도미노는 재밌다. 도미노 블록을 일렬로 길게 늘어세운 뒤 블록 하나를 넘어뜨리면 그 블록이 넘어지며 다음 블록을 넘어뜨리는 일이 반복되어 일렬로 늘어선 블록들을 연쇄적으로 모두 쓰러뜨릴 수 있다. 그러나, 가끔씩 도미노가 다른 블록을 넘어뜨리지 못하게 배치되어 있다면, 우리는 다음 블록을 수동으로 넘어뜨려야 한다.
이제 각 도미노 블록의 배치가 주어졌을 때, 모든 블록을 넘어뜨리기 위해 손으로 넘어뜨려야 하는 블록 개수의 최솟값을 구하자.
각 테스트 케이스의 첫 번째 줄에는 두 정수 N, M이 주어지며 두 수는 100,000을 넘지 않는다. N은 도미노의 개수를, M은 관계의 개수를 나타낸다. 도미노 블록의 번호는 1과 N 사이의 정수다. 다음 M개의 줄에는 각각 두 정수 x, y가 주어지는데, 이는 x번 블록이 넘어지면 y번 블록도 넘어짐을 뜻한다.
출력
각 테스트 케이스마다 한 줄에 정수 하나를 출력한다. 정답은 손으로 넘어뜨려야 하는 최소의 도미노 블록 개수이다.
입력 예시
1
3 2
1 2
2 3
출력 예시
1
풀이
1. 쓰러뜨리는 도미노는 다른 도미노에 영향을 미친다. 즉 우리는 다른 어떤 도미노도 영향을 미치지 못하는 도미노들을 가급적 쓰러뜨리는 것이 최소 도미노를 쓰러뜨릴 수 있는 방법임을 눈치챌 수 있다. 위상 정렬을 수행해서, 최우선 위상인 도미노를 쓰러뜨리는 전략을 생각해보자.
2. 그런데, 이 도미노 구조는 사이클을 형성할 수 있다. 가령 1 -> 2, 2 -> 3, 3 -> 1인 도미노는 셋 다 위상이 1일 것이다(입력이 존재하므로). 이 경우는 어떤 도미노를 쓰러뜨리던 결과는 같다.
잘 생각해보면, 이런 싸이클, 혹은 싸이클들이 복합적으로 이루어진 도미노들은 하나의 강한 연결 요소(SCC)로 이루어짐을 알 수 있다. 즉 같은 강한 연결 요소 내의 도미노들은 어떤 도미노를 먼저 쓰러뜨리더라도 그 연결 요소 내의 모든 도미노가 쓰러짐을 보장한다. 따라서 같은 SCC 내의 모든 도미노를 하나의 도미노로 취급해야 한다.
즉 우리는
SCC끼리 묶어 하나의 도미노로 만든 뒤
이 처리가 끝난 그래프를 위상 정렬을 수행하여
indegree 위상이 0인 경우를 세면
쓰러뜨려야 하는 최소 도미노 블록 개수를 구할 수 있다.
풀이 코드
from collections import defaultdict
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**7)
MAX = float('inf')
def solve() :
global idx, scc_idx
N, M = map(int, input().split())
scc_visited = [MAX]*(N+1)
scc_node = [-1]*(N+1)
stk = list()
scc_idx = 0
idx = 0
result = 0
edge_dict = defaultdict(list)
for _ in range(M) :
a, b = map(int, input().split())
edge_dict[a].append(b)
def scc(node) :
global idx, scc_idx
ret = scc_visited[node] = idx
idx += 1
stk.append(node)
for nxt in edge_dict[node] :
if scc_node[nxt] > -1 :
continue
if scc_visited[nxt] == MAX :
scc(nxt)
scc_visited[node] = min(scc_visited[node], scc_visited[nxt])
if scc_visited[node] == ret :
while stk :
n = stk.pop()
scc_node[n] = scc_idx
if n == node :
break
scc_idx += 1
for i in range(1, N+1) :
if scc_visited[i] == MAX :
scc(i)
order = [0]*scc_idx
for i in range(1, N+1) :
for j in edge_dict[i] :
if scc_node[i] == scc_node[j] :
continue
order[scc_node[j]] += 1
for i in range(scc_idx) :
if not order[i] :
result += 1
print(result)
for _ in range(int(input())) :
solve()