지이이이인-짜 오래걸렸다. Pytorch를 비롯해 인공지능 프레임워크를 사용해보기도 했고, 꽤 잘 다룬다고 생각했었는데 일일히 구현하는 건 쉽지 않은 모양이다. Pytorch source code까지 뜯어보고 나서야 어느 정도 감을 잡았던 것 같다.
각설하고 문제 풀이에 들어가자면, 이 문제는 언뜻 보기엔 쉬운 구현 문제지만 여러가지 함정이 있다.
우선, math module을 사용할 수 없으므로 자연 상수 e를 다른 방식으로 구현해야 한다. 또한, Softmax는 지수 승으로 이루어져있는데(위 수식 참고바람) 그 특성상 입력값에 따라 기하급수적으로 전체 exponential값이 증가한다. 따라서 Softmax의 overflow 문제를 해결하도록 수식을 변경해야 한다. 마지막으로, 이러한 Softmax 연산은 소수의 사칙연산 문제로 볼 수 있다. 문제는 컴퓨터는 소수 연산에 대해 정확한 값을 제공하지 않는다! (부동 소수점을 이용하기 때문). 따라서, 이러한 소수 연산에 따른 오차 역시 최대한 보정해야 한다.
즉 주요 고려 사항은
자연 상수 e 구현
Softmax의 overflow 해결
부동소수점 연산 오차 해결
이 되겠다.
그러나 우리의 파이썬은 1, 3번을 동시에 해결 가능한 모듈을 하나 준비했는데, 이게 decimal module이 되시겠다. decimal module은 이러한 부동소수점 연산 오차를 지원하기 위한 기능들을 제공한다.
또한 한 가지 더. Decimal class는 내부 함수로 exp()를 제공하는데, 이 exp는 원래 Decimal 값의 자연상수 지수승, 즉 $ e^x $를 지원한다! 따라서 이러한 Decimal class를 이용하면 쉽게 문제를 해결할 수 있다(다른 언어에서도 비슷한 역할을 하는 라이브러리가 존재하리라 본다)
(이러한 우회 방법을 이용하지 않고 $ e^x $를 구현할 수도 있다. 테일러 급수를 이용하면 되니까.)
그럼 남은건 2번째인데... 정답은 Pytorch source code에서 발췌했음을 우선 밝혀 둔다. 인공지능 프레임워크에서 주로 activation function으로 사용하는 기능이기도 하고, 당연히 overflow 문제를 위한 방지책을 만들어두었으리라 쉽게 예상 가능하다. 우선, 다시 Softmax function으로 돌아가보자.
따라서 집합 x의 모든 값은 0 이하가 되고, e^x의 값은 0과 1사이의 값으로 scaling된다.
풀이 코드
from decimal import *
getcontext().prec = 100
min_val = Decimal('1e-15')
def exp_val(num):
return num.exp()
def make_decimal_result(a, b):
result = a / b
return result.quantize(min_val, rounding=ROUND_DOWN)
def softmax():
global max_val
N = int(input())
num_lst = list(map(Decimal, input().split()))
max_val = max(num_lst)
for i in range(N) :
num_lst[i] -= max_val
exp_lst = list(map(exp_val, num_lst))
exp_sum = Decimal(0.)
for i in exp_lst :
exp_sum += i
for i in range(N) :
print('{:d}: {:.15f}'.format(i+1, make_decimal_result(exp_lst[i], exp_sum)))
softmax()
총 정답자 수가 2명인데, 한 분은 출제자이니 사실상 내가 1등인가...? 시간은 오래 걸렸어도 기분은 꽤 좋다!