[4-2] Model Inference: Variational Inference (+connection to VAE, RL)
아래 references의 강의 및 자료들을 공부하고 짧게 정리한 내용입니다! 저도 공부하면서 정리한 내용이라 틀린 것이 있다면 언제든 댓글 달아주신다면 감사하겠습니다~ :D
____
1. Why Variational Inference?
앞서 Sampling Methods 포스팅에서 언급했듯이, 대부분의 inference problem은 intractable하다. Inference problem이란, 구하고자 하는 distribution p(x)가 주어졌을 때 어떤 statistics를 계산하는 것이다. 예시로는 다음과 같은 경우가 있다:
대부분의 경우 첫번째와 같이 p(x)만을 남기기 위해 다른 random variable에 대해 marginalize하거나, 두번째와 같이 posterior를 구하기 위해 bayes rule로 conditional distribution을 구하려고 할 것이다. 결과적으로 두 경우 모두 z에 대한 적분이 포함되고, p(x|z)가 아주 간단한 경우가 아니라면 구하기 어렵다 (EM algorithm 포스팅 에서는 posterior q(z|x)를 직접 계산한 후(E step) maximization(M step)했는데, 포스팅 맨 마지막에 써있듯이 전통적인 방식에서 p(z)가 1-d categorical vector일때 처럼 아주 간단한 경우만 연산했기 때문에 가능했다).
따라서 model을 inference하기 위해서, exact inference 대신 approximate inference하는 방법이 필요했다. 그 중 하나가 Sampling Methods 포스팅에서 소개한 MCMC Sampling이고, 오늘 소개할 것은 variational inference이다.
variational inference의 장점.
- Sampling methods의 경우, global optimum으로 수렴하는 것을 보장할 수 있다 (Sampling Methods 포스팅 Section 2.2 Why this work 참조). 하지만 얼마나 오래 걸릴지 알기 어렵다.
- 빠르게 good solution을 찾기 위해서는 Q(x'|x)를 잘 proposal하는 technique이 필요하다.
- variational inference는 global optimum을 보장하진 않지만 model scaling이 쉽고, stochastic gradient나 parallel training같은 training technique들을 적용하기 쉽다.
variational inference의 단점.
그럼에도 가끔씩 MCMC sampling을 쓰는 경우들을 종종 볼 수 있는 이유는, variational inference를 사용하는 경우 inference하고자 하는 statistics의 closed-form이 연산 가능해야 하기 때문이다. 뭔 소린지 살펴보면:
variational inference에서는 위와 같이 closed-form을 활용하여, p(x)대신 p(x)와 가까운 q(x)를 학습할 것이다. 하지만 이러한 연산이 불가능하다면 아래와 같이 sampling을 활용할 수 있다. (MCMC sampling 활용 예시 논문: HMC)
2. Variational Inference
variational inference란, inference 문제를 optimization 문제로 변형해서 푸는 방법이다. 즉 어떤 true distribution p(x)가 주어졌을 때 p(x)를 직접 inference하지 말고, tractable distribution들의 집합 Q에서 p(x)와 가장 가까운 q를 대신 찾은 후 q(x)를 inference하자는 것이다. 물론 이전에는 집합 Q에서 best function q를 찾는 방식을 사용했지만, 요즘 방식은 q 또한 parameterize해서 표현한다. q를 parameterize할거면 직접 p를 parameterize하면 안되나? 라는 생각이 들겠지만 q를 활용함으로써 간접적으로 연산 가능한 loss function을 유도할 수 있다.
다음으로는 p(x)와 가장 가까운 q(x)는 어떻게 찾지? 라고 생각할텐데, KL divergence를 이용한다! 먼저 p(x) 식을 다시 한번 정의해보자 (p(x) 수식 유도 과정은 Sampling Methods 포스팅 Section1 참조):
그리고 이 p(x)와 q(x)간의 reverse KL을 optimize한다. reverse KL을 사용하는 이유는, forward KL은 p(x)에 대한 expectation으로 정리되는 반면 reverse KL은 q(x)에 대한 expectation으로 나타낼 수 있어 학습하는 모델로부터 데이터를 샘플링할 수 있기 때문이다. forward/reverse KL에 대한 자세한 설명은 이 포스팅을 참조하면 좋을 것 같다.
그런데 reverse KL인 KL(q(x) || p(x))는 사실상 optimize가 불가능한데, 그 이유는 앞서 정의한 p(x)때문이다. KL(q || p)식에도 어쨌든 p가 들어가긴 하는데, 이 때 위의 p(x)식을 그대로 넣어주면 분모의 normalizing constant Z term이 살아남게 된다. 따라서 다음과 같이 식을 고쳐줄 수 있다:
즉 p(x)식의 분모를 p~(x)라고 두고, p~(x)와 q(x)간의 reverse KL을 정리한 것이다. 굳이 이렇게 정리한 이유는, 아래와 같이 정리할 수 있기 때문이다:
앞서 구한 KL 식을 log Z(θ)에 대해 정리해보면, 위와 같은 식을 얻을 수 있다. 이 때 KL(q||p)는 항상 >=0이므로, 위 식의 lower bound는 -KL(q||p~)로 정의할 수 있다. 따라서 우리는 이 -KL(q||p~)만 결과적으로 maximize하면 된다. -KL(q||p~)은 KL(q||p)와 normalizing constant의 차이만 있으므로 -KL(q||p~)를 maximize하면 KL(q||p)도 0에 가까워진다.
결과적으로 -KL(q||p~)식을 정리해보면:
가 되고, 이 식을 optimize하면 된다. 이 식을 variational lower bound라고 부른다. 정리하면, q(x)에 대한 expectation으로 표현되어 연산이 가능해지고, p(x)에서 Z term을 제거한 p~(x)만 inference하면 되기 때문에 연산이 가능해진다. 옛날 방식의 경우 p(x)만 θ로 parameterize하고 q(x)는 후보 집합 Q 중에서 적당한 것을 찾았으나, 아래에 소개할 VAE에서는 q(x) 또한 Φ로 parameterize한다.
즉 MCMC sampling에서는 p(x) 대신 q(x'|x)를 사용하여 많은 데이터를 샘플링하고 이를 통해 값을 inference했다면 (ex. restricted boltzmann machine에서 샘플링한 데이터를 바탕으로 p(x')-p(x)를 inference, Z 제거: PGM 포스팅 2편 참조), 여기서는 p(x)를 inference하는 대신 비슷한 q(x)를 optimize한 뒤 inference한다. 두 가지를 합친 방식도 있는데, variational approximation q(x)를 sampling을 위한 proposal q(x'|x)로 사용하는 방법을 variational MCMC라고 한다고 한다.
3. VAE
이제 variational inference를 공부할 때 항상 보게 되는 VAE에 대해 살펴보자. VAE는 어떤 latent vector z에서 이미지 x를 생성하는 모델이다. 따라서 다음과 같이 p(x|z)를 모델링한다:
위 식은 p(x|z)를 그냥 z에 따라 변하는 conditional gaussian으로 둔 것이다 (conditional gaussian 설명 참조). 그런데 우리는 p(x|z)뿐만 아니라, 어떤 z를 넣어야 어떤 이미지를 생성할 수 있는지 알기 위해서는 posterior p(z|x)도 inference할 수 있어야 한다:
그리고 여기서는 prior p(z)를 normal gaussian으로 두더라도, posterior p(z|x)는 분모 때문에 결과적으로 intractable하다. 따라서 variational inference를 적용해서 p(z|x)대신 q(z)를 optimize 해보기로 하자. 왜 q(z)를 optimize하는지 내용은 뒤에 나온다. 어쨌든 q와 p의 reverse KL을 구해보면:
다음과 같이 정리할 수 있다. 사실 아까는 KL(q||p~)를 구했으면서 왜 여기서는 KL(q||p)를 구하지? 라고 생각할 것이다. 결과적으로 보면 맨 오른쪽의 주황색 term은, p(z|x)를 bayes rule로 p(x|z)p(z) / p(x)로 바꿀 때 분모에서부터 온 것을 알 수 있다. 즉 여기서는 p(x|z)p(z) = p~(x)가 되고, log p(x)가 log Z(θ)가 되는 것이다. 앞에서 정의했던 식을 다시 살펴보면 위 식이랑 똑같은 것을 알 수 있다:
결과적으로 파란색 term은 앞서 구했던 variational lower bound KL(q||p~)와 동일하다. 주황색 term log p(x)는 normalizing constant Z이므로 무시할 수 있다. 따라서 이 파란색 term만 optimize하면 된다!
그런데 문제점은, 우리가 아까 q(z|x)대신 q(z)를 사용했다는 것이다. 위 식을 다시 보면 결과적으로 q는 p(z)와의 KL을 구하는 term에 남게 되기 때문에 q(z)를 사용하는 것이 맞다. 하지만 우리는 p(z|x)에 가까운 q를 구하고 싶은 것이므로, x가 변함에 따라 q(z)도 각각 다른 것을 사용해줘야 하고 따라서 매우 불편하다. 따라서 결국 q(z) 대신 input x에 따라 z를 generalize할 수 있는 q(z|x)를 대신 사용하게 되는데, 이를 amortized inference라고 한다. 따라서 amortized inference를 적용해서 최종 lower bound loss 식을 다시 적어보면:
여기서 왼쪽 KL term은 그대로 gradient ascent가 가능하나, 오른쪽의 q에 대한 expectation term은 reparameterization trick이 필요하다. reparametrization trick은 다음 포스팅에서 정리할 예정이다.
4. RL
RL에서 optimal policy를 찾는 내용 또한 probabilistic inference로 해석이 가능하다. 일례로 DPO논문의 Appendix A.1에서는, PPO objective로부터 optimal policy를 다음과 같이 도출한다:
맨 위는 PPO objective이고, 결과적으로 파란색 부분을 optimal policy π*로 둔다. optimal policy π*로 바꾸고 정리해보면 아래와 같은 식이 나오는데:
정말 신기하게도 앞서 우리가 계속 봤던 아래 식과 똑같다!
즉 위에서 주황색 term은 KL(q||p~) = KL(q||p) - log Z(θ)로 표현된 것이다. 결국 PPO objective도 variational lower bound와 같은 것을 optimize하게 된다.
RL과 probabilistic inference 간의 관계는 다음 포스팅에 정리할 예정이다!
[References]
https://ermongroup.github.io/cs228-notes/
Contents
Contents These notes form a concise introductory course on probabilistic graphical modelsProbabilistic graphical models are a subfield of machine learning that studies how to describe and reason about the world in terms of probabilities.. They are based on
ermongroup.github.io
https://www.cs.cmu.edu/~epxing/Class/10708-20/lectures.html
10-708 – Lectures (tentative)
10-708 – Lectures (tentative) 2020 Spring Lecture Date Topic Slides Videos Further Reading Note Scribe Design of GMs 01 Jan 13 Introduction to GM: (Eric) - Association between random variables - Marginal/partial correlation - Conditional independence pdf
www.cs.cmu.edu