새소식

PS/백준

[백준/1280] 나무 심기 (Python)

  • -

Problem : https://www.acmicpc.net/problem/1280

 

1280번: 나무 심기

첫째 줄에 나무의 개수 N (2 ≤ N ≤ 200,000)이 주어진다. 둘째 줄부터 N개의 줄에 1번 나무의 좌표부터 차례대로 주어진다. 각각의 좌표는 200,000보다 작은 자연수 또는 0이다.

www.acmicpc.net

 

Difficulty : Platinum 4

 

Status : Solved

 

Time : 00:58:31

 


 

문제 설명

 

더보기

 

1번부터 N번까지 번호가 매겨져 있는 N개의 나무가 있다. i번 나무는 좌표 X[i]에 심어질 것이다.

동호는 나무를 1번 나무부터 차례대로 좌표 X[i]에 심으려고 한다. 1번 나무를 심는 비용은 없고, 각각의 나무를 심는데 드는 비용은 현재 심어져있는 모든 나무 까지 거리의 합이다. 만약 3번 나무를 심는다면, 1번 나무와의 거리 + 2번 나무와의 거리가 3번 나무를 심는데 드는 비용이다.

2번 나무부터 N번 나무까지를 심는 비용의 곱을 출력하는 프로그램을 작성하시오.

 

 

입력 및 출력

 

더보기

입력

 

첫째 줄에 나무의 개수 N (2 ≤ N ≤ 200,000)이 주어진다. 둘째 줄부터 N개의 줄에 1번 나무의 좌표부터 차례대로 주어진다. 각각의 좌표는 200,000보다 작은 자연수 또는 0이다.

 

출력

 

문제의 정답을 1,000,000,007로 나눈 나머지를 출력한다.

 

입력 예시

 

5
3
4
5
6
7

 

출력 예시

 

180

 

 


 

풀이

 

n번째 나무의 비용을 어떻게 빠르게 구할 수 있을까? n번째 나무의 좌표를 val이라고 하자. 1 ~ n-1번째 나무들을 기준으로, 오른쪽에 있는 좌표들의 집합을 r이라고 두고, 왼쪽에 있는 좌표들의 집합을 l이라고 두자. 다음과 같은 식으로 나타낼 수 있을 것이다.

즉, 왼쪽에 있는 모든 좌표들과의 거리합과, 오른쪽에 있는 모든 좌표들의 거리합이 된다.

이를 풀어 쓰면 위와 같은 식이 나온다. 즉 우리가 필요한 정보는

 

  • val보다 왼쪽에 있는 좌표들의 개수
  • val보다 왼쪽에 있는 좌표들의 총합
  • val보다 오른쪽에 있는 좌표들의 개수
  • val보다 오른쪽에 있는 좌표들의 총합

 

이 되며, 이를 O(N)미만의 시간복잡도 내에 구할 수 있어야 한다. 또한, 다음의 n+1번째 나무의 비용을 구할 때는 n번째 나무 좌표 val역시 반영이 되어야 하므로, 어떤 자료구조를 사용한다면 이러한 정보 반영 역시 O(N)미만으로 소요되어야 한다. 즉 세그먼트 트리를 생각해 볼 수 있겠다.

 

세그먼트 트리에는 현재 범위의 좌표들의 개수와 총합, 두 가지 데이터를 저장하도록 하고, 정보를 검색할 때 val을 기준으로 왼쪽 범위와 오른쪽 범위를 검색토록 한다. update시 리프 노드에는 (+1, +val)로 정보를 업데이트하고, 그 상위 노드는 (left[0] + right[0], left[1] + right[1])로 재귀적으로 업데이트하면 된다.

 

풀이 코드

N = int(input())
MAXVAL = 200000
MOD = 1000000007
num_list = [int(input()) for _ in range(N)]

class SegTree() :
  def __init__(self) :
    self.tree = [[0]*2 for _ in range(4*MAXVAL)]

  def search(self, val) :
    l, r = (0, 0), (0, 0)
    if val > 0 :
      l = self._search(1, 0, MAXVAL, 0, val-1)
    if val < MAXVAL :
      r = self._search(1, 0, MAXVAL, val+1, MAXVAL)

    result = (val*l[0] - l[1]) + (r[1] - val*r[0])
    return result % MOD

  def _search(self, idx, start, end, l, r) :
    if end < l or r < start :
      return 0, 0
    if l <= start and end <= r :
      return self.tree[idx]

    mid = (start + end) // 2
    left = self._search(2*idx, start, mid, l, r)
    right = self._search(2*idx+1, mid+1, end, l, r)

    return (left[0] + right[0]) % MOD, (left[1] + right[1]) % MOD

  def update(self, val) :
    self._update(1, 0, MAXVAL, val)

  def _update(self, idx, start, end, val) :
    if end < val or val < start :
      return
    if start == end :
      self.tree[idx][0] = (self.tree[idx][0] + 1) % MOD
      self.tree[idx][1] = (self.tree[idx][1] + start) % MOD
      return

    mid = (start + end) // 2
    self._update(2*idx, start, mid, val)
    self._update(2*idx+1, mid+1, end, val)

    self.tree[idx][0] = ( self.tree[2*idx][0] + self.tree[2*idx+1][0] ) % MOD
    self.tree[idx][1] = ( self.tree[2*idx][1] + self.tree[2*idx+1][1] ) % MOD

ans = 1
segtree = SegTree()
segtree.update(num_list[0])
for i in range(1, N) :
  ans = (ans * segtree.search(num_list[i])) % MOD
  segtree.update(num_list[i])

print(ans)

풀이 완료!

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.