[3-2] Model Learning: EM algorithm
아래 references의 강의 및 자료들을 공부하고 짧게 정리한 내용입니다! 저도 공부하면서 정리한 내용이라 틀린 것이 있다면 언제든 댓글 달아주신다면 감사하겠습니다~ :D
____
지난번 포스팅 PGM learning 3-1편 에서는 fully-observed graphical model일 때의 대표적인 learning 방식인 maximum likelihood estimation에 대해서 살펴보았다. 이번 포스팅에서는 paritally-observed GM(VAE와 같이 hidden variable을 사용하는 경우)에서의 learning 방식인 EM algorithm에 대해 살펴보려고 한다!
1. Parameter learning for partially observed GM
1.1 Why latent variable models?
news article을 생성하는 language model을 생각해보자. 물론 최근의 large language model은 매우 큰 모델 사이즈 및 데이터를 기반으로 p(x) 자체를 모델링하지만, 일반적으로 p(x) 자체를 모델링하는 것보다 각 article의 topic t에 기반해서 문장을 생성하는 모델 p(x|t)를 구하는 것이 더 정확할 것이다. 즉 bayes rule을 사용해서 p(x) 대신 p(x|t)p(t)를 모델링하는 것이다. 이 때 random variable t는 supervised learning에서처럼 주어질 수도 있지만, 관측이 불가능한 latent variable이라고 가정해보자. 이러한 경우 모델은 PGM learning 3-1편에서 소개한 fully-observed가 아닌 partially-observed GM이다.
Latent variable models.
latent variable model은 두 random variable x, z에 대한 probability distribution으로 아래와 같이 정의할 수 있다:
여기서 x는 관측 가능한 variable, z는 hidden variable을 의미한다. Directed GM인 경우와 undirected GM인 경우 모두 위와 같은 식으로 모델링되는데, z를 기반으로 x를 생성(z->x)하는 generative model(directed GM)과 z를 label로써 이용하는 discriminative model(undirected GM) 모두 어차피 x와 z간의 joint probability를 구해야 하기 때문이다.
또한 z는 상황에 따라 discrete 또는 continuous variable 모두가 될 수 있다. 만약 label에 따라 데이터를 clustering하는 discriminative model을 설계한다면, z는 label을 의미하므로 discrete variable로 정의하는 것이 알맞을 것이다. 반면 dimensionality reduction을 통해 latent vector z에 정보를 압축하고 이로부터 이미지 x를 생성하고자 한다면 z는 continous variable로 정의하는 것이 좋을 것이다. 따라서 핵심 아이디어는 같으며 상황에 따라 generative / discriminative model 모두에 적용될 수 있다.
Latent variable model 장점.
- 맨 위에서 topic에 따라 news article을 생성하는 모델을 생각해보자. 우리는 p(x) 대신 p(x|t)p(t)를 모델링하게 되는데, 이 때 p(t)에 사전에 알고 있던 topic들의 분포라던가, categorical distribution의 conjugate prior 등을 대입할 수 있다. 따라서 model을 정의할 때 prior knowledge를 활용할 수 있다는 장점이 있다.
- domain에 따라 어떤 input variable은 관찰이 불가능할 수 있는데 이러한 경우 활용 가능하다.
- 아래 예시에 소개할 gaussian mixture와 같이, 더 좋은 expressive power를 가진 모델을 설계할 수 있다.
[Example of LVM] Gaussian mixture models
gaussain mixture model은 위 그림과 같이 말 그대로 여러 개의 gaussian model을 합친 것이다. 당연히 single component gaussain보다 더 좋은 expressive power를 갖게 되고 실제 세계를 모델링하는 데 더 가까워질 수 있다.
gaussain mixture model은 unsupervised clustering task에 사용될 수 있다. 만약 그림 (c)와 같은 데이터가 주어지고, 그림 (a)와 같이 K개의 gaussain component로 구분하는 경우를 생각해보자. 즉 어떤 data point (x_i, z_i)가 주어졌을 때,
다음과 같이 x는 d차원의 실수 vector, z_i는 discrete label로 정의할 수 있다. 즉 z는 multinomial distribution을, p(x|z)는 multi-variate gaussain distribution을 따른다고 가정하면 다음과 같이 정의할 수 있다:
이 때 n은 data point의 index를 의미하며, z_n^{k}는 k번째 class에 해당할 때는 1, 나머지 경우에는 0이 된다. probability distribution에 관한 내용은 이 포스팅에 업로드할 예정이다.
즉 이를 기반으로 p(x) 수식을 나타내기 위해서는 p(x,z) = p(x|z)p(z)를 z에 대해 marginalize하면 된다. 식으로 정리해보면:
첫 번째 줄은 p(x,z) = p(x|z)p(z)이고, 두 번째 줄은 각 p(z)와 p(x|z)에 앞서 정의한 multinomial, gaussian distribution을 대입한 것이다. 갑자기 ∑_{z_n}이 등장하는데, 결과적으로 z_n^{k}는 k번째 class에 해당할 때만 1이 되므로 앞줄의 ∑_k와 같은 의미가 된다. 맨 아래에서는 이를 다시 k에 대한 sum으로 다시 깔끔하게 정리하게 된다.
어쨌든 결과만 보면 K개 gaussian의 선형결합으로 이루어진다. 이 때 π_k는 mixture proportion, N(x;μ,∑)는 mixture component라고 부른다.
굳이 여기서 gaussian mixture의 p(x)를 정의해본 이유는, 결과적으로 p(x)를 inference하기 위해서는 hidden variable인 z에 대한 distribution p(z)가 필요하기 때문이다. 위 수식의 맨 마지막 줄에도 p(z)의 parameter인 π_k가 등장한 것을 볼 수 있다. 따라서 이 경우 gaussian mixture는 partially observed GM이다.
1.2 MLE로 optimize했을 때 문제점
만약 partially observed GM을 그냥 maximum likelihood estimation으로 학습한다고 생각해보자. 우리의 목표는 marginal distribution인 p(x)를 learning하는 것이다.
위 수식에서 L_full의 경우 x와 z가 모두 observable할 때의 log likelihood를, L_partial의 경우 z가 hidden variable일 때 log likelihood를 나타낸 것이다. L_full의 경우 marginalization이 필요 없지만 L_partial의 경우 z에 대해 marginalize한 것을 볼 수 있다. 이렇게 marginalize함으로써 생기는 문제점은 최종 loss 식이 sum of log로 분해되지 않는다는 것인데, 따라서 결과적으로 loss function의 convexity를 보장할 수 없다는 문제가 생긴다.
fully-observe GM에서 loss의 경우, 앞서 PGM learning 3-1편에서 소개했듯이 p(x,z)를 exponential family를 이용하여 정의하면 convexity를 보장할 수 있었다. 하지만 partially-observe GM의 경우, p(x|z)와 p(z)가 모두 exponential family더라도 서로 간의 곱이므로 convexity가 보장되지 않는다. 위에서 언급했듯이 sum of log로 분해하여 convex combination으로 만들고 싶지만 식이 변형되지 않는다.
따라서 이러한 non-convex function을 optimize하기 위해서는 approximate learning이 필요하다고 한다!
1.3 Lower Bound를 이용한 complete log likelihood 정의
approximate learning을 위해 앞서 소개한 loss function을 변형하는 trick을 사용한다. 간단하게 trick을 요약하면 z에 대한 marginalization을 z|x에 대한 expectation으로 바꾸고, 대신 이로 인해 생기는 entropy term을 제거하면 lower bound를 얻을 수 있다. 먼저 expectation 형태로 바꾸는 과정을 살펴보자:
위 식에서 보듯이, 첫 줄에서 둘째줄로 넘어갈 때 갑자기 분모와 분자에 q(z|x)를 곱해주는 것을 볼 수 있다. 이는 ∑q(z|x)부분을 만들기 위한 것인데, 식을 q(z|x)에 대한 expectation으로 바꿀 수 있기 때문이다. 원래 바꾸기 전 식에서도 p(x) = ∑ p(z)p(x|z)니까 z를 sampling해서 p(z)에 대한 expectation으로 만들면 되는데, 왜 굳이 q(z|x)를 이용하는지 궁금할 것이다. 그 이유는 q(z|x)를 learning하지 않고 learning하는 distribution p(x, z)를 통해 얻을 것이기 때문이다. (bayes rule을 사용하면 joint distribution을 통해 conditional distribution을 유추할 수 있다.)
맨 마지막 줄의 부등호는 Jensen's inequality를 이용하여 log를 sum 안으로 넣어준 것인데, 이를 통해 sum of log로 분해하여 convex combination으로 만들 수 있다(q와 p를 exponential family 함수로 가정하는 경우). Jensen's inequality는 convex optimization을 검색하다 보면 나오는데 더 알고 싶은 경우 관련해서 찾아보면 좋을 것 같다.
결과적으로 Jensen's inequality를 통해 convex combination으로 만들었으므로 loss 식을 정리해보면:
위에서 보듯이 결과적으로 lower bound는 log p(x,z)에 대한 expectation과 q(z|x)에 대한 entropy term으로 분해된다. 좋은 점은 오른쪽 entropy term이 q에만 관한 식이므로 model parameter θ에 dependent하지 않고, 결과적으로 loss 식에서 제거할 수 있다는 것이다. 여기서 주의할 점은 q(z|x)를 이미 아는 distribution으로 가정하기 때문에 model parameter θ와 independent하다는 점이다. 앞서 언급했듯이 우리는 q(z|x)를 learning하지 않고 learning하는 distribution p(x, z)를 통해 얻을 것이다. 예를 들어 q를 multinomial distribution으로 둔다면 multinomial distribution의 canonical parameter ϕ를 p(x, z)를 통해 계산할 것이다. 이에 대한 예시는 아래 Section 1.4에 소개되어 있다.
어쨌든 결과적으로 우리는 Lc만 maximize하면 되는데, 살펴보면 새로운 loss Lc는 fully observable일 때와 똑같이 p(x,z)를 optimize한다. 즉 marginalization이 없는 complete log likelihood 형태이다. 따라서 똑같이 maximum likelihood estimation을 통해 모델을 learning할 수 있다.
1.4 EM algorithm
즉 위에서 최종적으로 우리가 학습해야 하는 loss term은 다음과 같았다:
앞서 언급했듯이 fully observable일 때와 똑같이 p(x,z)에 대해 optimize함으로써 z를 관측 가능한 variable처럼 취급하게 되는데, 사실상 z는 관측 가능하지 않다. 따라서 어떤 임의의 distribution q에서 z를 샘플링한 뒤 그 z가 참값인 것처럼 사용하고, 다시 q를 update하는 iterative한 방식을 사용한다.
정리해보면
- model parameter를 θ_0로 initialize 한다.
- repeat
- E step: 주어진 data x에 대해, posterior distribution q(z|x)를 계산한다.
- M step: model parameter θ를 다음과 같은 식으로 update한다.
아래 그림은 EM algorithm이 어떻게 최적점에 수렴하는지 그림으로 나타낸 것이다. x축의 Q(x)는 q(z|x)를, y축의 θ는 model parameter를 의미한다. F(Q,θ)는 우리가 구하려는 lower bound의 값 분포라고 보면 된다.
그림에서 보듯이 먼저 x축(q(z|x))방향으로 이동하고, 이후 y축(θ)방향으로 이동하면서 F(Q,θ)의 최대점으로 수렴하는 것을 볼 수 있다. 따라서 EM algorithm은 F에 대한 coordinate-ascent이다.
[Example] Learning gaussian mixture with EM algorithm
앞서 latent variable model 중 하나로 소개한 Gaussian mixture model에서 learning 예시를 살펴보자.
(1) E step
앞서 언급했듯이 우리는 E step에서 posterior q(z|x)를 추정할 것이다.
갑자기 q(z|x)대신 p(z|x)가 등장하는데, 둘이 똑같은 것으로 보면 된다. p(z|x)를 추정하는 distribution의 의미로 q(z|x)를 사용했지만 둘이 똑같은 것이다.
어쨌든 첫 번째 줄에서는 bayes rule을 이용했고, 두 번째 줄은 앞서 Section1.1의 [Example]에서 정의한 대로 p(x|z)에는 multivariate gaussian의 pdf를, p(z)에는 multinomial distribution의 pdf를 대입한 것이다. 앞서 Section1.1의 [Example]에서는 p(x) 또한 정의했었는데, 이를 구성하는 parameter에는 π,μ,∑가 있었다. 이와 같이 learning할 distribution p(x)의 parameter를 이용해서 posterior p(z|x)를 추정할 수 있다.
사실상 E step은 결과적으로 모델을 inference하는 것과 같다. p(z|x)는 multinomial distribution으로 가정했으므로 결과적으로 k-dimension의 vector가 될 것인데, 우리는 inference를 통해 K개의 class에서 "soft" probability를 할당해 주었다. 즉 참이 아닌 것을 참인 것처럼 가정해서 z의 soft instantiation을 했다고 볼 수 있다.
(2) M step
M step에서는 z를 observable하다고 가정하고, fully-observe GM에서와 같이 p(x,z)를 optimize할 것이다:
p(x,z) 식을 풀어보면 위와 같은 형태가 된다. 첫 번째 줄은 bayes rule이고, 두 번째 줄 이후부터는 Section1.1의 [Example]에서 정의한 대로 p(z)에는 multinomial distribution의 pdf를, p(x|z)에는 multivariate gaussian의 pdf를 대입한 것이다. 이 loss function을 이제 각각의 parameter π,μ,∑에 대해 미분하고, log likelihood가 최대화 되도록 parameter를 update하면 된다. 결과적으로 미분값을 구해보면, μ_k는 class k에서의 mean of data, ∑_k는 class k에서의 variance of data 식이 나온다. 또 π_k의 경우 p(z|x)에 대한 empirical mean 식이 나온다. 신기하게도 gaussian, bernoulli distribution의 MLE 해에서와 같은 결과가 나온다.
어쨌든 이렇게 posterior 추정 -> parameter update를 반복하며 모델을 학습할 수 있다!
Convergence of EM algorithm.
EM algorithm이 수렴할지에 대해 궁금할텐데, 사실 궁극적으로 우리가 최적화하고자 하는 처음 loss term은 non-convex이기 때문에, global optimum으로의 수렴을 보장하지는 않는다고 한다. 실제로 local optimum에 수렴하는 경우가 많고 initialization의 영향을 크게 받는다고 한다.
하지만 최근 딥러닝의 발전에 따라 model의 expression 능력이 크게 증가하고, non-convex에서도 잘 작동할 수 있는 다양한 optimizer (Adam, ...)들이 소개되면서 우리는 local optimum인지 아닌지는 모르겠지만 거의 global optimum처럼 행동하는 모델들을 그냥 사용하고 있다. 그렇다면 EM algorithm이 왜 수렴하는지 expressive model에서 많이 쓰이는 variational inference와 비교를 통해 유추해보기로 하자.
1.5 EM algorithm과 Variational Inference의 관계
사실 EM algorithm을 살펴보며 variational inference와 매우 유사하다는 생각이 많이 들었을 것이다. variational inference에 관한 내용은 사실 다음 포스팅에 정리할 계획이다. 일단 VAE에 관한 내용은 이 포스팅을 참조하면 좋을 것 같다.
어쨌든 variational inference에서 p(x)를 모델링했던 내용을 다시 살펴보면:
KL term을 제외하고 오른쪽 식은 앞서 우리가 구했던 lower bound와 똑같다!
즉 이 식을 EM 관점에서 해석해보면
- E step에서 posterior p(z|x)를 추정하고 q(z) 자리에 대입하는 부분
-> KL term을 0으로 만드므로 ELBO를 먼저 tight하게 만드는 과정이다. - M step에서 나머지 log p(x,z)를 maximize한다.
결과적으로 variational inference의 식을 iterative하게 maximize하는 것과 같다!
EM algorithm에서 posterior p(z|x)를 추정할 때, z의 dimension이 작다면 충분히 추정할 수 있겠지만 (ex. Gaussian mixture처럼 1-d categorical vector일 때), p(z|x)가 모델링해야하는 distribution이 복잡해진다면 EM algorithm을 사용하기 어렵다. 따라서 복잡한 이미지를 생성하는 VAE와 같은 모델에서는 variational inference를 사용하고, posterior q(z|x) 또한 parameterize하는 것을 볼 수 있다.
다음 포스팅에서는 variational inference와 MCMC sampling에 대해 소개할 예정이다!