새소식

PS/백준

[백준/7469] K번째 수 (Python)

  • -

Problem : https://www.acmicpc.net/problem/7469

 

7469번: K번째 수

현정이는 자료 구조 프로젝트를 하고 있다. 다른 학생들은 프로젝트 주제로 스택, 큐와 같은 기본 자료 구조를 구현하는 주제를 선택했다. 하지만, 현정이는 새로운 자료 구조를 만들었다. 현정

www.acmicpc.net

 

Difficulty : Platinum 2

 

Status : Solved

 

Time : 00:41:23

 


 

문제 설명

 

더보기

 

현정이는 자료 구조 프로젝트를 하고 있다. 다른 학생들은 프로젝트 주제로 스택, 큐와 같은 기본 자료 구조를 구현하는 주제를 선택했다. 하지만, 현정이는 새로운 자료 구조를 만들었다.

현정이가 만든 자료구조는 배열을 응용하는 것이다. 배열 a[1...n]에는 서로 다른 수가 n개 저장되어 있다. 현정이는 여기에 Q(i,j,k)라는 함수를 구현해 모두를 놀라게 할 것이다.

Q(i,j,k): 배열 a[i...j]를 정렬했을 때, k번째 수를 리턴하는 함수

예를 들어, a = (1,5,2,6,3,7,4)인 경우 Q(2,5,3)의 답을 구하는 과정을 살펴보자. a[2...5]는 (5,2,6,3)이고, 이 배열을 정렬하면 (2,3,5,6)이 된다. 정렬한 배열에서 3번째 수는 5이다. 따라서 Q(2,5,3)의 리턴값은 5이다.

배열 a가 주어지고, Q함수를 호출한 횟수가 주어졌을 때, 각 함수의 리턴값을 출력하는 프로그램을 작성하시오.

 

 

입력 및 출력

 

더보기

입력

 

첫째 줄에 배열의 크기 n과 함수 Q를 호출한 횟수 m이 주어진다. (1 ≤ n ≤ 100,000, 1 ≤ m ≤ 5,000)

둘째 줄에는 배열에 포함된 정수가 순서대로 주어진다. 각 정수는 절댓값이 10^9를 넘지 않는 정수이다.

다음 m개 줄에는 Q(i,j,k)를 호출할 때 사용한 인자 i,j,k가 주어진다. (1 ≤ i ≤ j ≤ n, 1 ≤ k ≤ j-i+1)

 

출력

 

Q함수를 호출할 때마다 그 함수의 리턴값을 한 줄에 하나씩 출력한다. 

 

입력 예시

 

7 3
1 5 2 6 3 7 4
2 5 3
4 4 1
1 7 3

 

출력 예시

 

5
6
3

 

 


 

풀이

 

처음에는 좌표 압축 + 카운팅 세그먼트트리 + mo's + 오프라인 쿼리로 풀어보려 했지만, 여지없이 시간초과를 겪었다. 시간복잡도가 O((N+M)*sqrt(N)*log(N))이 되며, 이 값은 사실상 N >> M이므로 O(Nsqrt(N)log(N))이 된다. N이 10^6이라고 가정했을 때 시간 초과를 일으키기 딱 좋다.

 

즉 발상을 전환할 필요가 있었다.

  • 머지 소트 트리를 초기화하면, 트리의 [start, end] 구간의 정렬된 리스트를 구할 수 있다.
  • 서칭 시, [left, right] 구간 내에 해당하는 모든 트리의 정렬된 리스트를 구할 수 있다.
    • 시간복잡도는 O(logN)이 소요된다.
  • 이를 매개변수 이분 탐색으로 k번째 값을 구할 수 있을 것이다.
    • 정렬된 리스트들에서 매개 변수가 몇번째 원소인지를 이분 탐색(lower bound)으로 각각 구한다면, 그 합은 전체 리스트에서의 위치(lower bound)를 나타낸다.
      • 이를테면 M = 2, List = [[1, 2, 3], [0, 2], [1, 4, 5]] 라고 가정하자.
      • 각 lower bound는 [1, 1, 1] 이고 이 합인 3은 총 List인 [0, 1, 1, 2, 3, 4, 5]에서의 lower bound와 동일하다.
    • 따라서 이 lower bound = k를 만족하는 값을 lower bound로 구하면 된다.
    • 이분 탐색의 시간복잡도는 O((logN)^2)이 된다.
  • 총 시간복잡도는 O(M(logN)^3)이 될 것이다.

 

 

풀이 코드

import bisect
import sys
input = sys.stdin.readline

N, M = map(int, input().split())
MAX = 10**9
nums = list(map(int, input().split()))

tree = [list() for _ in range(4*N)]

def init(start = 0, end = N-1, idx = 1) :
  if start == end :
    tree[idx] = [nums[start]]
    return
  mid = (start + end) // 2
  init(start, mid, idx*2)
  init(mid+1, end, idx*2+1)
  tree[idx] = tree[idx*2] + tree[idx*2+1]
  tree[idx].sort()

def search(left, right, target) :
  idx_list = []
  def _search(start, end, idx) :
    if right < start or left > end :
      return
    if left <= start <= end <= right :
      idx_list.append(idx)
      return
    mid = (start + end) // 2
    _search(start, mid, idx*2)
    _search(mid+1, end, idx*2+1)

  _search(0, N-1, 1)
  start, end = -MAX, MAX
  ans = start
  while start < end :
    mid = (start + end) // 2
    cur = 0
    for idx in idx_list :
      cur += bisect.bisect_left(tree[idx], mid)
    if cur > target :
      end = mid
    else :
      ans = mid
      start = mid + 1
  print(ans)

init()
for _ in range(M) :
  i, j, k = map(int, input().split())
  search(i-1, j-1, k-1)

풀이 완료!

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.