Courses/Probabilistic Graphical Model

[4-1] Model Inference: Sampling Methods(MCMC, Gibbs sampling)

모끼주인 2024. 9. 23. 20:51

아래 references의 강의 및 자료들을 공부하고 짧게 정리한 내용입니다! 저도 공부하면서 정리한 내용이라 틀린 것이 있다면 언제든 댓글 달아주신다면 감사하겠습니다~ :D

____

 

만약 왼쪽 그림과 같은 변수들 간의 관계가 주어지고 분포 p(x)를 inference해야 한다면, 먼저 joint probability p(x, y, z) = p(x|y,z) p(z|y) p(y)를 계산한 뒤 y, z에 대해 marginalize하는 방식을 사용한다. 이 때 각각의 conditional probability를 모두 구할 수 있고 y,z에 대한 marginalization(적분)을 연산할 수 있다면 exact inference가 가능하다. 주로 p(x)가 gaussian과 같이 연산 가능한 형태를 가지고 있다고 가정했을 때 가능하다.

 

그러나 대부분 모델의 경우 변수들 간의 관계가 훨씬 더 복잡하거나 (bayesian network), 또는 실제 세계를 모델링하는 p(x)는 gaussian같은 간단한 확률 분포로 나타낼 수 없기 때문에 marginalization이 불가능하고, 결과적으로 exact inference가 불가능하다. 따라서 appoximate inference를 사용하는데, 이것에는 크게 두 가지 방법이 있다. 

 

그 중 하나는 variational inference이고, 오늘 포스팅할 내용은 sampling methods이다. 

 

1. Monte Carlo sampling

 

Monte Carlo Sampling은 이름은 거창하지만 생각보다 단순하다. 구하고자 하는 complex distribution p(x)에 대한 stochastic representation을 만들자는 것인데, 위의 수식과 같이 complex distribution p(x)에 대한 적분을 sampling을 통해 근사하는 것이다. 따라서 위와 같이 closed-form으로 나타냈을 때 생기는 p(x)에 대한 expectation을 approximation할 수 있다.

 

그런데 여기서 궁금한 점이 생길 것이다:

  1. 어떻게 p(x)에서 sample을 추출할 수 있는지?
    -> 어떤 simple distribution q(x)를 대신해서 이용하게 된다.
  2. p(x)에서 sampling한 것들 중 더 유용한 것을 어떻게 찾을 수 있는지?
    -> p(x)를 대신하는 distribution q(x)와 p(x) 간의 비율을 이용한다.
  3. 충분한 양이 sampling되었는지 어떻게 알 수 있을지?
    -> Section 4에 설명될 예정이다.

여기서 헷갈릴 만한 점은 2번에서 q(x)와 p(x) 간의 비율을 이용한다고 했는데, p(x)를 어떻게 구할 수 있냐는 것이다. 일단 p(x) 수식을 한번 살펴보자 (PGM 포스팅 2편 참조, 2편 식에서 cliques를 없앤 식입니다):

 

 

간단하게 다시 한 번 설명하면, 어떤 특정 값 x에 대한 logit을 함수 ϕ로 추정한다고 하면 위와 같은 식이 된다. 분모는 모든 가능한 x에 대해서 적분한 값으로, p(x)를 0-1사이 확률값으로 만들기 위한 normalizing constant다. 분모를 줄여서 Z라고 표기하기로 하자. 이게 뭔가 싶은 생각이 들수도 있는데, 만약 x가 language domain의 token들이라고 생각해보면 하나의 token을 예측하기 위한 softmax(=boltzmann distribution)와 같다는 것을 알 수 있을 것이다.

 

현재 모델을 "inference"하는 것을 다루고 있으므로, 함수 ϕ는 어떻게든 주어져 있다고 가정하자. (inference가 가능해야 loss를 구할 수 있어 training도 가능해진다. training방법에는 MLE 또는 EM algorithm이 있다.) 문제는 Z를 구하는 것인데, p(x)가 exponential family이고, x가 한정된 discrete space에 속하는 variable이라 값을 구할 수 있다고 쳐도 전체 sample에 대한 mean을 구해야하므로 비효율적이고 샘플링이 어려워진다. 일반적인 경우 Z는 구할 수 없다.

 

따라서 sampling 방법들에서도 이 Z를 처리하는 것이 중요하다!

1) [Example] Rejection Sampling

대표적인 예시로 rejection sampling은 monte carlo sampling에 rejection을 추가한 것이다. 만약 sampling해야하는 distribution Π(x) = Π'(x)/Z가 있다고 해보자. 앞서 언급한 대로 Π(x)는 Z 때문에 연산이 어려우므로, evaluation이 쉬운 Π'(x)를 대신 사용한다.

 

또한 같은 이유로 Π(x)의 Z 때문에, Π(x) 대신 어떤 simple distribution Q(x)로부터 sampling하기로 한다. 즉 수식으로 나타내면:

 

 

이 되는데, 오른쪽에 p(x)/q(x)를 기반으로 한 acceptance 조건이 있는 것을 볼 수 있다. 이 acceptance 조건이 어떻게 유도되었는지는 아래 식을 통해서 알 수 있다:

 

 

Q(x)앞에 k는 Π'(x)/kQ(x)가 0에서 1사이의 확률값을 갖도록 해주는 coefficient이다. (Π(x)대신 Π'(x)를 사용하다보니 확률값이 되도록 강제로 만들어준다.)

2) Limitations

이러한 monte carlo sampling methods의 한계는 바로 Q(x)를 설정하는 것에 있다. 일단 p(x)와 매우 유사한 Q(x)를 만드는 것은 매우 어렵다. 또 p(x)를 아주 간단한 gaussian distribution으로 가정한 뒤 variance만 아주 조금 차이나도록 Q(x)를 세팅하더라도, input이 high-dimension space로 갈수록 두 distribution의 volume차이가 매우 커지고, acceptance rate가 매우 낮아진다고 한다. 즉, 아주 간단한 distribution에서 아주 가까운 Q(x)를 설정하더라도 실제 p(x)에서 뽑힌 것과 유사한 샘플을 뽑기 매우 어려워진다는 것이다.

 

2. Markov Chain Monte Carlo

따라서 fixed Q(x)를 사용하는 대신 adaptive하게 Q(x'|x)를 사용하는 방법이 바로 Markov Chain Monte Carlo 방식이다. 여기서 x는 기존에 가지고 있던 sample, x'은 새로운 sample을 의미하는데, 즉 바로 전 step에서 가지고 있던 sample만을 기반으로 다음 sample을 뽑겠다는 것이다. 즉, Q(x'|x)는 markov property를 가짐을 가정한다.

 

중요한 점은 Q(x'|x)는 conditional distribution이므로, x가 변함에 따라 Q값도 변하게 되고 즉 Q(x'|x)는 x에 대한 함수가 된다. 따라서 Q(x'|x)를 간단한 distribution (ex. conditional gaussian)으로 두더라도 훨씬 다양한 sample을 뽑을 수 있다.

1) Metropolis-Hastings algorithm

이 MCMC 방식으로 sampling하는 알고리즘이 Metropolis-Hastings 알고리즘이다. 먼저 알고리즘 방식에 대해 간단히 요약하면 다음과 같다:

 

 

즉 매 timestep 마다 Q(x'|x)로부터 x'을 sampling하고, acceptance probability인 A(x'|x)의 값에 따라 새로운 데이터로 이동할지, 현재 데이터에 머물러 있을지 결정하는 것이다. 이렇게 알고리즘이 수렴할 때까지 반복하는 것을 burn-in이라고 표현한다. sampling한 data의 distribution이 실제 p(x)로 수렴하게 되면 이제 E_p(x)[f(x)]를 inference할 수 있을 것이다.

 

그렇다면 이 acceptance probability는 어떻게 생겼을지 살펴보자:

 

맨 오른쪽을 보면, 결국 앞서 살펴본 rejection sampling과 유사하게 importance weight를 사용하지만, 두 importance weight의 ratio를 사용하는 것을 볼 수 있다. 즉, 분자는 새로 뽑은 data x'의 importance, 분모는 다시 원래 data로 move back하는 것의 importance가 된다. 만약 두 값의 비율이 1보다 커지면 어쨌든 accept하는 것이므로 A(x'|x) = 1로 둔다. (A(x'|x)는 0-1사이 확률값)

 

이렇게 ratio를 사용하는 이유는, P(x)나 P(x')을 직접 구하지 않고 P(x')/P(x)을 구함으로써 Z를 제거할 수 있기 때문이다. (PGM 포스팅 2편에서 봤던, restricted boltzmann machine의 learning에서 Z를 없애는 방법과 상당히 유사하다.)

 

이 알고리즘이 동작하는 방식을 눈으로 살펴보자.

 

기존 방식(왼쪽)의 경우, p(x)와 최대한 가까운 q(x)를 fix해놓고 샘플링했다면, MCMC sampling(오른쪽)의 경우 distribution Q가 data point x에 따라 계속 움직이는 것이다. 오른쪽 그림을 보면 두 mode 사이 확률이 작아지는 구간이 있는데, 이 때 왼쪽 mode에서 오른쪽 mode로 reject되지 않고 어떻게 움직일 수 있을지 살펴보자.

 

출처: https://www.cs.cmu.edu/~epxing/Class/10708-20/lectures/lecture09-MC.pdf

 

위 그림에서와 같이 두 mode 사이 확률이 낮은 구간에서 데이터가 샘플된다면, A(x'|x)의 값이 0에 가까워져 reject될 것이다. 반면 조금 더 오른쪽 mode에 가까운 데이터를 샘플한다면 accept될 확률이 높아질 것이고, 맨 오른쪽 그림처럼 오른쪽 mode도 샘플링할 수 있도록 이동할 수 있을 것이다. 따라서 여러 mode를 샘플링하기 위해서는 적당히 variance가 큰 gaussian을 Q로 두는 것이 유리하다. 반면 variance가 너무 크면 A(x'|x)의 값이 작아지기 때문에 적당한 variance를 선택하는 것이 필요하다. 아래 Section 4에 좀 더 자세히 설명되어 있다.

2) Why this work?

이 A(x'|x)를 acceptance probability로 사용하면, sufficient sample이 주어진 경우 true distribution p(x)에 가까워 지는 것이 보장된다고 한다. 왜 보장되는 것일까?

 

그 이유는 Q(x'|x)가 markov property를 가진다고 가정했기 때문이다. Markov chain에서는 다음과 같이 reversible한 경우(=detailed balance 조건을 만족하는 경우) π(x)는 stationary distribution으로 수렴한다고 한다:

 

 

증명은 아래 reference에도 첨부한 Eric Xing의 CMU강의에 되어있다..! 일단 증명은 생략하고, 여기서 π(x)는 어떤 state x의 probability distribution이다. 여기서는 sample x가 발생할 probability distribution으로 보면 될 것 같다. T(x'|x)는 transition matrix로, state x에서 state x'으로 이동할 확률을 의미한다. 여기서는 x ~ Q(x'|x)에서 샘플링하고, A(x'|x)에 따라 acceptance 여부를 결정하므로 T(x'|x) = Q(x'|x)A(x'|x)가 된다. 이렇게 transition matrix를 정의했을 때 앞서 언급한 detailed balance 조건을 만족하고, 결과적으로 π(x)는 true data distribution p(x)로 수렴하게 된다고 한다! (증명은 Eric Xing의 CMU강의 참조)

 

3. Gibbs Sampling

Gibbs sampling은 Metropolis-Hastings의 special case로, 기존 방식보다 computation 및 memory cost를 더 줄일 수 있다. cost를 줄여주는 이유는 한 번에 하나의 random variable을 샘플링하기 때문이다. 이 때 sampling variance를 줄이기 위해 몇 random variable은 marginalize하고 계산하기도 한다고 한다.

 

즉 p(x) = p(x1, ..., xn)일 때, Metropolis-Hastings에서는 한 timestep 당 x1~xn을 전부 sampling했다면 Gibbs sampling에서는 한 timestep 당 하나의 variable만 sampling하는 것이다. 알고리즘을 간단히 표현해보면:

 

출처: https://www.cs.cmu.edu/~epxing/Class/10708-20/lectures/lecture09-MC.pdf

 

즉 매 step마다 어떤 random variable을 먼저 sampling할 것인지 순서를 정한 후에, 그 variable만 빼고 나머지 variable에 condition된 distribution p(x_i | x_1, ..., x_i-1, x_i+1, ..., x_n)에서 새로운 x_i'을 sampling하는 것이다. 이 때 다음 random varaible x_j를 sampling하기 전 x_i 자리에 먼저 x_i'값을 update한다.

 

즉, Gibbs Sampling은 다음과 같은 distribution에서 sampling한다고 볼 수 있다:

 

Metropolis-Hastings에서와 하나의 variable이 들어가고 빠지고의 차이가 있는 것을 볼 수 있다. 앞서 Gibbs sampling은 Metropolis-Hastings의 special case라고 했는데, 그 이유는 아래와 같다:

 

출처: https://www.cs.cmu.edu/~epxing/Class/10708-20/lectures/lecture09-MC.pdf

 

결과적으로 위에서와 같이 Gibbs Sampling에서의 A(x'|x)식을 정리해보면, 항상 acceptance probability가 1이 됨을 알 수 있다. 따라서 Gibbs Sampling은 모든 sample을 accept하는 special case이다.

 

4. Practical aspects of MCMC

MCMC로 실제로 샘플링을 하게 된다면, (1) proposal Q(x'|x)를 잘 고른 것인지, (2) 언제 iteration(burn-in)을 멈춰야 할지에 대해 고민이 될 것이다. 여기서는 두 가지 질문에 대한 방법을 설명한다.

1) proposal Q(x'|x)의 적합성 확인

Acceptance rate.

앞서 예시에서 잠시 언급했지만, 아래 그림과 같이 Q(x'|x)의 variance가 practical setting에서 중요하다. 왜냐하면 variance가 낮은 경우 acceptance rate가 크겠지만 p(x)를 exploration하지 못하고, variance가 너무 크면 그 반대의 상황이 일어나기 때문이다. 따라서 적절한 중간값을 찾는 것이 중요한데, 일반적으로 MH algorithm에서의 general guideline으로는 ~0.5까지의 acceptance rate을 가지는게 좋다고 한다.

출처: https://www.cs.cmu.edu/~epxing/Class/10708-20/lectures/lecture09-MC.pdf

 

Autocorrelation function.

accept rate 외에도 중요하게 볼 점은 sample간의 autocorrelation인데, MCMC chain에서는 항상 서로 가까운 timestep에서 샘플링 된 샘플들이 highly correlated 되어있다고 한다. 따라서 k개의 sample sequence간의 covariance를 측정함으로써 correlation 정도가 얼마나 심한지 확인할 수 있다고 한다. correlation이 낮을수록 sample들이 IID(independent and identically distributed)에 가깝다고 볼 수 있으므로 sample efficiency가 크다.

출처: https://www.cs.cmu.edu/~epxing/Class/10708-20/lectures/lecture09-MC.pdf

 

2) 언제 burn-in을 멈출 것인지 확인

Sample values per time.

 

만약 iteration step t에서 random variable x와 y의 값을 위와 같이 찍어본다고 하자. 이 때 4개의 서로 다른 sampling process를 돌렸다고 하자. 만약 data distribution이 p(x)에 가까워졌다면 왼쪽 그림과 같이 4개 process간 값의 차이가 크지 않을 것이다. 반면 샘플링을 시작한지 얼마 되지 않았다면 오른쪽 그림과 같이 차이가 난다고 한다.

 

Log-likelihood per time.

 

또는 우리가 loss를 tracking하는 것처럼 log-likelihood를 tracking할 수 있다고 한다. 각각의 random variable의 dimension이 커질수록 위와 같이 값을 찍는 방법은 어려워지므로 이 방법을 많이 쓴다고 한다. data distribution이 p(x)와 가까워질수록 log likelihood가 점점 커지고, 증가율이 어느 수준 이하로 작아져 수렴하게 되면 burn-in을 멈춰야 될 타이밍이라고 한다.

 

 

다음 포스팅에서는 sampling methods보다 최근 더 많이 쓰이는 방법인 variational inference에 대해 정리해보려고 한다!

 

[References]

https://ermongroup.github.io/cs228-notes/

https://www.cs.cmu.edu/~epxing/Class/10708-20/lectures.html