새소식

PS/CodeUp

[CodeUp/5301] Softmax (Python)

  • -

Problem : softmax (codeup.kr)

Status : Solved

Time : ??????

 


 

문제 설명

 

더보기

찬형이는 오늘 학교에서 인공지능수학 수업을 들었다. 다중 클래스 분류 문제를 풀 때 인공지능이 계산한 값을 확률로 바꿔주는 Softmax를 배웠다.  

 

$$ Softmax(x) = \frac{e^{a_k}}{\sum_{i = 1}^{n}e^{a_i}} $$


찬형이는 오늘 배운 Softmax 함수를 직접 구현해보기로 했다. 하지만 찬형이는 CodeUp 기초 100제를 풀고 있는 코딩 초보라 Softmax를 구현하기에는 무리가 있어 보인다. 

어려움을 겪고 있는 찬형이를 위해 Softmax 함수를 구현해주자!

(사용 불가 : numpy, exec, math)

 

입력 및 출력

 

더보기

입력

첫째 줄에 데이터의 개수 n(1≤n≤5,000)이 입력된다.

둘째 줄에 n개의 실수 데이터가 공백으로 구분되어 입력된다.

(−20,000≤ai≤20,000, a1,a2,...,an)

ai는 최대 소수점 넷째 자리까지 입력된다.

 

출력

n개 데이터의 Softmax의 값을 출력한다. 출력은 소수점 열다섯째 자리까지 출력한다.  

 

입력 예시

4

1.1 1.2 1.3 1.4

 

출력 예시

1: 0.213838220365984

2: 0.236327782321537

3: 0.261182592155075

4: 0.288651405157402

 

 


 

풀이

 

지이이이인-짜 오래걸렸다. Pytorch를 비롯해 인공지능 프레임워크를 사용해보기도 했고, 꽤 잘 다룬다고 생각했었는데 일일히 구현하는 건 쉽지 않은 모양이다. Pytorch source code까지 뜯어보고 나서야 어느 정도 감을 잡았던 것 같다.

 

각설하고 문제 풀이에 들어가자면, 이 문제는 언뜻 보기엔 쉬운 구현 문제지만 여러가지 함정이 있다.

 

우선, math module을 사용할 수 없으므로 자연 상수 e를 다른 방식으로 구현해야 한다. 또한, Softmax는 지수 승으로 이루어져있는데(위 수식 참고바람) 그 특성상 입력값에 따라 기하급수적으로 전체 exponential값이 증가한다. 따라서 Softmax의 overflow 문제를 해결하도록 수식을 변경해야 한다. 마지막으로, 이러한 Softmax 연산은 소수의 사칙연산 문제로 볼 수 있다. 문제는 컴퓨터는 소수 연산에 대해 정확한 값을 제공하지 않는다! (부동 소수점을 이용하기 때문). 따라서, 이러한 소수 연산에 따른 오차 역시 최대한 보정해야 한다.

 

즉 주요 고려 사항은

  1. 자연 상수 e 구현
  2. Softmax의 overflow 해결
  3. 부동소수점 연산 오차 해결

이 되겠다.

 

그러나 우리의 파이썬은 1, 3번을 동시에 해결 가능한 모듈을 하나 준비했는데, 이게 decimal module이 되시겠다. decimal module은 이러한 부동소수점 연산 오차를 지원하기 위한 기능들을 제공한다.

 

decimal — 십진 고정 소수점 및 부동 소수점 산술 — Python 3.11.0 문서

 

decimal — 십진 고정 소수점 및 부동 소수점 산술 — Python 3.11.0 문서

The decimal module provides support for fast correctly rounded decimal floating point arithmetic. It offers several advantages over the float datatype: 모듈 설계의 중심 개념은 세 가지입니다: 십진수, 산술을 위한 컨텍스트, 신호(

docs.python.org

 

또한 한 가지 더. Decimal class는 내부 함수로 exp()를 제공하는데, 이 exp는 원래 Decimal 값의 자연상수 지수승, 즉 $ e^x $를 지원한다! 따라서 이러한 Decimal class를 이용하면 쉽게 문제를 해결할 수 있다(다른 언어에서도 비슷한 역할을 하는 라이브러리가 존재하리라 본다)

(이러한 우회 방법을 이용하지 않고 $ e^x $를 구현할 수도 있다. 테일러 급수를 이용하면 되니까.)

 

Exponential function - Wikipedia

 

Exponential function - Wikipedia

From Wikipedia, the free encyclopedia Jump to navigation Jump to search Mathematical function, denoted exp(x) or e^x This article is about the function f(x) = ex and its generalizations. For functions of the form f(x) = xr, see Power function. For the biva

en.wikipedia.org

 

그럼 남은건 2번째인데... 정답은 Pytorch source code에서 발췌했음을 우선 밝혀 둔다. 인공지능 프레임워크에서 주로 activation function으로 사용하는 기능이기도 하고, 당연히 overflow 문제를 위한 방지책을 만들어두었으리라 쉽게 예상 가능하다. 우선, 다시 Softmax function으로 돌아가보자.

$$ Softmax(x) = \frac{e^{a_k}}{\sum_{i = 1}^{n}e^{a_i}} $$

그런데, Softmax의 분자 및 분모는 e^x꼴로 이루어졌음을 알 수 있다! 따라서 분자와 분모에 상수항인 e^C를 곱해본다면...

$$ Softmax(x) = \frac{e^{a_k}\cdot e^{c}}{\left ( \sum_{i = 1}^{n}e^{a_i}\right )\cdot e^{c}} = \frac{e^{a_k + c}}{\sum_{i = 1}^{n}e^{a_i + c}} $$
 

이런 결과를 가져온다. 즉 모든 입력값에 대해 Scaling이 가능하다는 의미이다. 여기서 C = - max(a) 로 둔다면, 즉 모든 숫자 집합에 대해 그 집합의 최댓값을 뺀 값을 넣어도 원래 Softmax 수식과 동일한 결과를 가져온다.

$$ Softmax(x) = \frac{e^{a_k - max(x)}}{\sum_{i = 1}^{n}e^{a_i - max(x)}} $$

따라서 집합 x의 모든 값은 0 이하가 되고, e^x의 값은 0과 1사이의 값으로 scaling된다.

풀이 코드

from decimal import *

getcontext().prec = 100
min_val = Decimal('1e-15')
def exp_val(num):
  return num.exp()

def make_decimal_result(a, b):
  result = a / b
  return result.quantize(min_val, rounding=ROUND_DOWN)

def softmax():
  global max_val
  N = int(input())
  num_lst = list(map(Decimal, input().split()))
  max_val = max(num_lst)
  for i in range(N) :
    num_lst[i] -= max_val

  exp_lst = list(map(exp_val, num_lst))
  exp_sum = Decimal(0.)
  for i in exp_lst :
    exp_sum += i
  
  for i in range(N) :
    print('{:d}: {:.15f}'.format(i+1, make_decimal_result(exp_lst[i], exp_sum)))
    
softmax()

풀이 완료!

 

총 정답자 수가 2명인데, 한 분은 출제자이니 사실상 내가 1등인가...? 시간은 오래 걸렸어도 기분은 꽤 좋다!

'PS > CodeUp' 카테고리의 다른 글

[CodeUp/1754] 큰 수 비교 (Python)  (0) 2022.11.29
[CodeUp/2753] 수열의 n번째 항 (Python)  (0) 2022.11.29
[CodeUp/4787] 택배 (Python)  (2) 2022.11.28
[CodeUp/4786] 올림픽 (Python)  (1) 2022.11.28
[CodeUp/4425] 잠수함 식별(Python)  (0) 2022.11.28
Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.