새소식

PS/백준

[백준/9735] 삼차 방정식 풀기 (Python)

  • -

Problem : https://www.acmicpc.net/problem/9735

 

9735번: 삼차 방정식 풀기

첫째 줄에 테스트 케이스의 개수 N (0 < N < 100)이 주어진다. 다음 N개 줄에는 삼차 방정식의 계수 A, B, C, D가 한 줄에 하나씩 주어진다.

www.acmicpc.net

 

Difficulty : Platinum 3

 

Status : Solved

 

Time : 00:42:31

 


 

문제 설명

 

더보기

삼차 방정식 Ax3 + Bx2 + Cx + D = 0 의 모든 실수 해를 찾는 프로그램을 작성하시오.

입력으로 주어지는 방정식은 정수 해를 적어도 한 개 갖는다.

A, B, C, D는 -2,000,000보다 크거나 같고, 2,000,000보다 작거나 같은 정수이고, A는 0이 아니다. 모든 해는 -1,000,000보다 크거나 같고, 1,000,000보다 작거나 같다. 주어지는 방정식의 해의 차이는 10-4보다 크다.

 

입력 및 출력

 

더보기

입력

 

첫째 줄에 테스트 케이스의 개수 N (0 < N < 100)이 주어진다. 다음 N개 줄에는 삼차 방정식의 계수 A, B, C, D가 한 줄에 하나씩 주어진다.

 

출력

 

입력으로 주어진 방정식마다 모든 실수 해를 오름차순으로 출력한다. 해의 절대/상대 오차는 10-4까지 허용한다. 중근이 존재하는 경우에는 한 번만 출력한다.

 

입력 예시

 

2
2 -7 7 -2
2 0 0 0

 

출력 예시

 

0.5000 1.0000 2.0000
0.0000

 

 

 


 

풀이

 

문제 조건상 삼차방정식의 해를 직접 조사해서 적용하라는 것은 아닐 것이고... 다음 힌트를 얻을 수 있다.

 

  • 삼차방정식은 정수 해를 적어도 한 개 갖는다.

 

즉 우리는 다음 과정으로 문제를 풀이해 볼 수 있다.

  • 삼차방정식의 정수해 찾기
  • 조립제법으로 삼차방정식을 (정수해 일차방정식) x (이차방정식) 꼴로 나타내기
  • 이차방정식의 해 찾기

그런데 잠깐. 모든 해는 -10^6 에서 10^6 사이의 값을 가지므로, 브루트포스로 풀려면 골치가 아파진다. 여기서 우리는 다음 사실을 주목할 필요가 있다.

 

정수해 x는 D의 약수이다! 따라서 첫 번째 수식인 삼차방정식 정수해 찾기는 O(sqrt(D))로 시간복잡도가 획기적으로 줄어든다. 이제 조립제법을 보자. 정수해 x가 alpha라고 가정하도록 하자.

즉 다음과 같은 결과로 정리할 수 있다.

 

엣지 케이스도 고려하자. 만약 D == 0이라면, 정수해는 0이고 차수를 낮춰주기만 하면 되겠다. 이제 이차방정식을 풀이하러 가면 된다!

 

 

풀이 코드

import sys
input = sys.stdin.readline
MAX = 10**6
eps = 10**(-4)

def solve_cubic_eq(a, b, c, d) :
  idx = 1
  _range = []
  for i in range(1, int(abs(d)**0.5)+1) :
    if d % i == 0 :
      _range.append(i)
      _range.append(-i)
      if i != abs(d) // i :
        _range.append(abs(d) // i)
        _range.append(-abs(d) // i)
  
  for i in _range :
    if a*i**3 + b*i**2 + c*i + d == 0 :
      return i

def solve_quard_eq(a, b, c) :
  if b ** 2 - 4*a*c < 0 :
    return []
  elif b ** 2 - 4*a*c == 0 :
    return [-b / (2*a)]
  else :
    return [(-b + (b ** 2 - 4*a*c) ** 0.5) / (2*a), (-b - (b ** 2 - 4*a*c) ** 0.5) / (2*a) ]

def solve() :
  A, B, C, D = map(int, input().split())
  if D != 0 :
    a = solve_cubic_eq(A, B, C, D)
    b, c = B + a*A, -D / a
  else :
    a = 0
    b, c = B, C
  tmp = solve_quard_eq(A, b, c)
  ans = [a]
  for t in tmp :
    flg = True
    for _ans in ans :
      if abs(t - _ans) <= eps :
        flg = False
    if flg :
      ans.append(t)
  for _ans in sorted(ans) :
    print("{:.04f}".format(_ans), end = ' ')
  print()

for _ in range(int(input())) :
  solve()

풀이 완료!

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.