첫 줄에는 논의 수 N(1 ≤ N ≤ 300)이 주어진다. 다음 N개의 줄에는 i번째 논에 우물을 팔 때 드는 비용 Wi(1 ≤ Wi ≤ 100,000)가 순서대로 들어온다. 다음 N개의 줄에 대해서는 각 줄에 N개의 수가 들어오는데 이는 i번째 논과 j번째 논을 연결하는데 드는 비용 Pi,j(1 ≤ Pi,j ≤ 100,000, Pi,j = Pj,i, Pi,i = 0)를 의미한다.
출력
첫 줄에 모든 논에 물을 대는데 필요한 최소비용을 출력한다.
입력 예시
4
5
4
4
3
0 2 2 2
2 0 3 3
2 3 0 4
2 3 4 0
출력 예시
9
풀이
전반적인 접근은 최소 스패닝 트리(MST)를 구성하는 것이다.
1. 우선 내가 생각한 풀이법은 크루스칼 알고리즘을 적용하되, Union의 조건을 수정하는 것이었다. 모든 독립된 트리는 기본적으로 최소 하나의 수원을 가지고 있어야 한다. 따라서 양 노드를 Union하기 위해서는 (양 노드의 수원 코스트의 합) 보다 (양 노드의 수원 코스트의 최소값 + 두 노드를 연결하는 비용)이 작을 때만 이루어져야 한다. 그리고 마지막에는 부모 노드를 확인하며 트리의 개수와 수원 코스트를 가져와야 한다.
2. 그리고 다른 분들의 풀이를 보면서 배우게 된 사실. 수원 역시 하나의 노드로 간주하면 더 쉽게 풀린다.
앞의 예제를 예시로 들면...
4
5
4
4
3
0 2 2 2
2 0 3 3
2 3 0 4
2 3 4 0
이 입력은 노드 개수가 5인 경우로 간주할 수 있으며, 여기서 인접행렬은
0 5 4 4 3
5 0 2 2 2
4 2 0 3 3
4 2 3 0 4
3 2 3 4 0
로 변한다. 재밌는 발상의 전환이다. 이 방식을 사용하면 수윈 코스트를 저장할 필요가 없이 MST만 사용하면 되므로 코드가 훨신 간결해진다.
풀이 코드 1
import sys
input = sys.stdin.readline
MAX = float('inf')
N = int(input())
well_cost = [int(input()) for _ in range(N)]
road_list = list()
for i in range(N) :
road_cost = list(map(int, input().split()))
for j in range(i+1, N) :
road_list.append((road_cost[j], i, j))
road_list.sort()
parent = [(i, well_cost[i]) for i in range(N)]
def find(a) :
if a == parent[a][0] :
return parent[a]
parent[a] = find(parent[a][0])
return parent[a]
def union(a, b) :
pa, pa_cost = find(a)
pb, pb_cost = find(b)
if pa_cost < pb_cost :
parent[pb] = (pa, pa_cost)
else :
parent[pa] = (pb, pb_cost)
answer, cnt = 0, N-1
for cost, a, b in road_list :
if not cnt :
break
pa, pa_cost = find(a)
pb, pb_cost = find(b)
if pa != pb and max(pa_cost, pb_cost) >= cost :
cnt -= 1
union(a, b)
answer += cost
tree_set = set()
for i in range(N) :
pi, pcost = find(i)
if pi not in tree_set :
tree_set.add(pi)
answer += pcost
print(answer)
풀이 코드 2
import sys
input = sys.stdin.readline
MAX = float('inf')
N = int(input())
road_list = list()
for i in range(N) :
well_cost = int(input())
road_list.append((well_cost, i+1, 0))
for i in range(N) :
road_cost = list(map(int, input().split()))
for j in range(i+1, N) :
road_list.append((road_cost[j], i+1, j+1))
road_list.sort()
parent = list(range(N+1))
def find(a) :
if a == parent[a] :
return parent[a]
parent[a] = find(parent[a])
return parent[a]
def union(pa, pb) :
if pa < pb :
parent[pb] = pa
else :
parent[pa] = pb
answer, cnt = 0, N
for cost, a, b in road_list :
if not cnt :
break
pa = find(a)
pb = find(b)
if pa != pb :
cnt -= 1
union(pa, pb)
answer += cost
print(answer)