N×N개의 수가 N×N 크기의 표에 채워져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 표의 i행 j열은 (i, j)로 나타낸다. (x1, y1)부터 (x2, y2)까지 합이란 x1 ≤ x ≤ x2, y1 ≤ y ≤ y2를 만족하는 모든 (x, y)에 있는 수의 합이다.
예를 들어, N = 4이고, 표가 아래와 같이 채워져 있는 경우를 살펴보자.
여기서 (2, 2)부터 (3, 4)까지 합을 구하면 3+4+5+4+5+6 = 27이 된다. (2, 3)을 7로 바꾸고 (2, 2)부터 (3, 4)까지 합을 구하면 3+7+5+4+5+6=30 이 된다.
표에 채워져 있는 수와 변경하는 연산과 합을 구하는 연산이 주어졌을 때, 이를 처리하는 프로그램을 작성하시오.
첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는 네 개의 정수 w, x, y, c 또는 다섯 개의 정수 w, x1, y1, x2, y2가 주어진다. w = 0인 경우는 (x, y)를 c (1 ≤ c ≤ 1,000)로 바꾸는 연산이고, w = 1인 경우는 (x1, y1)부터 (x2, y2)의 합을 구해 출력하는 연산이다. (1 ≤ x1 ≤ x2 ≤ N, 1 ≤ y1 ≤ y2 ≤ N) 표에 채워져 있는 수는 1,000보다 작거나 같은 자연수이다.
2차원이라고 겁낼 필요 없고, 2차원 구간 세그먼트 트리를 작성하면 된다. 수행 속도 및 메모리를 고려해 펜윅 트리로 작성하였다. 한 행이 세그먼트 트리로 이루어진 세그먼트 트리를 생각하면 구현이 더 쉽다.
펜윅 트리는 특성상 [1, N]까지의 구간합이 구해지므로, 실제로 펜윅 트리를 이용할 때 search 함수가 총 4번 호출되어야 한다는 점을 기억하자. update 및 search 모두 O(logN^2)의 시간복잡도가 소요된다.
풀이 코드
import sys
input = sys.stdin.readline
MAX = float('inf')
N, M = map(int, input().split())
tree = [[0]*(N+1) for _ in range(N+1)]
def update2D(x, y, val) :
while y <= N :
_x = x
while _x <= N :
tree[y][_x] += val
_x += -_x & _x
y += -y & y
def search2D(x, y) :
if not x or not y :
return 0
result = 0
while y :
_x = x
while _x :
result += tree[y][_x]
_x -= -_x & _x
y -= -y & y
return result
maps = [list(map(int, input().split())) for _ in range(N)]
for i in range(N) :
for j in range(N) :
update2D(j+1, i+1, maps[i][j])
for _ in range(M) :
w, *cmd = map(int, input().split())
if w == 0 :
y, x, c = cmd
update2D(x, y, c - maps[y-1][x-1])
maps[y-1][x-1] = c
else :
y1, x1, y2, x2 = cmd
print(search2D(x2, y2)+search2D(x1-1, y1-1)-search2D(x2, y1-1)-search2D(x1-1, y2))