1번부터 N번까지 번호가 매겨져 있는 N개의 나무가 있다. i번 나무는 좌표 X[i]에 심어질 것이다.
동호는 나무를 1번 나무부터 차례대로 좌표 X[i]에 심으려고 한다. 1번 나무를 심는 비용은 없고, 각각의 나무를 심는데 드는 비용은 현재 심어져있는 모든 나무 까지 거리의 합이다. 만약 3번 나무를 심는다면, 1번 나무와의 거리 + 2번 나무와의 거리가 3번 나무를 심는데 드는 비용이다.
첫째 줄에 나무의 개수 N (2 ≤ N ≤ 200,000)이 주어진다. 둘째 줄부터 N개의 줄에 1번 나무의 좌표부터 차례대로 주어진다. 각각의 좌표는 200,000보다 작은 자연수 또는 0이다.
출력
문제의 정답을 1,000,000,007로 나눈 나머지를 출력한다.
입력 예시
5
3
4
5
6
7
출력 예시
180
풀이
n번째 나무의 비용을 어떻게 빠르게 구할 수 있을까? n번째 나무의 좌표를 val이라고 하자. 1 ~ n-1번째 나무들을 기준으로, 오른쪽에 있는 좌표들의 집합을 r이라고 두고, 왼쪽에 있는 좌표들의 집합을 l이라고 두자. 다음과 같은 식으로 나타낼 수 있을 것이다.
즉, 왼쪽에 있는 모든 좌표들과의 거리합과, 오른쪽에 있는 모든 좌표들의 거리합이 된다.
이를 풀어 쓰면 위와 같은 식이 나온다. 즉 우리가 필요한 정보는
val보다 왼쪽에 있는 좌표들의 개수
val보다 왼쪽에 있는 좌표들의 총합
val보다 오른쪽에 있는 좌표들의 개수
val보다 오른쪽에 있는 좌표들의 총합
이 되며, 이를 O(N)미만의 시간복잡도 내에 구할 수 있어야 한다. 또한, 다음의 n+1번째 나무의 비용을 구할 때는 n번째 나무 좌표 val역시 반영이 되어야 하므로, 어떤 자료구조를 사용한다면 이러한 정보 반영 역시 O(N)미만으로 소요되어야 한다. 즉 세그먼트 트리를 생각해 볼 수 있겠다.
세그먼트 트리에는 현재 범위의 좌표들의 개수와 총합, 두 가지 데이터를 저장하도록 하고, 정보를 검색할 때 val을 기준으로 왼쪽 범위와 오른쪽 범위를 검색토록 한다. update시 리프 노드에는 (+1, +val)로 정보를 업데이트하고, 그 상위 노드는 (left[0] + right[0], left[1] + right[1])로 재귀적으로 업데이트하면 된다.
풀이 코드
N = int(input())
MAXVAL = 200000
MOD = 1000000007
num_list = [int(input()) for _ in range(N)]
class SegTree() :
def __init__(self) :
self.tree = [[0]*2 for _ in range(4*MAXVAL)]
def search(self, val) :
l, r = (0, 0), (0, 0)
if val > 0 :
l = self._search(1, 0, MAXVAL, 0, val-1)
if val < MAXVAL :
r = self._search(1, 0, MAXVAL, val+1, MAXVAL)
result = (val*l[0] - l[1]) + (r[1] - val*r[0])
return result % MOD
def _search(self, idx, start, end, l, r) :
if end < l or r < start :
return 0, 0
if l <= start and end <= r :
return self.tree[idx]
mid = (start + end) // 2
left = self._search(2*idx, start, mid, l, r)
right = self._search(2*idx+1, mid+1, end, l, r)
return (left[0] + right[0]) % MOD, (left[1] + right[1]) % MOD
def update(self, val) :
self._update(1, 0, MAXVAL, val)
def _update(self, idx, start, end, val) :
if end < val or val < start :
return
if start == end :
self.tree[idx][0] = (self.tree[idx][0] + 1) % MOD
self.tree[idx][1] = (self.tree[idx][1] + start) % MOD
return
mid = (start + end) // 2
self._update(2*idx, start, mid, val)
self._update(2*idx+1, mid+1, end, val)
self.tree[idx][0] = ( self.tree[2*idx][0] + self.tree[2*idx+1][0] ) % MOD
self.tree[idx][1] = ( self.tree[2*idx][1] + self.tree[2*idx+1][1] ) % MOD
ans = 1
segtree = SegTree()
segtree.update(num_list[0])
for i in range(1, N) :
ans = (ans * segtree.search(num_list[i])) % MOD
segtree.update(num_list[i])
print(ans)