아래 references의 자료들을 공부하고 짧게 정리한 내용입니다! 저도 공부하면서 정리한 내용이라 틀린 것이 있다면 언제든 댓글 달아주신다면 감사하겠습니다~ :D
____
1. Reparameterization Trick
모델을 학습시킬 때 주어진 데이터셋이 아닌, parameterized된 모델로부터 data를 sampling해서 학습하는 경우가 있다:

이런 경우 그냥 pθ(x)에서 데이터를 샘플링해서 쓰면 되나?? 라는 생각이 들겠지만 gradient를 구해보면 문제가 생긴 것을 알 수 있다.

위 식과 같이 gradient를 구해보면, 오른쪽 항의 경우 pθ(x)에 대한 expectation으로 정리가 되지만, 왼쪽 항의 경우는 pθ(x)에 대한 expectation으로 정리되지 않는다. gradient를 구해야하는데 gradient에 대한 expectation으로 정리가 된다.
따라서 reparameterization trick은, 위의 식을 parameterize되지 않은 변수에 대한 expectation으로 바꾸는 trick이다. 만약 p(x)가 gaussian이었다면, fθ의 input x는 gaussian에서 샘플링 된 것이므로 다음과 같이 분해할 수 있다:

여기서 μ와 σ는 p(x)의 parameter이므로 gradient를 구해야 할 대상이다. 반면 ε는 parameterize되지 않았으며 stochastic noise로 취급할 수 있다. 결과적으로 pθ(x)에 대한 expectation을 ε의 분포 N(0,I)에 대한 expectation으로 바꾸자는 것이다.
2. Policy Gradient에서 Reparameterization Trick
방법1: Reparameterization Trick
그렇다면 어떻게 ε의 분포 N(0,I)에 대한 expectation으로 바꾸는지에 대해 살펴보자. 아래는 policy gradient objective의 gradient 수식이다:

앞서 살펴본 것과 같이 τ는 policy distribution pθ(τ)에서 샘플링된다. 따라서 gradient를 구하기 어려울 것이므로, R(τ)의 input τ를 다음과 같이 분해한다:

이제부터 μ + σ*ε은 함수 g로, ε을 샘플링하는 간단한 gaussian distribution(ex. normal gaussian)은 q로 표현하도록 하자. 결과적으로 policy gradient의 objective는 다음과 같이 정리된다:

이렇게 결과적으로 미분을 expectation안으로 집어넣을 수 있었다! 따라서 batch의 평균 gradient를 구할 수 있게 되었다.
방법2: Score Function Estimator
그런데 기억상으로 policy gradient에서 reparameterization trick을 안썼던것 같은데? 하고 의아해하는 사람도 있을 것이다. (저입니다;) 사실 policy gradient에서는 score function estimator를 활용해서 식을 변형하는 trick을 주로 사용한다! 식을 정리해보면:

이렇게 분자 분모에 p(τ)를 곱함으로써 p(τ)에 대한 expectation으로 정리하면, 안의 내용들은 깔끔하게도 log p(τ)에 대한 gradient로 정리된다. 이 때 ∇log p(τ)를 score function estimator라고 한다.
3. VAE에서 Reparameterization Trick
그렇다면 VAE에서도 꼭 reparameterization trick을 써야하나? policy gradient같은 trick은 없었을까? 하고 궁금증이 들 것이다. 결론부터 얘기하면 VAE는 두 가지 모델로 구성되어 두 가지 parameter를 갖고 있기 때문에 불가능하다. ㅠㅠ 일단 VAE의 그림을 다시 한번 간단하게 살펴보자:

이와 같이 encoder에서 x를 latent z로 인코딩하고, decoder가 이 z를 input으로 받아 다시 x를 복원하는 방식으로 학습된다! 결과적으로 gradient는 decoder의 parameter θ에 대해 먼저 계산되고, 이후 encoder의 parameter Φ에 대해 계산될 것이다. 그럼 일단 먼저 VAE의 loss부터 다시 한번 살펴보자(유도 과정은 variational inference 포스팅을 참조해주시면 좋을거 같습니다~):

gradient를 구할 때 문제가 생기는 항은 오른쪽의 Eq[log pθ]항이다. 오른쪽 항을 일단 decoder의 parameter θ에 대해 미분해보면:

당연하지만 encoder에 대한 expectation이므로 θ에 대한 미분이 expectation안으로 들어가버린다. 결과적으로 encoder의 parameter값을 반영하지 않은 채 gradient를 구하게 되는데, 이러한 경우 encoder qΦ로부터 data를 샘플해서 구한다고 해보자. 결과적으로 식은

이러한 형태가 될 것이고, 결과적으로 식이 encoder qΦ와 관련이 없어진다. 즉 뭔가 policy gradient에서와 같은 trick을 써보기도 전에 decoder의 gradient가 encoder로 연결이 되지 않게 된다. 따라서 오른쪽 항을 reparameterization trick을 통해 ε에 대한 expectation으로 바꿔 주면, 안에 encoder의 parameter값도 들어갈 것이므로 expectation 내부의 식을 encoder와 decoder parameter 모두에 대해 나타내줄 수 있다.
그렇다면 VAE의 loss 식을 다시 한번 들여다보자. 여기서는 일단 편의를 위해 x의 요소 하나 x0에 대한 loss를 구했다고 생각해보면:

reparameterization trick을 위해 decoder의 input z를 다음과 같이 정의한다:

이 때 z는 앞서 본 경우와 달리 conditional gaussian에서 샘플링되었기 때문에, x에 대한 함수로 정의된다. (conditional gaussian 설명은 이 포스팅 참조) 여기서는 encoder가 q라 헷갈리므로 ε을 샘플링하는 간단한 가우시안을 p(ε)라고 하자. 그러면 이제 ε에 대한 expectation으로 바꿔보면:

이렇게 오른쪽 항을 간단하게 정리할 수 있다! 결과적으로 p(ε)에서 샘플링한 L개의 z들에 대한 expectation을 구하게 된다. 앞서 정의했듯이 z = gΦ(ε, x)와 같이 Φ에 대한 함수로 변경했으므로, 별 변화가 없어보이지만 사실 parameter Φ가 expectation 내부로 들어간 것이다. 헷갈린다면 둘째줄 식의 p(x0|z)에서 z를 gΦ(ε, x)로 바꿔보면 잘 이해될 것이다!
따라서 결과적으로 VAE의 loss는 다음과 같이 정리될 수 있다:

[References]
Reparameterization trick:
https://untitledtblog.tistory.com/181
Reparameterization Trick에 대한 수학적 이해와 기댓값의 미분가능성
1. 기댓값 (Expectation)의 미분가능성 많은 머신러닝 문제에서 우리는 식 (1)과 같이 모델 매개변수 $\theta$에 대해 어떠한 기댓값을 최대화하고자 한다. $$ \theta^* = \underset{\theta}{\text{argmax}} \; E_{p(x)}[f
untitledtblog.tistory.com
VAE:
https://jaejunyoo.blogspot.com/2017/05/auto-encoding-variational-bayes-vae-3.html
초짜 대학원생의 입장에서 이해하는 Auto-Encoding Variational Bayes (VAE) (3)
Machine learning and research topics explained in beginner graduate's terms. 초짜 대학원생의 쉽게 풀어 설명하는 머신러닝
jaejunyoo.blogspot.com
Policy Gradient:
https://stillbreeze.github.io/REINFORCE-vs-Reparameterization-trick/
REINFORCE vs Reparameterization Trick
An introduction and comparison of two popular techniques for estimating gradients in machine learning models
stillbreeze.github.io
'Courses > ML(PRML,cs229)' 카테고리의 다른 글
forward KL, reverse KL, cross-entropy (0) | 2024.09.16 |
---|---|
Exponential Family (0) | 2024.09.15 |
Probability Distributions (0) | 2024.09.15 |