Problem : https://www.acmicpc.net/problem/31034
31034번: 초전도체 부수기
당신은 상온 상압 초전도체를 개발하고 세상을 뒤바꿀 논문을 작성했다. 당신은 $N\mathrm{g}$의 초전도체 덩어리를 가지고 있는데, 논문 검증을 위해 $K$개의 연구소에서 초전도체 샘플을 요청했다
www.acmicpc.net
Difficulty : Platinum 3
Status : Solved
Time : 00:18:10
더보기
당신은 상온 상압 초전도체를 개발하고 세상을 뒤바꿀 논문을 작성했다. 당신은 Ng의 초전도체 덩어리를 가지고 있는데, 논문 검증을 위해 K개의 연구소에서 초전도체 샘플을 요청했다! 각 연구소에는 1g 이상의 초전도체 샘플을 보내주면 된다. 다행히도, 당신은 초전도체를 정밀하게 부수는 기술을 가지고 있다. 2 이상의 정수 a에 대하여, aㅎ의 초전도체를 다음과 같은 방법으로 절단할 수 있다
1이상 a 미만의 정수 b를 선택한다.
bg 짜리 초전도체와 (a-b)g짜리 초전도체 두 개로 쪼갠다.
이때, a원의 비용이 든다.
연구 비용 절감을 위해 초전도체를 K개의 조각으로 자르기 위한 최소 비용은 얼마일지 구해야 한다. 단, 제한 조건하에서 위의 방법을 통해 초전도체를 K개의 조각으로 쪼갤 수 있음을 증명할 수 있다.
더보기
입력
첫 번째 줄에 테스트 케이스의 개수 T가 주어진다. T개의 줄에 이어, 각 테스트 케이스마다 한 줄에 초전도체 덩어리의 무게 N과 초전도체 샘플을 요청한 연구소의 수 K가 공백을 사이에 두고 주어진다.
출력
각 테스트 케이스마다 초전도체를 K개의 조각으로 자르기 위한 최소 비용을 출력한다.
입력 예시
3
2 2
5 3
10000 1000
출력 예시
2
7
19965
덧셈 으로 관점을 바꾸어 보자. 초기에는 K개의 조각이 존재하고, K개의 조각을 하나로 합칠 것이다. 이 때 임의의 i와 jg의 조각을 합치면 그 비용은 (i+j)g이 된다. 우리는 다음과 같은 사실을 알 수 있다.
조각을 합칠 때, 덧셈은 K-1번 발생한다.
각 덧셈에서 가장 낮은 비용인 두 조각을 합치는 것 이 최소 비용이다.
초기 비용이 제일 낮으려면, 가능한 한 최소 단위(1g)인 조각의 개수가 많아야 한다 .
즉 다음 사실들을 통해 그리디 하게 접근해보자.
우리는 초기에 K-1개의 1g의 조각과, 1개의 N-K-1g의 조각을 가지고 있다. 우리는 재귀적으로 다음과 같은 방식으로 빠르게 총 비용을
가장 적은 조각(i) 의 개수(j)가 2개 이상일 경우 : 이 조각들만을 짝지어 새로운 조각을 만들자. 새로운 조각은 무게가 i*2이며, 개수는 j // 2가 된다.
가장 적은 조각(i)의 개수가 1개일 경우 : 두 번째로 적은 조각 하나(k)를 사용하여 둘을 합치자. 새 조각은 무게가 i + k이며, 개수는 1개이다.
이 과정을 1개의 조각이 남을 때까지 반복하면 된다. 트리 기반의 SortedList 등을 사용하면 좀 더 구현이 쉽겠지만, 아쉽게도 파이썬 내장 모듈은 이를 지원하지 않아 힙과 카운팅 딕셔너리를 이용해 유사하게 구현하였다.
풀이 코드
from collections import defaultdict
from heapq import *
import sys
input = sys.stdin.readline
def solve():
N, K = map(int, input().split())
count = defaultdict(int)
count[1] += K - 1
count[N - K + 1] += 1
q = list(count.keys())
heapify(q)
ans = 0
while sum(count.values()) > 1:
if count[q[0]] > 1:
key, val = q[0] * 2, count[q[0]] // 2
if key not in count:
heappush(q, key)
count[key] += val
count[q[0]] %= 2
if count[q[0]] == 0:
heappop(q)
ans += key * val
else:
key = heappop(q)
sec_key = q[0]
count[key] -= 1
count[sec_key] -= 1
if key + sec_key not in count:
heappush(q, key + sec_key)
count[key + sec_key] += 1
ans += key + sec_key
if count[sec_key] == 0:
heappop(q)
print(ans)
for _ in range(int(input())) :
solve()
풀이 완료!