새소식

PS/백준

[백준/13537/13544] 수열과 쿼리 1 / 3 (Python)

  • -

Problem 1 : https://www.acmicpc.net/problem/13537

 

13537번: 수열과 쿼리 1

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. i j k: Ai, Ai+1, ..., Aj로 이루어진 부분 수열 중에서 k보다 큰 원소의 개수를 출력한다.

www.acmicpc.net

 

Problem 2 : https://www.acmicpc.net/problem/13544

 

13544번: 수열과 쿼리 3

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. i j k: Ai, Ai+1, ..., Aj로 이루어진 부분 수열 중에서 k보다 큰 원소의 개수를 출력한다.

www.acmicpc.net

 

Difficulty : Platinum 3

 

Status : Solved(pypy)

 

Time : 00:19:38

 


 

문제 설명 (13573번 기준)

 

더보기

 

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오.

i j k: Ai, Ai+1, ..., Aj로 이루어진 부분 수열 중에서 k보다 큰 원소의 개수를 출력한다.

 

 

입력 및 출력

 

더보기

입력

 

첫째 줄에 수열의 크기 N (1 ≤ N ≤ 100,000)이 주어진다.

둘째 줄에는 A1, A2, ..., AN이 주어진다. (1 ≤ Ai ≤ 10^9)

셋째 줄에는 쿼리의 개수 M (1 ≤ M ≤ 100,000)이 주어진다.

넷째 줄부터 M개의 줄에는 쿼리 i, j, k가 한 줄에 하나씩 주어진다. (1 ≤ i ≤ j ≤ N, 1 ≤ k ≤ 10^9)

 

출력

 

각각의 쿼리마다 정답을 한 줄에 하나씩 출력한다.

 

입력 예시

 

5
5 1 2 3 4
3
2 4 1
4 4 4
1 5 2

 

출력 예시

 

2
0
3

 

 


 

풀이

 

 

임의의 구간 중 k값보다 큰 값을 빠르게 얻어내려면, 구간에 대한 접근이 매우 빨라야 함을 알 수 있다. 즉 O(logN) 혹은 O((logN)^2)의 시간복잡도가 필요하다. 우리는 이를 수행하는 세그먼트 트리를 쉽게 떠올릴 수 있다.

 

또한, 어떤 구간 중 k값보다 큰 값을 빠르게 구하는 문제를 정렬된 구간에 대해 k값의 upper bound 인덱스를 구하는 문제로 변형할 수 있다. upper bound의 인덱스값부터 끝값까지의 값이 모두 k값보다 클테니까. 즉 세그먼트 트리의 모든 노드가 정렬된 순서를 유지하여 저장한다면, 이 문제를 합리적인 시간 내에 풀 수 있는 셈이다. 모든 노드가 자신의 범위 내의 정렬된 구간을 갖고 있는 세그먼트 트리머지 소트 트리라고 한다. 머지 소트를 수행하며 그 과정을 ㅌ 본 머지 소트 트리의 search 연산은

 

  • 구간 탐색을 재귀적으로 찾기 (O(logN))
  • 만약 현재 탐색 구간이 찾고자 하는 범위에 완전히 포함된다면, 현재 저장된 구간에 대해서 upper bound를 찾기 (O(logN))

 

로 구분된다. 이 때 초기화에 O(NlogN)이 소요되며, 모든 쿼리를 처리하는 시간복잡도는 O(N(logN)^2)이 된다

 

풀이 코드 1(13537)

import sys
input = sys.stdin.readline
N = int(input())
num_list = list(map(int, input().split()))

class MergeSortTree :
  def __init__(self) :
    self.tree = [list() for _ in range(4*N)]
    def _init(start, end, idx) :
      if start == end :
        self.tree[idx] = [num_list[start]]
        return
      mid = (start + end) // 2
      _init(start, mid, idx*2)
      _init(mid+1, end, idx*2+1)

      l, r = 0, 0
      while l < mid-start+1 and r < end - mid :
        if self.tree[idx*2][l] > self.tree[idx*2+1][r] :
          self.tree[idx].append(self.tree[idx*2+1][r])
          r += 1
        else :
          self.tree[idx].append(self.tree[idx*2][l])
          l += 1
      while r < end - mid :
        self.tree[idx].append(self.tree[idx*2+1][r])
        r += 1
      while l < mid-start+1 :
        self.tree[idx].append(self.tree[idx*2][l])
        l += 1
    _init(0, N-1, 1)

  def search(self, left, right, val) :
    left -= 1
    right -= 1

    def upper_bound(start, end, idx) :
      while start < end :
        mid = (start + end) // 2
        if self.tree[idx][mid] <= val :
          start = mid + 1
        else :
          end = mid
      return len(self.tree[idx]) - end
    
    def _search(start, end, idx) :
      if right < start or left > end :
        return 0
      if left <= start <= end <= right :
        return upper_bound(0, end-start+1, idx)
      mid = (start + end) // 2
      lval = _search(start, mid, idx*2)
      rval = _search(mid+1, end, idx*2+1)
      return lval + rval

    print(_search(0, N-1, 1))

segtree= MergeSortTree()
for _ in range(int(input())) :
  i, j, k = map(int, input().split())
  segtree.search(i, j, k)

풀이 완료!

 


 

수열과 쿼리 3 문제도 기본 풀이법은 위와 동일하다. 단 last answer를 저장 후 이후에 xor해주는 경우만 추가로 고려해주면 된다.

 

풀이 코드 2(13544)

import sys
input = sys.stdin.readline
N = int(input())
num_list = list(map(int, input().split()))

class MergeSortTree :
  def __init__(self) :
    self.tree = [list() for _ in range(4*N)]
    def _init(start, end, idx) :
      if start == end :
        self.tree[idx] = [num_list[start]]
        return
      mid = (start + end) // 2
      _init(start, mid, idx*2)
      _init(mid+1, end, idx*2+1)

      l, r = 0, 0
      while l < mid-start+1 and r < end - mid :
        if self.tree[idx*2][l] > self.tree[idx*2+1][r] :
          self.tree[idx].append(self.tree[idx*2+1][r])
          r += 1
        else :
          self.tree[idx].append(self.tree[idx*2][l])
          l += 1
      while r < end - mid :
        self.tree[idx].append(self.tree[idx*2+1][r])
        r += 1
      while l < mid-start+1 :
        self.tree[idx].append(self.tree[idx*2][l])
        l += 1
    _init(0, N-1, 1)

  def search(self, left, right, val) :
    left -= 1
    right -= 1

    def upper_bound(start, end, idx) :
      while start < end :
        mid = (start + end) // 2
        if self.tree[idx][mid] <= val :
          start = mid + 1
        else :
          end = mid
      return len(self.tree[idx]) - end

    def _search(start, end, idx) :
      if right < start or left > end :
        return 0
      if left <= start <= end <= right :
        return upper_bound(0, end-start+1, idx)
      mid = (start + end) // 2
      lval = _search(start, mid, idx*2)
      rval = _search(mid+1, end, idx*2+1)
      return lval + rval

    return _search(0, N-1, 1)

segtree= MergeSortTree()
last_ans = 0
def cvt(x) :
  return int(x) ^ last_ans
  
for _ in range(int(input())) :
  i, j, k = map(cvt, input().split())
  last_ans = segtree.search(i, j, k)
  print(last_ans)

풀이 완료!

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.