Courses/ML(PRML,cs229)

forward KL, reverse KL, cross-entropy

모끼주인 2024. 9. 16. 19:49

이번 포스팅에서는 loss function으로 많이 사용되는 KL divergence 및 cross-entropy에 대해 정리해 보려고 합니다! 

 

1. forward KL

KL divergence는 확률 분포 간 거리를 재는 척도로 data distribution p(x)와 model distribution q(x) 간의 거리를 잴 수 있기 때문에 loss function으로 많이 사용한다. KL divergence는 확률 분포를 대입하는 순서에 따라 값이 달라지는데, 먼저 forward KL 수식은 다음과 같다:

 

 

위 수식에서는 알아보기 쉽도록 data distribution p(x)=p^(x)로, model distribution q(x)=pθ(x)로 표기했다. data distribution은 실제 세계의 distribution에 대한 추정치이기 때문에 p^을, model distribution은 parameter로 추정한 distribution이기 때문에 pθ로 표기한다.

 

위와 같이 정리한 식을 살펴보면 두 가지 체크할 점이 있다:

  1. data distribution p^(x)에 대한 expectation으로 정리된다는 점
  2. 왼쪽 항이 data distribution p^(x)에 대한 entropy로 정리된다는 점

 

1. data distribution p^(x)에 대한 expectation으로 정리된다는 점

 

출처: https://bekaykang.github.io/posts/KL_divergence/

 

위 그림에서 파란색 distribution은 실제 세계의 distribution인 p^(x)를, 주황색 distribution은 model distribution인 pθ(x)를 의미한다. 왜 pθ(x)가 저런 모양을 가졌는지 생각해보자. 만약 p^(x) 중 확률값이 더 높은 오른쪽 mode에만 집중해서 학습한다면, 왼쪽 mode에서 -∑p^(x) logp(x)는 p(x)=0이므로 무한대에 가까워진다. 따라서 전체적인 data 구간을 cover하는 것이 forward KL을 더 효과적으로 minimize할 수 있다. 따라서 forward KL은 mode-covering이다.

 

이러한 mode-covering behavior의 장점은 data distribution의 전반적인 mode를 covering하기 때문에 diversity를 가질 수 있다는 것이다. 예시로 language model을 생각해보면, forward KL로 학습한 모델(ex. DPO)은 reverse KL(ex. PPO)보다 더 다양한 답변을 생성할 수 있을 것이다.

 

어쨌든 결과적으로 p^(x)에 대한 expectation을 수행하므로, pθ(x)는 data의 전체 구간을 cover하며 p^(x)와 최대한 유사해지도록 학습되는 것을 볼 수 있다. 보통 p^(x)에 대한 expectation은 불가능하다. 그러나 우리는 forward KL을 사용하는 경우를 볼 수 있는데, 바로 classification과 같이 label distribution p^(y)를 추정하는 경우이다. 이 경우 random variable y가 가질 수 있는 경우의 수가 정해져있으므로 (ex. language model의 경우 token 개수) expectation을 추정 가능하다. 

 

2. 왼쪽 항이 data distribution p^(x)에 대한 entropy로 정리된다는 점

 

결과적으로 data distribution p^(x)는 추정해야할 대상이며 값이 바뀌지 않는다. 따라서 p^(x)는 model parameter θ와 무관하기 때문에 p^(x)에 대한 entropy term은 최종적인 loss 식에서 무시할 수 있다. 따라서 결과적으로 forward KL은 다음으로 소개할 cross entropy와 같은 것을 optimize하게 된다.

 

2. cross-entropy

 

 

위의 식은 우리가 흔히 알고 있는 cross entropy 수식이다. 결과적으로 forward KL에서 entropy term을 제거한 형태와 똑같은 것을 볼 수 있다.

 

일반적으로 classification model에서 사용되는 경우를 생각해보자. 이 경우 p^(x)는 data에서의 label distribution인 p^(y)가, pθ는 model로 추정한 distribution인 pθ(y|x)가 된다. p^(y)는 보통 각각의 data point가 어떤 label을 가졌는지 나타낸 one-hot vector를 사용한다. 이와 같이 label을 추정하는 경우 conditional distribution으로 식을 대체하면 된다.

 

3. reverse KL

 

다음으로 reverse KL수식을 살펴보자. 이번에도 두 가지 측면에서 살펴보자.

 

1. model distribution pθ(x)에 대한 expectation으로 정리되는 점

 

reverse KL은 위의 수식에서 보듯이 model distribution에 대한 expectation으로 정리된다. 따라서 PPO와 같이 학습 진행 중에 모델로부터 data를 sampling하는 알고리즘의 경우 함께 사용하기 적절하다.

 

reverse KL에서는 model distribution pθ가 어떻게 optimize되는지 살펴보자. model distribution이 cover하는 범위 내에서 KL을 minimize하면 되므로, forward KL에서와는 반대로 작동한다. 즉 어떤 mode에서 model로부터 추정한 확률값 pθ(x)가 항상 0이 되면 되기 때문에 무한대 값을 가지지 않는다. 따라서 아래 그림과 같이 두 가지의 mode가 있으면, 그 중 확률값이 더 높은 mode쪽으로 치우치게 된다. 따라서 reverse KL은 mode-seeking이다.

출처: https://bekaykang.github.io/posts/KL_divergence/

 

2. 왼쪽 항이 pθ(x)에 대한 entropy term이라는 점

 

이번에는 forward KL과 달리 model distribution pθ에 대한 entropy이기 때문에 마음대로 제거할 수 없다. 대신 entropy term이 추가됨으로써 장점은 mode-seeking behavior를 하는 reverse KL의 단점을 보완해줄 수 있다는 점이다. 만약 language model을 생각해본다면, 하나의 mode에만 수렴하게 되면 유사한 질문들에 대해 무조건 하나의 답변만 생성되도록 학습될 수 있다. 그러나 entropy term이 추가된다면 어느 정도 diversity를 유지하면서 mode-seeking behavior를 하도록 학습될 수 있다.