edges의 각 행은 [v1, v2] 2개의 정수로 이루어져 있으며, 이는 v1번 정점과 v2번 정점 사이에 간선이 있음을 의미합니다.
v1, v2는 각각 1 이상 n 이하입니다.
v1, v2는 다른 수입니다.
입력으로 주어지는 그래프는 항상 트리입니다.
입출력 예
n
edges
result
4
[[1,2],[2,3],[3,4]]
2
5
[[1,5],[2,5],[3,5],[4,5]]
2
풀이
이 문제를 선형 시간 안에 풀기 위해서는 '트리의 지름'의 개념을 알아야 한다.
트리의 지름은 '트리에서 임의의 두 노드를 골랐을 때 가장 큰 거리'로 정의된다. 트리의 지름을 구하는 방법은 매우 간단하다.
임의의 노드 x를 루트로 잡고 가장 거리가 긴 노드 y를 구한다.
그리고 그 노드 y를 루트로 잡고 가장 거리가 긴 노드 z를 구한다.
y-z간의 거리가 트리의 지름이 된다.
위의 두 과정 모두 DFS 혹은 BFS로 O(N)의 선형 시간 안에 구할 수 있으므로, 총 시간복잡도 역시 O(N)이 된다. 우리는 트리 트리오의 중간값의 최댓값을 알고 싶으며, 이 중 다음과 같은 경우가 나올 수 있다.
트리의 지름이 한 개인 경우 : 트리의 두 노드 y-z가 트리의 지름일 경우, 그 경로상의 맨 첫번째 노드를 고르면 당연히 세 거리는 (지름, 지름-1, 1)이 된다. 따라서 최대 중간값은 (지름-1)이 된다.
트리의 지름이 두 개 이상인 경우 : 트리의 지름이 둘 이상이므로 (지름, 지름, x)꼴의 조합이 가능해진다. 따라서 최대 중간값은 (지름)이다. 이 때 두 개 이상의 지름은 하나의 노드를 공유하게 된다... 위 알고리즘의 증명과 연관이 있으니 한 번 증명해 보야야할 것 같다.
또한 트리의 지름이 두 개 이상인 경우 두 지름은 하나의 노드를 공유한다는 조건이 있으므로, 지름의 개수를 올바르게 파악하려면 노드 z를 기준으로도 트리의 지름 개수를 세어줄 필요가 있다. 따라서 최대 3번의 탐색이 필요하다. 다음 예시를 보자.
1을 루트로 잡았을 때 DFS 순서를 나타낸 트리이다. 1에서 탐색을 진행하면 최대 거리 노드로 4를 반환하게 되고, 4에서 다시 탐색을 진행하면 최대 거리 노드로 7을 반환한다. 그러나 이 때 최대 거리 중복 여부가 제대로 체크되지 않는다. (트리의 지름은 4-7, 5-7이나, 4에서부터 DFS를 수행할 때는 4-7만 체크되기 때문이다) 따라서 7에서 탐색을 다시 수행하여 5-7 역시 체크할 수 있어야 한다.
dfs : 경로탐색 함수. 선형 시간 내에 출발점에서 가장 먼 거리를 갖는 노드, 그 거리, 그리고 최대거리의 중복 여부를 반환한다.(max_check가 0이면 중복된다)
solution : 메인함수. 두 번의 dfs를 통해 먼저 트리의 지름을 찾는다. 만약 이 과정에서 트리의 지름이 두 개 이상이 발견된다면 바로 지름값을 반환한다. 만약 그렇지 않으면 다시 한 번 dfs를 통해 다른 노드에서의 트리 지름 개수를 판별하고, 그 결과에 따라 (지름) 혹은 (지름-1)을 반환한다.
풀이 코드
from collections import defaultdict
def dfs(start, n, edge_dict) :
q = [(start, 0)]
visited = [False]*(n+1)
visited[start] = True
max_node, max_dist, max_check = start, 0, 1
while q :
node, dist = q.pop()
if dist > max_dist :
max_node, max_dist, max_check = node, dist, 1
elif dist == max_dist :
max_check = 0
for next_node in edge_dict[node] :
if not visited[next_node] :
visited[next_node] = True
q.append((next_node, dist+1))
return max_node, max_dist, max_check
def solution(n, edges):
edge_dict = defaultdict(list)
for a, b in edges :
edge_dict[a].append(b)
edge_dict[b].append(a)
start_node, _, _ = dfs(1, n, edge_dict)
end_node, diameter, check = dfs(start_node, n, edge_dict)
if not check :
return diameter
_, diameter, check = dfs(end_node, n, edge_dict)
return diameter-check