새소식

PS/백준

[백준/11658] 구간 합 구하기 3 (Python)

  • -

Problem : https://www.acmicpc.net/problem/11658

 

11658번: 구간 합 구하기 3

첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는

www.acmicpc.net

 

Difficulty : Platinum 4

 

Status : Solved

 

Time : 00:22:05

 


 

문제 설명

 

더보기

 

N×N개의 수가 N×N 크기의 표에 채워져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 표의 i행 j열은 (i, j)로 나타낸다. (x1, y1)부터 (x2, y2)까지 합이란 x1 ≤ x ≤ x2, y1 ≤ y ≤ y2를 만족하는 모든 (x, y)에 있는 수의 합이다.

예를 들어, N = 4이고, 표가 아래와 같이 채워져 있는 경우를 살펴보자.

 

 

여기서 (2, 2)부터 (3, 4)까지 합을 구하면 3+4+5+4+5+6 = 27이 된다. (2, 3)을 7로 바꾸고 (2, 2)부터 (3, 4)까지 합을 구하면 3+7+5+4+5+6=30 이 된다.

표에 채워져 있는 수와 변경하는 연산과 합을 구하는 연산이 주어졌을 때, 이를 처리하는 프로그램을 작성하시오.

 

 

입력 및 출력

 

더보기

입력

 

첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는 네 개의 정수 w, x, y, c 또는 다섯 개의 정수 w, x1, y1, x2, y2가 주어진다. w = 0인 경우는 (x, y)를 c (1 ≤ c ≤ 1,000)로 바꾸는 연산이고, w = 1인 경우는 (x1, y1)부터 (x2, y2)의 합을 구해 출력하는 연산이다. (1 ≤ x1 ≤ x2 ≤ N, 1 ≤ y1 ≤ y2 ≤ N) 표에 채워져 있는 수는 1,000보다 작거나 같은 자연수이다.

 

출력

 

w = 1인 입력마다 구한 합을 순서대로 한 줄에 하나씩 출력한다.

 

입력 예시

 

4 5 1 2 3 4 2 3 4 5 3 4 5 6 4 5 6 7 1 2 2 3 4 0 2 3 7 1 2 2 3 4 0 3 4 5 1 3 4 3 4

 

출력 예시

 

27 30 5

 

 


 

풀이

 

앞서 2차원 세그먼트 트리로 크게 죽쑤고 나서 홧김에(?) 도전한 문제.

 

2차원이라고 겁낼 필요 없고, 2차원 구간 세그먼트 트리를 작성하면 된다. 수행 속도 및 메모리를 고려해 펜윅 트리로 작성하였다. 한 행이 세그먼트 트리로 이루어진 세그먼트 트리를 생각하면 구현이 더 쉽다.

 

펜윅 트리는 특성상 [1, N]까지의 구간합이 구해지므로, 실제로 펜윅 트리를 이용할 때 search 함수가 총 4번 호출되어야 한다는 점을 기억하자. update 및 search 모두 O(logN^2)의 시간복잡도가 소요된다.

 

 

풀이 코드

import sys input = sys.stdin.readline MAX = float('inf') N, M = map(int, input().split()) tree = [[0]*(N+1) for _ in range(N+1)] def update2D(x, y, val) : while y <= N : _x = x while _x <= N : tree[y][_x] += val _x += -_x & _x y += -y & y def search2D(x, y) : if not x or not y : return 0 result = 0 while y : _x = x while _x : result += tree[y][_x] _x -= -_x & _x y -= -y & y return result maps = [list(map(int, input().split())) for _ in range(N)] for i in range(N) : for j in range(N) : update2D(j+1, i+1, maps[i][j]) for _ in range(M) : w, *cmd = map(int, input().split()) if w == 0 : y, x, c = cmd update2D(x, y, c - maps[y-1][x-1]) maps[y-1][x-1] = c else : y1, x1, y2, x2 = cmd print(search2D(x2, y2)+search2D(x1-1, y1-1)-search2D(x2, y1-1)-search2D(x1-1, y2))

풀이 완료!

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.