파이썬은 PS에서는 정말x100 좌절감을 느끼게 만든다. 고 생각한다. 이제 좀 시간제한이 빡빡한 문제는 C++이나 자바로 풀어봐야하나... 싶다.
다음은 시도한 풀이법들이다.
좌표 압축 + Mo's : 총 좌표의 범위가 매우 크므로 좌표 압축을 시행하고, 이를 sqrt decomposition + Mo's로 풀어보려는 시도. 시간복잡도는 O((Q+N)sqrt(N))이다. 시간 초과.
머지소트트리 : 이 경우부터는 다른 언어에서는 아마 풀이되었을 것이다. 우선 다음에 관해 생각해보자. 임의의 수 n[i]가 다음에 등장하는 index값을 nxt[i]라고 두자. [l, r] 구간에 대해서, nxt[i]값이 r+1 이상인 값들의 수가 그 구간의 서로 다른 숫자의 개수가 된다. 만약 중복되는 수 n[i] == n[j] 가 있다고 가정하면(l <= i < j <= r), nxt[i] = j가 될 것이므로 앞서 등장한 숫자가 무시된다. 따라서 [l, r] 구간의 마지막에 등장하는 수만이 r+1 이상의 값을 가지게 된다.
머지소트트리는 여기서부터 시작한다. [l, r]구간의 모든 nxt값이 이미 정렬된 상태라면, 이 값중 r+1 이상의 값을 O(logN) 시간 내에 찾을 수 있다. 초기 머지 소트의 시간복잡도가 O(NlogN), 구간 탐색이 O(logN), 이분 탐색이 O(logN)이므로 총 시간복잡도는 O(NlogN + QlogN^2)이다. 이 경우 역시 시간 초과.
여기서 15일 정도 지체하다가, 오늘에서야 발상의 전환으로 문제에 다시 도전해 볼 수 있었다. nxt[i] >= r+1인 인덱스의 집합을 가정하자. 이 인덱스의 집합 중 [l, r] 구간에 속하는 인덱스의 수 역시 서로 다른 숫자의 개수를 의미한다.
즉 오프라인 쿼리를 시행하되, r을 기준으로 내림차순 정렬하자.
nxt[i] >= r + 1인 모든 인덱스를 세그먼트 트리에 삽입한다. 이 과정은 결국 모든 N개의 수에 대해 시행되므로 시간복잡도는 O(NlogN)이다.
이제 이 인덱스가 저장된 구간합 세그먼트 트리에서 [l, r] 값을 구한다. 이 경우의 시간복잡도는 O(logN)이고, 총 쿼리 수는 Q이므로 O(QlogN)이 된다.
따라서 이 경우의 시간복잡도는 O((N+Q)logN)이 된다.
또한, 단순 구간합 세그먼트 트리는 바텀업 방식으로도 구현이 가능하다는 것을 그제서야 배웠다. 이 경우는 함수의 재귀적 호출이 매우 적어지므로 메모리상 이득과 시간적 이득을 동시에 챙길 수 있다. 즉 오프라인 쿼리 + 세그먼트 트리(바텀업) 형식으로 풀이하였을 때, 드디어 ac를 받을 수 있었다...
풀이 코드
from collections import defaultdict
import math
import sys
input = sys.stdin.readline
N = int(input())
sz = 1 << math.ceil(math.log2(N+1))
tree = [0]*(2*sz)
def update(target, val) :
target += sz
tree[target] += val
while target > 1 :
target >>= 1
tree[target] = tree[target<<1] + tree[target<<1|1]
def search(l, r) :
ret = 0
l += sz
r += sz
while l <= r :
if l & 1 :
ret += tree[l]
l += 1
if not r & 1 :
ret += tree[r]
r -= 1
l >>= 1
r >>= 1
return ret
nums = list(map(int, input().split()))
nxt_dict = dict()
nxt_list = defaultdict(list)
for i in range(N-1, -1, -1) :
nxt = nxt_dict[nums[i]] if nums[i] in nxt_dict else N
nxt_list[nxt].append(i)
nxt_dict[nums[i]] = i
Q = int(input())
ans = [0]*Q
queries = [[i] + list(map(int, input().split())) for i in range(Q)]
queries.sort(key = lambda x : -x[2])
i = N+1
for idx, l, r in queries :
l -= 1
r -= 1
for j in range(r+1, i) :
for k in nxt_list[j] :
update(k, 1)
i = r+1
ans[idx] = search(l, r)
print(*ans, sep = '\n')