둘째 줄에는 A1, A2, ..., AN이 주어진다. (1 ≤ Ai ≤ 1,000,000)
셋째 줄에는 쿼리의 개수 M (1 ≤ M ≤ 100,000)이 주어진다.
넷째 줄부터 M개의 줄에는 쿼리가 한 줄에 하나씩 주어진다. 1번 쿼리의 경우 1 ≤ i ≤ j ≤ N, -1,000,000 ≤ k ≤ 1,000,000 이고, 2번 쿼리의 경우 1 ≤ x ≤ N이다. 2번 쿼리는 하나 이상 주어진다.
출력
2번 쿼리가 주어질 때마다 출력한다.
입력 예시
5
1 2 3 4 5
4
1 3 4 6
2 3
1 1 3 -2
2 3
출력 예시
9
7
풀이
나의 첫 lazy segment tree 문제. 오래 전부터 실패했던 문제인데, 이번 기회에 공부하면서 풀어 보았다.
일반적인 방식으로 쿼리를 진행하게 되면, 최대 시간복잡도가 O(N*M)이 되는 경우가 주어진다. N의 최대 값과 M의 최대값이 모두 100,000이므로 시간 초과가 벌어지게 된다.
노드 하나의 업데이트 및 출력에 O(logN)이 소요되는 일반 세그먼트 트리를 적용시켜 보아도, 1번 쿼리의 경우 최대 N개의 인자를 업데이트해야 할 수도 있으므로 한 번의 업데이트에 O(NlogN)이 소모되며, 이 때 역시 시간복잡도는 O(NlogN*M)이 소요되므로 오히려 성능이 악화된다.
따라서 구간을 업데이트할 때 역시 구간 길이에 거의 시간복잡도 영향을 받지 않는 lazy segment tree를 사용하는 게 좋다. 핵심은 tree와 같은 크기의 lazy 배열을 생성하고, 각 노드에 접근할 때마다 한 층씩 아래로 lazy를 전파(propagation)하는 lazy propagation이 되겠다. 지금 이것과 같은 경우는 2번 쿼리가 단일 결과값을 요구하지만, 구간합을 요구하면 이 코드 역시 추가적인 수정이 필요해진다.
코드 역시 세그먼트 트리를 class로 지정하여 표현해 보았다. 훨씬 깔끔하니 보기가 좋네 ㅎ....
풀이 코드
import sys
input = sys.stdin.readline
class Seg_Tree() :
def __init__(self, length, val_list) :
self.tree = [0]*(length*4)
self.lazy = [0]*(length*4)
self.val_list = val_list
self.length = length
start, end = 1, self.length
self._init(1, start, end)
def _init(self, node, start, end) :
if start == end :
self.tree[node] = self.val_list[start]
return
mid = (start + end) // 2
self._init(node*2, start, mid)
self._init(node*2 + 1, mid+1, end)
def _propagate(self, node, start, end) :
if self.lazy[node] != 0 :
if start != end :
self.lazy[node*2] += self.lazy[node]
self.lazy[node*2 + 1] += self.lazy[node]
else :
self.tree[node] += self.lazy[node]
self.lazy[node] = 0
def update(self, l, r, val) :
start, end = 1, self.length
self._update(1, start, end, l, r, val)
def _update(self, node, start, end, l, r, val) :
self._propagate(node, start, end)
if l > end or r < start :
return
if l <= start and end <= r :
self.lazy[node] += val
self._propagate(node, start, end)
return
if self.lazy[node] != 0 :
self.lazy[node*2] += self.lazy[node]
self.lazy[node*2 + 1] += self.lazy[node]
self.lazy[node] = 0
mid = (start + end) // 2
self._update(node*2, start, mid, l, r, val)
self._update(node*2+1, mid+1, end, l, r, val)
def out(self, target) :
start, end = 1, self.length
return self._out(1, start, end, target)
def _out(self, node, start, end, target) :
self._propagate(node, start, end)
if target == start == end :
return self.tree[node]
if target < start or target > end :
return 0
mid = (start + end) // 2
if target <= mid :
return self._out(node*2, start, mid, target)
else :
return self._out(node*2+1, mid+1, end, target)
def solve() :
N = int(input())
val_list = [0] + list(map(int, input().split()))
M = int(input())
seg_tree = Seg_tree(N, val_list)
for _ in range(M) :
q, *commands = map(int, input().split())
if q == 1 :
i, j, val = commands
seg_tree.update(i, j, val)
else :
target = commands[0]
print(seg_tree.out(target))
solve()