넷째 줄부터 M개의 줄에는 쿼리가 한 줄에 하나씩 주어진다. (1 ≤ x ≤ y ≤ N, 1 ≤ v ≤ 109)
출력
4번 쿼리에 대해서 정답을 한 줄에 하나씩 순서대로 출력한다.
입력 예시
4
1 2 3 4
4
4 1 4
1 1 3 10
2 2 4 2
4 1 4
출력 예시
10
69
풀이
핵심은 구간의 덧셈, 곱셈, 초기화 세 가지를 어떻게 빠른 시간 내로 적용시킬 수 있는가로 보였다. lazy segment tree를 사용하되, 이 lazy 값을 어떻게 전파하냐가 중요하다.
가령, [1, 1, 1]로 이루어진 구간을 업데이트한다고 하자. +2, x3 순으로 업데이트한다면 그 결과는 [9, 9, 9]가 될 것이고, x3, +2 순으로 업데이트한다면 그 결과는 [5, 5, 5]가 될 것이다. 즉 구간의 lazy값에는 쿼리의 순서가 반영되어야 한다. lazy를 하나의 큐로 본다면 가장 확실하게 처리할 수 있겠지만, 이는 메모리 초과를 야기한다.
해답은 lazy값을 (mul, sum)으로 구성하자. 원래 값이 x라면, 반영 후의 값은 x*mul + sum*(end - start + 1)이 될 것이다. lazy의 전파 방법은 다음과 같다.
mul -> sum 순으로 전파한다.
mul값은 자식 노드의 lazy의 (mul, sum) 전체에 영향을 미친다.
가령 mul값이 m이라면, 자손 노드의 lazy는 (mul * m, sum * m)이 된다.
lazy를 하나의 수식(ax + b)라고 생각하자. a는 mul, b는 sum, x는 반영되는 값이다.
그렇다면, 위에서 전파되는 새로운 곱셈값 a'을 적용할 때 a' ( ax + b ) = a'ax + a'b가 될 것이다.
sum 값은 자식 노드의 lazy의 sum값에만 영향을 미친다.
sum 값이 n라면, 자손 노드의 lazy는 최종적으로 (mul * m, sum * m + n)이 된다.
위와 같이 수식으로 나타내면, 새로운 덧셈값 b'을 적용할 때 a'ax + a'b + b가 될 것이다.
이렇게 하면 순서를 고려하여 덧셈, 곱셈에 대한 결과가 보존된다.
v로 초기화하는 경우는 곱셈을 활용한다. (mul, sum) = (0, v)라면, 전파 과정에서 기존의 값에 0이 곱해져서 사라지고, v값이 더해져 v로 초기화되는 셈이다.
풀이 코드
import sys
input = sys.stdin.readline
MOD = int(1e9) + 7
N = int(input())
nums = list(map(int, input().split()))
tree = [0]*(4*N+4)
lazy = [[1, 0] for _ in range(4*N+4)]
def init(start = 0, end = N-1, idx = 1) :
if start == end :
tree[idx] = nums[start]
return
mid = (start + end) // 2
init(start, mid, idx*2)
init(mid+1, end, idx*2+1)
tree[idx] = tree[idx*2] + tree[idx*2+1]
def propagate(start, end, idx) :
if start != end :
for i in range(2) :
lazy[idx*2][i] = (lazy[idx*2][i]*lazy[idx][0]) % MOD
lazy[idx*2+1][i] = (lazy[idx*2+1][i]*lazy[idx][0]) % MOD
lazy[idx*2][1] = (lazy[idx*2][1] + lazy[idx][1]) % MOD
lazy[idx*2+1][1] = (lazy[idx*2+1][1] + lazy[idx][1]) % MOD
tree[idx] = (tree[idx] * lazy[idx][0]) % MOD
tree[idx] = (tree[idx] + lazy[idx][1]*(end - start + 1)) % MOD
lazy[idx][0] = 1
lazy[idx][1] = 0
def update(left, right, val, ops, start = 0, end = N-1, idx = 1) :
propagate(start, end, idx)
if right < start or left > end :
return
if left <= start <= end <= right :
if ops == 1 :
lazy[idx][1] = (lazy[idx][1] + val) % MOD
elif ops == 2 :
lazy[idx][0] = (lazy[idx][0] * val) % MOD
lazy[idx][1] = (lazy[idx][1] * val) % MOD
else :
lazy[idx][0] = 0
lazy[idx][1] = val
propagate(start, end, idx)
return
mid = (start + end) // 2
update(left, right, val, ops, start, mid, idx*2)
update(left, right, val, ops, mid+1, end, idx*2+1)
tree[idx] = (tree[idx*2] + tree[idx*2+1]) % MOD
def search(left, right, start = 0, end = N-1, idx = 1) :
propagate(start, end, idx)
if right < start or left > end :
return 0
if left <= start <= end <= right :
return tree[idx] % MOD
mid = (start + end) // 2
ret = 0
ret = (ret + search(left, right, start, mid, idx*2)) % MOD
ret = (ret + search(left, right, mid+1, end, idx*2+1)) % MOD
return ret
init()
M = int(input())
for _ in range(M) :
q, *cmd = map(int, input().split())
if q <= 3 :
x, y, v = cmd
update(x-1, y-1, v, q)
else :
x, y = cmd
print(search(x-1, y-1) % MOD)