새소식

PS/백준

[백준/13246] 행렬 제곱의 합 (Python)

  • -

Problem : https://www.acmicpc.net/problem/13246

 

13246번: 행렬 제곱의 합

첫째 줄에 행렬의 크기 N과 B가 주어진다. (2 ≤ N ≤ 5, 1 ≤ B ≤ 100,000,000,000) 둘째 줄부터 N개의 줄에 행렬의 각 원소가 주어진다. 행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0이다.

www.acmicpc.net

 

Difficulty : Gold 1

 

Status : Solved

 

Time : 00:16:07

 


 

문제 설명

 

더보기

 

크기가 N*N인 행렬 A가 주어진다. 이때, A의 1제곱부터 A의 B제곱까지 더한 행렬을 구하는 프로그램을 작성하시오. 즉, S = A^1 + A^2 + ... + A^B를 구해야 한다.

수가 매우 커질 수 있으니, S의 각 원소를 1,000으로 나눈 나머지를 출력한다.

 

 

입력 및 출력

 

더보기

입력

 

첫째 줄에 행렬의 크기 N과 B가 주어진다. (2 ≤ N ≤ 5, 1 ≤ B ≤ 100,000,000,000)

둘째 줄부터 N개의 줄에 행렬의 각 원소가 주어진다. 행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0이다.

 

출력

 

첫째 줄부터 N개의 줄에 걸쳐 행렬 S를 출력한다.

 

입력 예시

 

2 5
1 2
3 4

 

출력 예시

 

313 914
871 184

 

 


 

풀이

 

우선 B의 크기가 매우 클 수 있으므로, 분할 정복으로 거듭제곱꼴로 나타내야 할 것 같다. 문제는 그 다음. A^n꼴은 log시간 내에 구할 수 있겠지만, 이 A^n의 합을 어떻게 구할지이다.

 

다시 돌이켜 생각해 보자. 본 함수를 다음과 같이 가정하자.

 

이는 다음 수식으로 바꿔쓸 수 있다.

 

만약 B가 홀수라면 (B%2 == 1 ) 마지막 항이 남는다.

여기서 sigma 항을 보면 f(A, B)와 매우 유사함을 알 수 있다. 즉 이는 재귀적으로 f(A, floor(B/2))로 구할 수 있다. 수식을 정리하면..

 

로, summation 함수 역시 재귀적으로 구할 수 있다.

 

풀이 코드

MOD = 1000
N, B = map(int, input().split())
mat = []
for _ in range(N) :
  _mat = list(map(lambda x : int(x) % MOD, input().split()))
  mat.append(_mat)

def matsum(A, B = None) :
  if B is None :
    B = [[0 if i != j else 1 for j in range(N)] for i in range(N)]
  return [[(A[i][j] + B[i][j]) % MOD for j in range(N)] for i in range(N)]

def matmul(A, B) :
  tmp = [[0]*N for _ in range(N)]
  for i in range(N) :
    for j in range(N) :
      for k in range(N) :
        tmp[i][j] = (tmp[i][j] + A[i][k] * B[k][j]) % MOD
  return tmp

def matpow(A, n) :
  if n == 1 :
    return A
  p = matpow(A, n // 2)
  p = matmul(p, p)
  if n % 2 :
    p = matmul(p, A)
  return p

def solve(m, b) :
  if b == 1 :
    return m
  lmat = matsum(matpow(m, b // 2))
  rmat = solve(m, b // 2)
  res = matmul(lmat, rmat)
  if b % 2 :
    res = matsum(res, matpow(m, b))
  return res

ans = solve(mat, B)
for _ans in ans :
  print(*_ans)

풀이 완료!

 

p.s. 위 수식은 A^B를 각 단계에서 추가로 계산해야하기에 비효율적으로 보인다. 이는 초항을 기준으로 둘을 묶었기 때문이다. 마지막 항을 기준으로 초항을 묶는다면 B % 2 == 0 일때 다음과 같이 변할 것이다.

 

 

수식적으로는 더 깔끔하게 정리할 수 있겠지만, matrix의 sum은 O(N^2), multiplication은 O(N^3)시간복잡도가 소요된다. 즉 summation으로 가급적 처리하는 것이 훨씬 효율적이라고 생각하여 위와 같이 구현하였다. 즉 위 코드의 solve함수를 다음과 같이 바꾸어주자.

 

def solve(m, b) :
  if b == 1 :
    return m
#  lmat = matsum(matpow(m, b // 2))
  if b % 2 :
  	lmat = matsum(m, matpow(m, b // 2 + 1))
  else :
  	lmat = matsum(matpow(m, b // 2))
  rmat = solve(m, b // 2)
  res = matmul(lmat, rmat)
#  if b % 2 :
#    res = matsum(res, matpow(m, b))
  if b % 2 :
    res = matsum(m, matmul(m, res))
  return res

아래쪽이 오리지널 코드, 위쪽이 개선된 코드의 결과이다.

 

실제 테스트에서도 좀 더 빠르게 계산해내는 것을 확인할 수 있다.

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.