논문 리뷰/information theory

[논문 읽기] MINE : Mutual Information Neural Estimation(ICML 2018)

모끼주인 2019. 4. 17. 17:56

 논문 링크 : https://arxiv.org/abs/1801.04062

 

MINE: Mutual Information Neural Estimation

We argue that the estimation of mutual information between high dimensional continuous random variables can be achieved by gradient descent over neural networks. We present a Mutual Information Neural Estimator (MINE) that is linearly scalable in dimension

arxiv.org

*이 글은 논문 스터디를 진행하며 도움을 받아 작성한 글입니다.

*딥러닝 + 수학 초보자 주의(;;)

 

 

1. Introduction

결론부터 이야기하자면, neural network에 사용할 수 있는(scalable, flexible, trainable via back-prop) mutual information estimator인 MINE에 관하여 설명하고, 그 성능을 증명한 논문이다.

 

즉 요약하면, mutual information의 lower bound식에 neural network parameter를 추가하여, back-prop으로 이를 학습시킬 수 있다는 것을 수학적으로 증명하고, 이렇게 얻은 mutual information식을 실제 모델의 objective function에 적용하여 그 성능이 향상됨을 보여준 것이다.

 

구체적으로는 GAN의 objective function에 MINE을 추가하여, GAN에서 일반적으로 발생되는 문제인 mode-dropping을 해결할 수 있다고 주장한다. (mode-dropping관련 이해는 이 블로그를 참고 : http://dl-ai.blogspot.com/2017/08/gan-problems.html)

 

2. Background

1) Mutual Information

 

*Shannon Entropy, KL divergence 등 information theory에 관한 내용입니다. 관련 내용은 검색하면 쉽게 찾아보실 수 있어욤 ㅎㅅㅎ

 

Mutual Information이란, 두 random variable X와 Z가 얼마나 상호의존적(mutual dependence)인지에 관한 정보량을 의미한다. 아래 수식과 같이 Shannon entropy를 통하여 나타낼 수 있다.

또한 이 수식은 KL divergence로도 표현 가능한데, 아래와 같다.

(Px * Pz의 뜻은, marginal distribution인 Px와 Pz의 곱을 의미)

KL divergence란 확률 분포 P와 Q사이의 거리(유사도)를 나타낸 것으로, 아래와 같은 수식으로 표현 가능하다.

 

여기서 잠깐 확률과 통계를 짚고 넘어가 보면 ㅇㅅㅇ,,

dP는 확률 변수 P를 미분하였으므로 확률 밀도 함수 p(x)를 의미하고, 마찬가지로 dQ는 q(x)를 의미하게 된다.

이를 P에 대한 기댓값으로 나타내었으니 log(p(x)/q(x))에 p(x)를 곱하고 x에 관하여 적분한다는 의미가 된다!

 

 

2) Dual representations of the KL-divergence

 

KL divergence의 두 가지 representation에 대해 소개하고 있는데, MINE에서는 이 representation들을 사용하여 mutual information 수식을 최적화한다!

 

왜 굳이 위에서 정의한 KL term을 놔두고 Lower bound로 approximation하는가? 는 직접 KL term을 계산하기 위해서는 P와 Q의 분포를 정의할 수 있어야 하기 때문이다. (보통 training dataset X의 분포는 아무도 모른다.. 그걸 알기 위해서 neural net을 학습시킨다 ㅇㅅㅇ)

 

2-1. The Donsker-Varadhan representation

위의 수식처럼 Donsker-Varadhan representation은, KL term을 오른쪽 항과 같은 lower bound로 표현할 수 있다는 것이다. (sup의 의미는 상계)

여기서의 T는 결과값이 실수로 대응하는 어떠한 함수를 의미하는데, 즉 P와 Q의 KL term 값을 어떠한 함수 T의 연산으로 계산할 수 있다는 것이다.

 

논문에서 함수 T에 관한 KL term의 Lower Bound를 소개한 이유는, T 자리에 후에 parameter θ를 가진 neural network의 함수를 대입할 것이기 때문에다.

 

이 lower bound는 아래와 같은 수식을 만족할 때 tight하다고 할 수 있다.

이게 몬소리지...??;; 하시는 분들이 계실텐데(본인도 마찬가지임) 뒤에 Appendix8.2.1을 참고하면 수식 증명이 나와있다. 증명을 보면 tight할 때의 조건이 무슨 뜻인지 조금은 이해할 수 있다!

내용은 아래와 같다.

 

 

2-2. The f-divergence representation

 

위 수식 역시 2-1과 마찬가지로 KL term의 lower bound를 다른 방법으로 나타낸 것이다. 그러나 일반적으로, 2-2의 bound는 2-1의 bound보다 tight하지 않다. (오른쪽 항에서 뺄셈 부분이 2-2의 값이 더 크다) 논문에서는 두 방법 모두에 대한 성능을 증명하였지만, 더 KL term 값에 가까운 하한인 2-1이 더 strong하다고 이야기하고 있다.

 

3. MINE(The Mutual Information Neural Estimator)

 

3.1 Method

 

먼저 mutual information을 neural network의 가중치 parameter θ에 관해 나타낼 수 있다면, 아래와 같은 식이 성립한다.

이유는 왼쪽 항은 기존 정의의 mutual information이고(optimal 값), 오른쪽 항은 θ값 update를 통해 mutual information을 estimation한 값이기 때문이다.

I(θ)의 수식은 2-1에서의 Donsker-Varadhan representation에 의해 아래와 같이 나타낼 수 있다.

  • I(X, Z) 는 Pxz와 Px * Pz의 KL divergence로 나타낼 수 있음 -> 2-1 식에서 P, Q distribution 대신 Pxz, Px * Pz에 대하여 식을 변형한 것
  • minibatch상에서, X와 Z의 joint distribution이 있을 때 X를 무시하고 sampling하면, X에 관하여 적분한 "Z의 marginal distribution"이라고 생각할 수 있다.
  • 따라서 아래에 적힌 X와 Z의 joint distribution과, X와 Z 각각에 대한 marginal distribution의 곱에 대해 기댓값(모든 x, z에 관하여 적분) 계산을 할 수 있다!!
  • 위 objective function은 gradient ascent에 의해 maximize

 

즉 Sampling을 통해 위 수식의 기댓값들을 계산함으로써 mutual information을 update시키겠다는 것이다.

 

 

n크기의 minibatch만큼 sampling하여 계산한다면, 위 수식은 아래와 같은 I의 "추정치"로 나타낼 수 있다.

여기서 Tθ 는 neural network에 의해 parameterized된 함수이다. 이들의 집합을 F라고 표현한다.(올린 수식에는 F가 없지만 논문에는 F에 관한 설명이 약간 있으니 참고)

 

즉, minibatch로 샘플링하여 계산하는 알고리즘을 종합적으로 정리하여 보면 아래 그림과 같다.

 

  1. X와 Z의 joint distribution에서 b개 만큼의 미니배치 샘플을 추출한다.
  2. Z의 marginal distribution에서 n개 만큼의 미니배치 샘플을 추출한다.
  3. 위에서 정의한 I의 추정 식에 대입 : sampling한 x와 z값들을 식에 대입하여 값을 얻고, 전체 값의 평균 값을 구함
  4. gradient를 구하고 bias를 조정(아래 3.2절의 내용)
  5. 조정한 gradient로 parameter θ update

 

3.2 Correcting the bias from the stochastic gradients

 

논문에서는 그냥 stochastic gradient값을 일반적으로 구하면 bias를 일으킨다고 한다. 그 이유는 바로 아래 수식 오른쪽 항의 분모(denominator)에 있다는뎁..

이유는 minibatch에서 구한 E[e^Tθ]로 나누어주기 때문에, 실제 full batch에서 구한 값과 차이가 있다는 것 이다. gradient가 좋아야 최적점으로 잘 수렴할 수 있는데, 이것은 학습에서 중요한 문제라고 할 수 있다. 이러한 차이를 줄이기 위해서, 분모를 exponential moving average를 사용하여 지속적으로 업데이트 하고, 학습이 진행될수록 full batch에서 구한 값에 가까워지도록 한다.

*신기하게도 f-divergence를 사용했을 때 bias 발생 정도가 더 적었다고 한다!

 

**exponential moving average란...

순차적으로 측정한 값 a,b,c,d가 있고 새로운 값 e가 들어왔을 때, 오래된 값은 지수적으로 가중치를 줄이고, 새로운 값에 가중치를 더해 이동 평균을 구하는 방법을 말한다. 위의 수식을 통하여 구하며 일반적으로 α 값으로는 2/(N+1)을 사용한다고 한다. (학습 parameter 아님!!!) 자세한 설명은 아래 위키피디아 링크를 참조하시면 좋을거 같습니당:)

https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average

 

Moving average - Wikipedia

An example of two moving average curves In statistics, a moving average (rolling average or running average) is a calculation to analyze data points by creating a series of averages of different subsets of the full data set. It is also called a moving mean

en.wikipedia.org

 

3.3 Theoretical properties

 

여기서는 과연 MINE의 I추정치 값이 수렴할까? 에 대해 수학적인 근거를 제시하고 있다.

위의 definition은 실제 I와 I추정치의 차이가 epsilon 이내로 수렴한다는 것인데, 위를 만족하기 위해서는 아래 두 가지 조건을 만족해야 한다.

 

1) approximation problem : IΘ와 I의 차이가 수렴하는가(approximation이 가능한가)의 문제로 size of family F(함수 T의 집합)와 관계

2) estimation problem : empirical measures와 관계, 즉 batch size로 추정한 I^과 IΘ의 차이가 수렴하는가

 

1), 2)에 대해 증명한 정리들이 있는데, 먼저 1)에 대해 살펴보면 아래와 같다.

이것에 대한 증명은 appendix 8.2.2 Consistency Proofs에서 찾아볼 수 있다. 증명 내용은 아래와 같다. (T가 M bounded인 경우에 대해서만 정리하였는데, bounded되지 않은 경우에 대해서도 비슷하게 appendix에서 증명하고 있다!)

 

 

2)를 증명한 정리는 아래와 같다.

결과적으로 1과 2를 증명하였으므로 맨 처음의 definition 3.2가 성립하고, MINE은 strongly consistent(수렴)하다고 할 수 있다.

 

3.2.2 Sample Complexity

 

원하는 accuracy를 얻기 위해 sampling을 몇 개 이상 해야 할 지에 관한 내용이다. 아래의 Theorem 3이 그 내용이며, 수식(15)에서는 sampling 개수 n의 최소 크기를 구체적으로 나타내고 있다. 이것을 증명하기 위해 가정한 내용은 TΘ가 M bounded(최댓값이 M을 넘지 않음), domain Θ 또한 K bounded라는 것이며, 이렇게 가정했을 때 알고리즘의 시간복잡도에 대해서도 언급하고 있다.

이에 대한 증명 역시 appendix 8.2.3에 소개되어 있다!

 

4. Empirical comparisons

이 부분에서 저자는 MINE이 mutual information을 잘 계산해 내고, non-linear함수를 어떤 것을 사용하던지 dependent하다고 주장한다. 결과를 보면 먼저 첫 번째로, non-parametric estimation인 KNN based estimator(Kraskov et al.)와 어떤 것이 mutual information을 더 잘 계산하는지 비교하였다. 실험을 위해 사용한 두 분포 Xa와 Xb는 gaussian이며, correlation값 δ를 변화시켜 가며 측정하였다. 결과 그래프는 아래와 같은데, 왼쪽의 경우 random variable의 차원이 2dimensional, 오른쪽은 20dimensional인 경우이다.

그림에서 보듯이 MINE의 mutual information값이 더 높은 것을 볼 수 있고, 이는 random variable의 차원이 커질 수록 정도가 커진다.

 

두 번째로는 non-linear transform에서 어떤 non-linear함수를 사용하는지에 관계 없이 mutual information의 양을 측정한다는 것이다. 실험 결과 그래프는 아래와 같다.

위 그림에서 보듯 sinx, x, x^3어떤 함수를 썼냐에 따라서보다는, Y = f(x) + σ * ε에서의 noise σ에 따라 mutual information 측정 결과가 달라지는 것을 볼 수 있다.(색이 진할수록 mutual information 값이 큰 것)

 

 

5. Applications

저자는 MINE에서 parameter update를 통해 mutual information을 구하는 방식을 GAN, bi-directional adversarial models, information bottleneck에 적용한 결과를 보여주고 있다. 여기서는 GAN에 적용한 방식에 대해서만 이야기하도록 하겠다.

 

5.1 Maximizing mutual information to improve GANs

 

저자는 MINE을 통해 GAN에서 일반적으로 발생하는 문제인 mode collapse를 해결할 수 있다고 한다. Mode collapse에 대해 자세히 설명한 블로그 링크는 맨 위에 걸어 두었는데, 간단히 설명하면 data가 multi-modal인 경우 discriminator가 하나의 mode로만 치우치는 현상이다. 예를 들어, MNIST dataset의 경우 0~9까지의 mode가 있다면, GAN은 4만 생성하는 것이다. 이는 generator가 충분한 다양성을 가진 샘플을 생성하도록 학습하는 데 실패하였기 때문이다.

 

이전 글 [CS231n] GAN 편에서 자세히 설명했지만, 일반적인 GAN의 objective function은 아래와 같다.

여기서의 오른쪽 항은 generator의 목적식인데, 저자는 이 부분을 mutual information식으로 변환하였다. (기존에는 sample들의 negative entropy를 사용하여 regularizing하는 것을 사용하였는데, 저자는 이것보다 성능이 더 좋다고 주장한다.) 아래가 loss function 변환 식인데, infoGAN과 거의 비슷한 모습이라고 한다.

 

먼저 input인 prior distribution은 논문에서 Z = [ε, c]로 나타내었는데, ε은 noise, c는 code variable로 물체의 각도, shape 등 학습하길 원하는 class를 concatenation한 것을 의미한다. 위 수식에서 첫 번째 항은 generator의 loss 그대로이며, ((16)에서 오른쪽 항을 (1 - ~~)로 계산하지 않고 그냥 계산하여 maximize한다, gradient문제 때문 : [CS231n]GAN 글에서 설명) 여기에 generator의 분포와 원하는 class(mode)의 분포 간의 mutual information값을 더해주었다. 즉 앞의 항에서 discriminator의 지시대로 학습할 뿐만 아니라, 원하는 데이터의 분포와 mutual information까지 더 학습하면서 성능을 높이겠다는 뜻이다.

 

이에 대한 결과는 Experiments : Spiral, 25-Gaussians datasets에서 보여주고 있다. 기존의 GAN 식에서는 β = 0, MINE에서는 mutual information maximization학습 결과에 따라 β = 1.0이다.(β는 위의 수식 (17)의 값) 이에 따른 결과는 아래 그림과 같다.

 

여기서 보듯이 GAN+MINE이 GAN보다 mutl-modal data의 분포를 더 잘 나타낸다고 볼 수 있다!