backward: 미분값 계산

딥러닝으로 모델을 학습시키기 위해 미분 값을 구하는 과정이 필요하다. 만약 왜 미분이 필요한지 모른다면 '경사하강법과 학습률'을 참고하면 된다. 해당 내용을 몰라도 이번 글을 이해하는 데는 문제가 없다. 

  • 문제점
  • Chain-rule
  • 연산자와 미분 결과
  • 계산 과정 시각화
  • 역방향으로 계산

문제점

일반적으로 미분값을 구할 때, 도함수를 구한 후 값을 대입해 계산한다.

f(x)=ax3+bx2+cddxf(x)=3ax2+2bx

하지만 문제는 모델의 연산 과정이 너무 복잡하다. 

etc-image-0
CNN 구조

f(x)=Linear(Droupout(...(maxpool(relu(conv(...))))))
위 예시는 아주 기본적인 CNN 모델의 구조이다. 그리고 가장 많이 사용되는 손실 함수인 Cross Entropy with Softmax는 아래와 같이 정의된다. 
H(x,y)=1NcNlog(exp(f(xc))iNexp(f(xi)))yc
위 모델을 f(x)라고 할 때, H(x, y)에 f(x)를 대입하고 도함수를 구한다고 생각하면 막막하다. 도함수를 구하는 것이 매우 복잡하고 비효율적이다. 이러한 문제를 해결할 수 있는 아이디어로 Chain-Rule이 있다.


Chain-Rule

미분 값을 구하는 과정을 이해하기 위해서는 연쇄 법칙이라고 부르는 Chain-Rule에 대한 배경 지식이 필요하다.
dLda=dLdcdcdbdbda
결론부터 이야기하면 연쇄 법칙은 분수를 약분하듯이 분자, 분모가 연쇄적으로 계산된다는 법칙이다. 여기까지만 알아도 미분값을 구하는 데는 문제가 없다. 그래도 확실히 하기 위해 조금 더 자세히 이야기해 보겠다. 합성 함수를 미분하기 위해 연쇄 법칙을 적용해 보자. 
 
y=f(g(x))=fg(x)

g(x)=u

f(u)=y
미분 가능한 함수 f와 g에 대해 위와 같은 관계가 성립한다고 하자. 
dydu=limu0yu
미분의 정의를 활용해 작성한 식이다. 여기서 u와 x의 관계를 살펴보자.
u=g(x+x)g(x)

x0  then  u0
u의 변화량은 g(x)의 변화량을 뜻한다. x의 변화량이 0에 가까워지면 u의 변화량도 0에 가까워진다. 따라서 dy/du를 다시 정의하자. 
dydu=limu0yu=limx0yu
dudx=limx0ux
이제 구해둔 단서를 연결해 보자. 
dydududx=limx0yulimx0ux=limx0(yuux)=limx0yx=dydx
극한 값의 특징을 이용해 연쇄 법칙을 확인할 수 있다. 그리고 합성 함수의 미분 값을 여러 미분 값으로 쪼갤 수 있다는 것도 알 수 있다.
 
그렇다면 앞에서 봤던 모델의 미분도 연쇄 법칙을 이용해 쪼개볼 수 있다. 
y=maxpool(relu(conv(x)))=maxpoolreluconv(x)
위와 같은 합성 함수의 연산을 아래와 같이 표현할 수 있다. 
conv(x)=z0relu(z0)=z1maxpool(z1)=y
dydx=dydz1dz1dz0dz0dx
그리고 각각의 함수는 덧셈, 곱셈, 제곱, log, sin, cos 등등 작은 단위의 연산으로 쪼갤 수 있다. 따라서 우리는 작은 단위의 연산에 대해 미분 값을 정의하면, 커다란 함수의 미분값도 계산할 수 있게 된다. 


연산자와 미분 결과

다양한 연산이 있지만 이번 글에서는 가장 기본적인 덧셈(+), 뺄셈(-), 곱셉(*), 제곱(^2)의 미분 값만 살펴보겠다.

덧셈, 뺄셈
dda(a+b)=1,dda(ab)=1
덧셈과 뺄셈은 어떤 값이 더해지던 항상 미분 값은 1이다. 
 
곱셈
dda(ab)=b
a에 대한 미분값으로 b가 나왔다. 즉, 곱해진 값을 미분 값으로 갖는다는 것을 알 수 있다. 
 
제곱
ddaa2=2a
제곱은 알다시피 2를 곱한 값을 미분 값으로 갖는다. 


순방향 계산

계산되는 과정을 확인하기 위해 간단한 식을 하나 만들어보자. 
L=(wx+by)2 
wx+b라는 일차 함수와 y의 차이를 제곱한 값을 L이라고 두었다. 이제 위 식의 계산 과정을 단계별로 생각해 보자. 먼저 w와 x를 곱하고, b를 더한 후, y를 뺀다. 마지막으로 제곱을 한다. {x=4, w=1, b=-3, y=3}일 때, 과정을 그림으로 표현하면 아래와 같다. 

etc-image-1

왼쪽부터 순서대로 값과 연산자를 거쳐 최종적으로 L이 계산된다. 그리고 계산된 각각의 결괏값을 저장해뒀다. 


역방향으로 계산

etc-image-2

dLdw=dLdz2dz2dz1dz1dz0dz0dw

앞에서 봤던 연쇄 법칙을 이용해 dL/dw를 정의했다.
그럼 이제 하나씩 찾아가 보자. 
dLdz2=2z2
L은 제곱(^2)으로 계산되었다. 위에서 제곱은 2를 곱한 값을 미분 값으로 갖는 것을 확인했었다. 
dz2dz1=1
z2는 뺄셈(-)으로 계산되었다. 뺄셈은 항상 1을 미분 값으로 갖는다. 
dz1dz0=1
z1은 덧셈(+)으로 계산되었고 덧셈도 항상 미분 값으로 1을 갖는다. 
dz0dw=x
z0는 곱셈(*)으로 계산되었다. 따라서 곱해진 값인 x를 미분값으로 갖는다. (z0 = wx)
dLdw=dLdz2dz2dz1dz1dz0dz0dw=2z211x=2(2)4=16
결과적으로 w에 대한 L의 미분 값은 -16으로 나온다. 순방향으로 계산하는 과정에서 z2의 값을 저장해 두었기 때문에 별다른 연산 없이 바로 결과를 구할 수 있었다. 따라서 순방향으로 한 번, 역방향으로 한 번 계산하고 나면 미분값을 구할 수 있다. 
직접 손으로 도함수를 계산해 미분 값을 구해도 동일한 결과가 나오는 것을 확인할 수 있다.