인도의 도시 중 하나인 시루세리에는 모든 도로들이 일방통행으로 되어 있다. 도로들이 만나는 모든 교차로에는 시루세리 은행의 현금입출금기(ATM)가 설치되어 있다. 시루세리에는 유명한 레스토랑 체인인 아웃백 커리 하우스가 있다. 이 레스토랑의 각 체인점들은 교차로에만 위치한다. 물론 각 교차로마다 항상 이 레스토랑 체인점이 있는 것은 아니다. 이 레스토랑은 현금만 사용할 수 있다.
시루세리에 사는 반디치는 오늘 오후에 이 레스토랑에서 가족들과 파티를 열려고 한다. 그런데 갖고 있는 현금이 부족하여 레스토랑으로 가는 동안에 가능한 한 많은 현금을 ATM 기기로부터 인출할 계획을 세웠다. 그는 자신의 집에서 출발하여 차로 이동하면서 통과하는 모든 교차로 ATM 기기에 들어있는 현금 전부를 인출하려고 한다. 차량의 최종 목적지는 아웃백 커리 하우스 체인점 중의 한 곳이고, 이 체인점이 어떤 교차로에 위치하는지는 상관없다.
반디치는 시루세리 은행의 홈페이지 정보를 통해 각 ATM 기기에 현금이 얼마나 들어 있는지를 알고 있다. 이동 시 동일한 도로나 교차로를 여러 번 지날 수 있다. ATM 기기의 현금은 새로 보충되지 않기 때문에 첫 번째 이후 다시 방문하는 교차로의 ATM 기기에는 인출할 현금이 없다.
예를 들어, 아래 그림처럼 도시에 6개의 교차로가 있다고 하자. 교차로는 원으로 표시되어 있고, 화살표는 도로를 나타낸다. 이중 원으로 표시된 교차로에는 레스토랑이 있다. 각 ATM 기기가 갖고 있는 현금의 액수는 교차로 위에 표시된 숫자이다. 이 예에서 현금 인출을 1번 교차로부터 시작한다면, 반디치는 1-2-4-1-2-3-5의 경로를 통해서 총 47의 현금을 인출할 수 있다.
반디치가 출발 장소에서 어떤 레스토랑까지 이동하면서 인출할 수 있는 현금의 최대 액수가 얼마인지를 계산하는 프로그램을 작성하시오.
첫째 줄에 교차로의 수와 도로의 수를 나타내는 2개의 정수 N과 M(N, M ≤ 500,000)이 차례로 주어진다. 교차로는 1부터 N까지 번호로 표시된다. 그 다음 M개의 줄에는 각 줄마다 각 도로의 시작 교차로 번호와 끝 교차로 번호를 나타내는 2개의 정수가 주어진다. 그 다음 N개의 줄에는 1번 교차로부터 차례대로 각 교차로의 ATM 기기가 보유한 현금의 액수를 나타내는 정수가 각 줄에 하나씩 주어진다. 그 다음 줄에는 두 개의 정수 S와 P가 주어진다. 여기서 S는 출발 장소(현금 인출의 시작 장소)인 교차로 번호이고 P는 레스토랑의 개수이다(1 ≤ P ≤ N). 그 다음 줄에는 각 레스토랑이 있는 교차로의 번호를 나열한 P개의 정수가 주어진다.
각 ATM 기기에 들어 있는 현금의 액수는 0 이상 4,000 이하이다. 모든 입력에서 경로의 출발 장소로부터 일방통행 도로를 통해 도달 가능한 레스토랑이 항상 하나 이상 존재한다.
출력
출력은 한 개의 정수이다. 이 정수는 반디치가 출발 장소에서 어떤 레스토랑까지 이동하면서 인출할 수 있는 현금의 최대 액수이다.
어떤 교차로에서 다른 교차로를 통해 다시 교차로로 들어올 수 있다면, 이 교차로로 돌아올 수 있는 가능한 한 모든 교차로를 돌며 ATM을 출금하는 것이 합리적이다.
즉 같은 강한 연결 요소(SCC) 내 지점을 모두 방문하는 것이 좋다.
또한, 강한 연결 요소 내부에 레스토랑이 존재한다면, 그 요소 내의 모든 교차로는 마지막으로 방문할 수 있는 셈이다.
즉 한 SCC 요소를 하나의 교차로로 취급할 수 있겠다.
이 때 출금할 수 있는 모든 현금액은 요소 내의 현금액의 합과 같다.
탈출 여부(레스토랑 존재 여부)는 요소 내의 탈출 여부 합집합(OR)과 같다.
풀이는
SCC 요소를 찾아내어
그 요소들을 하나의 노드로 변환하고
시작 노드에서 그래프 탐색을 수행하여(이 때 위상 정렬되었음을 이용해도 좋다)
탈출이 가능한 노드의 최대 현금값을 출력하는
식으로 이루어져야 한다.
어설프게 시간에 쫓기며 풀다가 코드가 너무 더러워졌다..!
풀이 코드
from collections import defaultdict, deque
import sys
sys.setrecursionlimit(600000)
input = sys.stdin.readline
N, M = map(int, input().split())
raw_edge = defaultdict(list)
for _ in range(M) :
a, b = map(int, input().split())
raw_edge[a].append(b)
raw_cash = [0] + [int(input()) for _ in range(N)]
raw_visited = [-1]*(N+1)
scc_finished = [-1]*(N+1)
scc_idx = 0
idx = 0
stk = []
def scc(node) :
global idx, scc_idx
raw_visited[node] = ret = idx
idx += 1
stk.append(node)
for nxt in raw_edge[node] :
if scc_finished[nxt] > -1 :
continue
if raw_visited[nxt] == -1 :
scc(nxt)
raw_visited[node] = min(raw_visited[node], raw_visited[nxt])
if ret == raw_visited[node] :
while stk :
n = stk.pop()
scc_finished[n] = scc_idx
if n == node :
break
scc_idx += 1
for i in range(1, N+1) :
if scc_finished[i] == -1 :
scc(i)
S, P = map(int, input().split())
cash = [0]*scc_idx
edge_list = [set() for _ in range(scc_idx)]
max_cash = [0]*scc_idx
finish = [False]*scc_idx
for r in list(map(int, input().split())) :
finish[scc_finished[r]] = True
result = 0
for i in range(1, N+1) :
_i = scc_finished[i]
cash[_i] += raw_cash[i]
for j in raw_edge[i] :
if _i != scc_finished[j] :
edge_list[_i].add(scc_finished[j])
if i == S :
s = _i
def bfs(node) :
max_cash[node] += cash[node]
q = deque([node])
while q :
n = q.popleft()
for nxt in edge_list[n] :
if max_cash[nxt] < max_cash[n] + cash[nxt] :
max_cash[nxt] = max_cash[n] + cash[nxt]
q.append(nxt)
bfs(s)
for i in range(scc_idx) :
if finish[i] :
result = max(result, max_cash[i])
print(result)