Courses/cs231n

[CS231n] Generative Models (2) - GAN

모끼주인 2019. 4. 15. 15:54

* 이 글은 Stanford 대학의 CS231n 강의를 듣고 요약한 글입니다.

 

 

1. GAN의 network 구조

 

GAN은 SOTA Generative model이다. 앞서 언급한 PixelRNN/CNN, VAE와 무슨 차이가 있냐면..

 

PixelCNN에서는 P(x) 식을 Chain Rule을 이용하여 직접 정의하였다.

VAE에서는 latent variable z를 이용하여 P(x)를 간접적으로 표현하였고, 이를 계산하기 위해 Lower Bound를 최적화 하였다.

 

GAN에서는?

P(x)를 수식으로 정의하려고 하지 않는다.

대신에 게임 이론의 방식을 취하여, implicit하게(수식으로 직접 정의하지 않고) training distribution을 학습하고자 한다.

 

그럼 어떻게 implicit 하게 training distribution을 학습하는지 보면,

 

Input으로 gaussian random noise를 주고(VAE에서도 p(z)를 가우시안으로 가정했었음), 이를 neural network를 통해 training distribution으로 변화시키겠다는 것이다. 즉, 모든 random noise 입력을 학습 분포 sample에 mapping시키겠다는 것이다.

 

그렇다면 Generator Network는 어떻게 학습시켜야 할까? Generator에게 선생님 역할을 해 줄 추가적인 neural network구조가 필요하다. 아래 그림과 같이 Discriminator가 Generator의 선생님 역할을 해 준다.

 

GAN은 Generator와 Discriminator의 two player game이다.

전체적으로 random noise z로부터 Generator가 이미지를 생성해 내면, Discriminator는 이 이미지가 진짜 이미지(training set의 이미지)인지 가짜 이미지(generator가 생성한 이미지)인지 구별하는 구조이다.

즉, Generator는 Discriminator를 속이고자 노력하고 Discriminator는 Generator가 만든 이미지를 판별하고자 노력하는 것이다.

 

 

2. GAN의 objective function

 

이것을 어떻게 학습시킬지 Loss function을 정의한 것을 살펴보자.

먼저 첫 번째 항은, training data 집합인 Pdata의 x에 대한 D(x)의 기댓값이라는 뜻이다.

D(x)는 x에 대한 Discriminator의 판별 결과가 되는데, 포함하고 있는 parameter는 θd이다.

θd는 discriminator의 가중치 행렬을 뜻하는 것으로, 맨 앞의 max θd는 목적식을 최대화 하는 θd를 구하라는 것이다.

D(x)는 확률 값이므로 0~1사이의 값을 가지고, discriminator의 목적에 따라(training data를 참으로 판별) 1에 가까운 값을 가지도록 학습되어야 한다.

 

두 번째 항은, random noise 분포(gaussian)인 P(z)의 z에 대한 D(G(z))의 기댓값이라는 뜻이다.

D(G(z))는 말 그대로, generator가 만든 이미지에 대한 discriminator의 결과이다.

여기서 θg는 generator의 가중치 행렬로, 맨 앞의 min θg는 목적식을 최소화하는 θg를 구하라는 것이다.

 

즉 두 번째 항에서, generator입장에서는 D(G(z))가 1에 가까워져야 좋을 것이고(자신이 만든 data를 참으로 판별) discriminator입장에서는 D(G(z))가 0에 가까워져야 좋을 것이다. (생성 data를 거짓으로 판별)

 

 

결국, GAN은 두 가지 network를 동시 학습 시키는 것으로 아래와 같은 두 가지의 목적식을 갖게 된다.

 

generator의 경우, 첫 번째 항이 θg를 포함하지 않으므로 objective function에 포함하지 않는다.

 

그런데 여기서!!! 이대로 학습시키면 generator가 학습이 잘 되지 않는다.

아래 그래프에서 파란색 선은 generator의 objective function을 나타낸 것으로, x축은 D(G(x))값, y축은 function값을 나타낸다.

generator 학습 시, D(G(x))가 0에 가까울수록 1쪽으로 이동하게 학습시켜야 하는데, (discriminator판별 거짓 -> 참 쪽으로)

아래 파란 선에서는 D(G(x))가 0에 가까운 쪽에서는 gradient가 매우 평평하고, 1에 가까운 쪽에서는 gradient가 매우 가파르다.

즉 학습을 많이 해야할 구간에서는 학습이 거의 되지 않고, 학습을 천천히 진행해야 할 구간에서는 학습이 매우 가파르게 진행되는 것이다.

이러한 문제를 해결하기 위해, GAN은 기존 objective function을 아래와 같이 변형하고 gradient ascent를 구한다. 아래 objective function을 그래프 상에 나타낸 것이 위 그림에서의 초록 선이다.

새롭게 정의된 generator의 objective function

여기서 알 수 있는 GAN의 발전 가능 부분은, objective function을 어떻게 정의해야 generator를 더 잘 학습시킬 수 있는가이다.

Wasserstein GAN, LSGAN 등 다양한 종류가 loss function만 살짝 바꿔 성능을 향상시킬 수 있는 것이라고 한다. 

실제로 2017년에는 GAN의 해라 불릴 만큼 관련하여 엄청나게 많은 논문들이 쏟아져 나왔다고 한당....ㅇㅅㅇ

 

 

3. GAN의 학습

 

GAN의 학습 알고리즘은 아래와 같다.

먼저 k step만큼 discriminator를 학습(가중치 update)시킨 후, 이후에 generator를 한 번 학습시킨다는 것이다.

GAN의 최종 학습이 완료될 때 까지 이러한 과정을 계속 반복한다.

 

여기서 k step!!에 대한 논쟁이 끊이지 않고 있다고 한다. 

일반적으로 초반에는 discriminator가 더 많이 학습되어야 좋다고 하는데, 그리하여 어떤 사람들은 k > 1인 경우가 좋다고 주장한다. 하지만 또 다른 사람들은 k = 1인 경우가 가장 stable하다고 주장한다.

 

Wasserstein GAN이라는 최근의 연구는 이러한 Discriminator와 Generator 사이의 불균형을 조금 더 해소시켰다고 한다.

 

또한, 초기의 GAN은 discriminator와 generator로 fully connected를 사용하였는데, 

Convolutional Architecture를 도입한 DCGAN등 다양한 GAN이 CNN구조를 도입하기 시작하면서 그 성능이 크게 향상되었다고 한다.

 

4. GAN의 생성 결과

역시 test time에는 VAE와 마찬가지로 generator만 떼네어 사용한다.

 

GAN에서도 역시 과학자들은 latent variable의 변화에 따른 이미지 생성 변화를 측정하여 보았다.

맨 왼쪽 그림은 latent variable z1, 맨 오른쪽 그림은 z2에서 샘플링한 이미지이다. 가운데 그림들은 맨 왼쪽과 오른쪽 그림을 interpolation하여 구한 결과이다.

 

부드럽게 이미지가 변화하는 것을 볼 수 있으며, 이를 통해 latent variable인 z가 data 구조의 중심 축 역할을 한다는 것을 알 수 있다.

 

아래에서도 흥미로운 사실을 발견할 수 있다.

latent variable z를 smiling woman, neutral woman, neutral man으로 고정하고 이미지들을 샘플링 한 다음, 이들의 평균값으로 벡터 연산을 수행하면 smiling man을 얻을 수 있다는 것이다.

 

참 신기하당 ㅇㅅㅇ

 

또한 "The GAN Zoo"에서 다양한 GAN관련 논문들을 찾아볼 수 있고,

https://github.com/soumith/ganhacks 이 사이트에서 GAN학습에 관한 팁들을 얻을 수 있다고 하니 참고하면 좋을 것 같다.