논문 리뷰/GAN

[논문 읽기] Wasserstein GAN

모끼주인 2019. 5. 5. 18:15

논문 링크 : https://arxiv.org/pdf/1701.07875.pdf

불러오는 중입니다...

 

 

아래 블로그가 정말 알기쉽게 설명이 잘 되어있습니다!! 많이 참고하였고 다른 분들도 참고하시면 좋을거 같습니다ㅎㅎ

https://medium.com/@jonathan_hui/gan-wasserstein-gan-wgan-gp-6a1a2aa1b490

 

1. Introduction

Unsupervised Learning(Self-supervised Learning)은 학습 데이터 x에 대한 정답 라벨 y가 존재한 것과는 달리, 데이터 x의 분포 P(x)를 직접 학습하겠다는 것이다. 이를 위해서 P(x)를 parameter θ에 대해 아래와 같이 표현하고, 이를 학습시킬 수 있다.

그러나 P(x)의 식을 직접 표현하는 것은 어렵기 때문에(정답을 이미 알고있다는 의미가 된다), GAN에서는 x를 결정하는 latent variable z의 분포를 가정하여 입력으로 대입하고, discriminator와 generator간의 관계를 학습시킴으로써 generator의 분포를 P(x)에 가깝게 학습시키고자 한다. (이전 [CS231n] GAN 글 참고)

 

그러나 GAN의 문제점은 discriminator와 generator간의 균형을 유지하며 학습하기 어렵고, 학습이 완료된 이후에도 mode dropping이 발생한다는 것이다. 이러한 문제가 발생하는 이유는, discriminator가 선생님 역할을 충분히 해 주지 못해 모델이 최적점까지 학습되지 못했기 때문이다. (mode dropping : http://dl-ai.blogspot.com/2017/08/gan-problems.html)

 

Wasserstein GAN에서는 이러한 문제점을 해결하기 위하여 기존의 GAN에 비해 아래와 같은 차이점을 둔다.

  • discriminator대신 새로 정의한 critic을 사용한다. discriminator는 가짜/진짜를 판별하기 위해 sigmoid를 사용하고, output은 가짜/진짜에 대한 예측 확률 값이다.
  • 반면 critic은 EM(Earth Mover) distance로부터 얻은 scalar 값을 이용한다.
  • EM distance는 확률 분포 간의 거리를 측정하는 척도 중 하나인데, 그 동안 일반적으로 사용된 척도는 KL divergence이다. KL divergence는 매우 strict 하게 거리를 측정하는 방법이라서, continuous하지 않은 경우가 있고 학습시키기 어렵다.

 

결과적으로, GAN의 discriminator보다 선생님 역할을 잘 할 수 있는 critic을 사용함으로써 gradient를 잘 전달시키고 critic과 generator를 최적점까지 학습할 수 있다는 것이다. 때문에 아래와 같은 이점을 얻을 수 있다고 주장한다!

  • training 시 discriminator와 generator간의 balance를 주의깊게 살피고 있지 않아도 된다!
  • GAN에서 일반적으로 발생되는 문제인 mode dropping을 해결 가능하다!

 

2. Different Distances

먼저 Wasserstein GAN의 알고리즘에 대해 소개하기 전에, 여기서 사용한 확률 거리 척도의 당위성에 대해 이야기 하는 부분이다.

 

즉, 가장 널리 사용되는 KL divergence를 왜 사용하지 않았나?? 에 대해 설명하는 부분이다!

아래에서는 KL divergence를 포함한 네 가지 거리 척도에 대해 소개하고, 이들에 대해 비교한다.

 

<아래부터는 논문에서 수학적인 내용이 굉장히 많습니다! 요약하여 정리하였고, 자세히 이해하고 싶으신 경우 아래 링크를 참조하시면 좋을거 같습니다. 2절 내용은 아래 링크에서 대부분 이해하고 참조하여 작성하였습니다.>

<문제 시 복사해 온 그림 및 참조 내용은 삭제하도록 하겠습니다.>

https://www.slideshare.net/ssuser7e10e4/wasserstein-gan-i

 

Wasserstein GAN 수학 이해하기 I

이 슬라이드는 Martin Arjovsky, Soumith Chintala, Léon Bottou 의 Wasserstein GAN (https://arxiv.org/abs/1701.07875v2) 논문 중 Example 1 을 해설하는 자료입니다

www.slideshare.net

[1] 네 가지 거리 종류 정의

 

1. Total Variation(TV)

두 확률 분포의 측정값이 벌어질 수 있는 가장 큰 값을 뜻한다. 아래 그림을 보면 쉽게 이해할 수 있다! 즉 아래 그림에서 빨간색 A의 영역 안에 있는 A들을 대입하였을 때, Pr(A)와 Pg(A)의 값의 차 중 가장 큰 것을 뜻한다.

출처 : Wasserstein GAN 수학 이해하기 1 (위 링크 참조)

2. Kullback-Leibler(KL) divergence

3. Jensen-Shannon(JS) divergence

*KL divergence와 JS divergence는 가장 대표적으로 쓰이는 확률 분포 간 거리 척도로, 자세히 설명하지는 않겠습니다!

검색하시면 쉽게 내용을 찾아보실 수 있습니다.

 

4. Earth-Mover(EM) distance

두 확률 분포의 결합확률분포 Π(Pr, Pg)중에서 d(X, Y) (x와 y의 거리)의 기댓값을 가장 작게 추정한 값이다.

즉 아래 그림에서 파란색 원이 X의 분포, 빨간색 원이 Y의 분포, 𝛘가 결합 확률 분포를 의미하며,

초록색 선의 길이가 ||x-y||를 의미한다. 즉, 초록색 선 길이들의 기댓값을 가장 작게 추정한 값이다.

 

출처 : Wasserstein GAN 수학 이해하기 1 (위 링크 참조)

 

[2] EM distance의 타당성

각 거리 함수들의 정의는 위와 같고, 논문에서는 아래 Example 1을 통해 EM distance의 타당성을 이야기하고 있다.

 

아래는 임의의 distribution P0와 Pθ를 정의하고, 이들 간의 확률 거리를 구해 본 결과이다.

결과에서 볼 수 있듯이 Wasserstein 거리(EM distance)의 경우 θ에 관계 없이 일정한 수식을 가지고 있으나,

다른 거리의 경우 θ값에 따라 거리가 달라질 뿐만 아니라 그 값이 상수 또는 무한대인 것을 볼 수 있다.

 

 

다시 말하면, KL/JS divergence나 TV같은 경우에는 두 분포가 서로 겹치는 경우에는 0, 겹치지 않는 경우에는 무한대 또는 상수로 극단적인 거리 값을 나타낸다는 것이다. 

 

이는 discriminator와 generator가 분포를 학습할 때 위 세 가지 distance를 기반으로 학습하게 된다면 굉장히 어려움을 겪을 것이라는 것을 이야기 해 주는 것이다!!

(초반에는 실제 데이터의 분포와 겹치지 않을 것이므로 무한대 또는 일정한 상수 값을 갖다가, 갑자기 0으로 변해버리므로 gradient가 제대로 전달되지 않는다)

 

반면 EM distance의 경우 분포가 겹치던 겹치지 않던 간에 |θ|를 유지하므로, 학습에 사용하기 쉽다는 것을 말하고 싶은 것이다.

 

 

위의 결과가 나온 이유를 그림으로 그려 설명해보면...

 

<위 Example1에 대한 그림 설명은, 맨 위 링크를 걸어둔 "Wasserstein GAN 수학 이해 하기 1"을 읽고 요약한 내용입니다. 자세한 수학적 내용이 이해가 가지 않으신다면, 저 링크를 참조하여 읽어보시면 좋을 것 같습니다!>

 

1) EM distance

2) KL divergence

3) JS divergence

4) TV

 

[3] EM distance를 사용하기 위한 제약조건

 

EM distance를 loss function으로 사용하기 위해서는 미분이 가능해야 한다.

아래 내용은 EM distance의 연속성을 증명하고, 이를 위해 어떤 조건이 필요한지 명시한 것이다.

 

이 부분이 중요한 이유는, 이후 3절에서 연속성 제약 조건을 만족하기 위해 매번마다 clipping이라는 것을 해 준다!

 

Pr은 학습하고자 하는 목표 distribution이며, Pθ는 학습시키고 있는 현재의 distribution으로 생각할 수 있다.

z는 latent variable의 space이며, 함수 g는 latent variable z를 x로 mapping하는 함수이다. 이 때 gθ(z)의 distribution이 Pθ가 된다. (GAN의 기본적인 가정 사항에 관한 내용이다. 자세한 내용은 [CS231n] GAN 참조)

 

이 때,

1. g가 θ에 대해 연속한다면, Pr와 Pθ의 EM distance 또한 연속한다.

2. g가 Lipschitz조건을 만족한다면, Pr와 Pθ의 EM distance 또한 연속한다.

 

여기서 Lipschitz조건이란, 두 점 사이의 거리를 일정 비 이상으로 증가시키지 않는 함수를 뜻한다. (자세한 내용은 아래 참조)

https://ko.wikipedia.org/wiki/%EB%A6%BD%EC%8B%9C%EC%B8%A0_%EC%97%B0%EC%86%8D_%ED%95%A8%EC%88%98

 

립시츠 연속 함수 - 위키백과, 우리 모두의 백과사전

위키백과, 우리 모두의 백과사전. 해석학에서, 립시츠 연속 함수(영어: Lipschitz-continuous function)는 두 점 사이의 거리를 일정 비 이상으로 증가시키지 않는 함수이다. 이름은 독일의 수학자인 루돌프 립시츠의 이름을 땄다. 두 거리 공간 ( X , d X ) {\displaystyle (X,d_{X})} , ( Y , d Y ) {\displaystyle (Y,d_{Y})} 사이의 함수 f : X → Y {\displaystyle

ko.wikipedia.org

출처 : GAN - Wasserstein GAN & WGAN-GP (https://medium.com/@jonathan_hui/gan-wasserstein-gan-wgan-gp-6a1a2aa1b490)

위 수식을 보면 무슨 뜻인지 쉽게 감을 잡을 수 있다! (위 수식은 K-Lipschitz 조건, 자세한 내용은 블로그 참조)

이 Lipschitz조건을 만족시키기 위해 아래 3절에서 계속 clipping을 해 준다!!

 

 

 

3. Wasserstein GAN

 

이제 원래 Loss function에서 계산해야 할 것이 이것이었다면, 사실 여기서 inf 부분을 계산할 수가 없다. 이유는 Pr과 Pg의 joint distribution을 계산해야 하는데, Pr은 우리가 알고자 하는 대상이기 때문이다... ㅠㅠ

 

그래서 Kantorovich-Rubinstein duality를 이용하여 식을 바꿔보면

이렇게 된다! (sup 아래 의미는 f가 1-Lipschitz 조건을 만족한다는 뜻)

이것을 학습시키기 위해 parameter가 추가된 f 로 수식을 바꾸고, Pθ를 g(θ)에 대한 식으로 바꾸면 아래와 같은 수식이 된다.

 

 

기존 GAN의 loss와 비슷한 모습임을 볼 수 있다!

 

그런데 여기서도 Pr이 있는데??? 앞의 항은 어떻게 구하지?? 라는 생각이 든다면 괜찮다. 

앞 부분은 잘 학습된 discriminator(사실은 critic이다, 아래에서 설명)가 Pr의 역할을 하여 줄 것이고, 아래와 같이 gradient update를 할 때에는 θ에 대해 미분하면 앞의 항이 사라지게 되기 때문이다.

 

 

https://medium.com/@jonathan_hui/gan-wasserstein-gan-wgan-gp-6a1a2aa1b490

 

GAN — Wasserstein GAN & WGAN-GP

Training GAN is hard. Models may never converge and mode collapses are common. To move forward, we can make incremental improvements or…

medium.com

<아래 그림들은 위 블로그에서 가져온 그림인데, GAN과 WGAN의 차이를 알기 쉽게 잘 살명하여 놓았습니다.

꼭 읽어보시는걸 추천드립니다!>

 

출처 : GAN - Wasserstein GAN & WGAN-GP

즉 f(x)는 Lipschitz조건을 만족하는 함수로, discriminator역할을 하는 함수이다(여기서는 critic이라고 이름 붙였다.)

 

앞서 보아왔듯이 critic의 loss function 항 자체가 EM distance(Wasserstein distance)를 의미하므로, 위 loss function을 최대화(최대화 이유 : Kantorovich-Rubinstein duality를 통해 sup으로 변형됨, 이후 parameter 식으로 변형하면서 maximize로 바뀜) 하는 함수 f를 찾는 문제가 된다. 여기서 w는 함수 f의 parameter, 즉 critic의 parameter이며, maximize이므로 w에 대한 gradient ascent이다.

 

generator의 loss function역시 Theorem3에서 정의한 대로, 변형한 Wasserstein distance 식을 θ에 대해 미분하여 앞의 식을 사라지게 하면 얻을 수 있다. generator의 경우, Theorem3에서 미분 결과에서 볼 수 있듯이 앞에 -가 붙어있다!! 즉, θ에 대한  gradient descent이다.

 

출처 : GAN - Wasserstein GAN & WGAN-GP

 

 

위 내용은 WGAN의 최종적인 학습 알고리즘이다.

 

먼저 n critic번 만큼 critic을 학습시키는 부분이 보이는데,

Pr과 p(z) (Pθ역할)를 미니배치만큼 샘플링한 후에, critic의 loss function을 이용하여 parameter w(즉 함수 f)를 update시킨다.

*왜 Adam을 안쓰고 RMSProp을 쓰는지 이유는 뒤에 나옵니다:)

 

여기서 update 후 clip(w, -c, c)라는 부분이 있는데, Lipschitz조건을 만족하도록 parameter w가 [-c, c]공간에 안쪽에 존재하도록 강제하는 것이다!! 이를 Weight clipping이라고 한다.

 

이는 WGAN의 한계점이라고 할 수 있는데, 실험 결과 clipping parameter c 가 크면 limit(c나 -c)까지 도달하는 시간이 오래 걸리기 때문에, optimal 지점까지 학습하는 데 시간이 오래 걸렸다고 한다. 반면 c가 작으면, gradient vanish 문제가 발생하였다고 한다. 이미 간결하고 성능이 좋기 때문에 사용하였지만, 이후의 발전된 방법으로 Lipschitz조건을 만족시키는 것은 다른 학자들에게 맡긴다...라고 쓰여있다(ㅠㅠ)

 

 

[Weight Clipping Parameter c에 따른 변화]

 

출처 : GAN - Wasserstein GAN & WGAN-GP

그림은 batch normalization없이 실험한 결과인데, clipping parameter c 에 매우 민감하게 gradient가 변화함을 볼 수 있다고 한다.

또한 clipping의 문제점은, 이것이 regularizer로써 작용하여 함수 f의 capacity를 줄인다는 것도 있다고 한다.

 

때문에 이러한 점을 보완하기 위해 만들어 진 것이 gradient penalty(그래프에서 파란색 선)를 준 WGAN-GP이다!

 

*자세한 내용은 GAN - Wasserstein GAN & WGAN-GP 내용을 참조하시면 알 수 있습니당...

 

[Discriminator vs Critic]

 

discriminator의 경우 일반적인 분류 neural net과 같이 이미지가 진짜인지, 가짜인지 sigmoid확률값으로 판별해 낸다.

그러나 critic의 경우 Wasserstein GAN 식 자체를 사용하기 때문에, scalar 값이 output이다. 이는 이미지가 진짜인지 아닌지에 대한 점수를 의미하는 것으로, sigmoid와 달리 saturation현상이 없고 좋은 gradient를 만들어 낸다.

 

따라서 진짜 optimal 지점까지 쉽게 학습이 가능하고, 앞서 서론에서 언급했던

  • discriminator와 generator간의 balance 맞추기
  • mode dropping (mode collapse) 문제

두 가지가 해결된다는 것이다!!

 

 

[RMSProp 사용 이유]

 

이 부분은 4절에서 나오는데 먼저 언급하자면, 실험 결과 critic을 학습 할 때 Adam과 같은 mometum 베이스 optimizer를 사용하면 학습이 unstable 하다는 것이다!

 

이유는, loss값이 튀고 샘플이 좋지 않은 경우(일반적으로 학습 초반) Adam이 가고자 하는 방향, 즉 이전에 기억했던 방향(Adam step)과 gradient의 방향 간의 cosine값이 음수가 된다는 것이다. 일반적으로 nonstationary 문제(극한값이 존재하지 않음)에 대해서는 momentum계열보다 RMSProp이 성능이 더 좋다고 한다.(여기서 정의한 문제도 nonstationary problem)

 

 

4. Empirical Results

전반적인 실험 결과와 성능에 대해 이야기하고 있다.

 

4.1 Experimental Procedure

맨 위의 그래프들은 discriminator대신에 critic을 적용한 것이고, 왼쪽은 generator로 Multi Layer Perceptron, 오른쪽은 DCGAN을 이용한 결과이다. sigmoid를 사용하지 않아 wasserstein거리가 점차적으로 줄어들고, sample의 결과도 훨씬 좋아진 것을 볼 수 있다.

아래 그림은 discriminator와 generator모두 MLP를 사용한 결과이다. Sample 그림은 무엇인지 알아보기 어렵고, 각 sample에 대해 wasserstein distance를 계산하여 보았을 때 상수값으로 변화하지 않는 것을 볼 수 있다.

 

4.2 Meaningful Loss Metric

여기서는 4.1과 같은 모델 구조(critic + MLP, critic + DCGAN, MLP + MLP)를 사용하였지만 generator iteration마다 JS 거리를 측정하여 그래프를 측정한 결과를 보여준다. 결과는 아래와 같으며, sample quality가 좋아져도 JS distance는 증가하거나 상수 값을 유지하는 것을 볼 수 있다.

즉, 말하고자 하는 바는 EM distance를 잘 사용했다!! 이다.

WGAN은 GAN역사상 처음으로 수렴(convergence)한 모습을 보여 준 경우라고 한다.(ㅇㅅㅇ...놀랍)

 

4.3 Improved Stability

여기서는 DCGAN generator를 이용하고, 다양한 변화를 주며 실험하여 discriminator와 critic의 성능을 비교한 부분이다.

결론적으로, discriminator와 critic간의 balance를 더 이상 신경쓰지 않아도 되며,

실험에서 WGAN알고리즘을 썼을 때는 mode collapse현상이 발생하지 않았다!!!라고 한다.

 

figure5는 일반적인 GAN과 WGAN으로 이미지를 생성한 결과이다.(DCGAN generator 이용) 둘 다 좋은 질의 sample을 생성한다.

 

그러나 figure6에서는, batch normalization을 없애고 generator의 DCGAN부분에서 filter수를 고정함으로써 전체적으로 parameter수를 줄인 결과이다. WGAN에서는 여전히 잘 작동하지만, 일반적인 GAN은 그렇지 않은 것을 볼 수 있다.

 

마지막으로 figure7은, generator를 MLP + ReLU로 변형하여 실험한 결과이다. 왼쪽은 WGAN, 오른쪽은 일반적인 GAN의 결과로, DCGAN을 사용했을 때 보다 퀄리티는 떨어졌지만 mode collapse 현상에 대해서 비교해 볼 수 있다. 오른쪽의 경우 비슷한 그림이 많고 특정 그림에 대해서는 생성하지 못하였지만, 왼쪽은 다양한 그림을 비슷하게 생성한 것을 볼 수 있다!

 

 

결론적으로, WGAN은 discriminator식을 EM distance를 사용하여 변형함으로써

gradient가 잘 흘러가도록 하여 GAN이 실제 optimal 지점까지 도달할 수 있도록 도왔고,

결과적으로 GAN학습을 안정화시키고 mode collapse 문제까지 해결한 것이라고 할 수 있다!!!