본문 바로가기
GAN

[GAN] Generative Adversarial Nets(NIPS2014) 리뷰 및 코드 구현

by pulluper 2022. 3. 23.
반응형

안녕하세요 pulluper 입니다! 😁

이번 포스팅은 generative model중 하나인 GAN에 대하여 리뷰 / 코드분석 을 해 보겠습니다. 


INTRO

 

요슈아 뱅지오 교수님과 이얀 굿펠로우등의 연구자들이 NIPS2014 년에 발표한 논문입니다. 

나오고 엄청난 열풍이 붑니다. 2022년 3월 기준 4만회의 citation 이 있으며,
대표적인 generative model 의 한 분야로 자리 잡았습니다. 

https://arxiv.org/pdf/1406.2661.pdf

GAN(Generative Adversarial Net) 을 직역해 보면, 

Generative  : 생성의

Adversarial  : 적대적인

Network      : 네트워크 

한국말로 "적대적 생성 신경망" 이라고 불리기도 합니다. 


IDEA

 

Generarive model의 목표는 학습 data와 유사한 data를 생성하는 것 입니다. 

즉, 학습 data 의 분포를 학습 하는 것으로 생각 할 수 있습니다. 

GAN은 적대적인(서로 반대되는 학습 목적을 가진) 딥뉴럴 네트워크를 통해서 

Generative model 의 목표를 이루고자 합니다. 

 

GAN의 아이디어는 다음과 같습니다.

예시로 많이 쓰이는 도둑(위조지폐범)과 경찰(감별사)에 대하여 생각해 봅시다. 🥲

idea of gan

파란색 생성 부분에서 도둑(위조지폐범)은 위조지폐를 진짜처럼 만드는게 목표이고,

분홍색 감별 부분에서 경찰(감별사)는 진짜지폐와 위조지폐를 잘 구분하는게 목표입니다. 

 

이렇듯, 도둑(위조지폐범)과 경찰(감별사)을 각각 Deep neural network로 두고

적절한 Loss 로 서로 반대되는 목표를 학습하도록 만든 것이 GAN의 아이디어 입니다.  

그렇다면, 어떻게 생성을 하고 감별을 하도록 적절한 Loss를 만들수 있을까요? 🤪


IDEA to LOSS

 

그럼 이 아이디어를 어떻게 적절한 Loss로 만드는지 알아봅시다. 

"감별"한다는 것은 Real(진짜지폐)와 Fake(위조지폐)의 2개의 class 를 구분한다는 것입니다. 

우리가 익히 알고있는 2개의 class 를 구분하는 방법 즉, binary cross entropy 를 이용하면 되지 않을까요?

BCE loss 는 다음과 같습니다.

BCE loss

그리고 다음과 같이 설정합니다. 

$D(.)$ :  Discriminator,

$G(.)$ :  Generator,

$x$      :  real image,

$z$      :  latent vector,

 

자 이제 D 를 학습할 때, G 를 학습할 때를 나누어 생각을 해 보겠습니다.

 

D(Discriminator) 학습

D의 입장에서 BCE의 Loss 가 줄면, D는 분류를 잘 하도록 학습이 됩니다. 

D의 학습을 위해서는 real image(x) 의 경우와 fake image(G(z))의 경우를 모두 학습하여야 합니다. 

real image 일 때 $(a, b) =  (D(x), 1)$ 이고, loss 는 $-\frac{1}{n}\sum_i^n{logD(x)}$ 입니다. 

fake image 일 때 $(a, b) =  (D(G(z)), 0)$ 이고, loss 는 $-\frac{1}{n}\sum_i^n{(1-logD(G(z)))}$ 입니다. 

 

G(Generator) 학습

G의 학습을 위해서는 fake image(G(z)) 경우만 고려하면 됩니다. 왜냐하면 real image가 들어갈때는 G가 관여를 하지 않기 때문입니다. 따라서 fake image 일 때 $(a, b) =  (D(G(z)), 0)$ 이고, loss 는 $-\frac{1}{n}\sum_i^n{(1-logD(G(z)))}$ 입니다. 이때, loss 가 minimize 되는것은 D를 학습시키는 것 이므로 반대로 maximize 시켜야 합니다. 

 

학습 Loss 정리 

$$ \begin{align} -\frac{1}{n}\sum_i^n{logD(x)}\tag{1} \\ -\frac{1}{n}\sum_i^n{(1-logD(G(z)))} \tag{2}\\ \end{align} $$

(1)의 수식은 loss가 결국 전체 학습 data의 분포 $P_{data}(x)$ 로부터 sampling된 x에 대한 expectation 입니다. 

마찬가지로, (2)의 수식은 이번에는 latent space 분포인 $P_z(z)$ 로부터 sampling된 z에 대한 기댓값입니다. 

따라서 (1) 의 수식은 (3) 으로, (2) 의 수식은 (4)로 서술될 수 있습니다. 

$$ \begin{align} -\mathbb{E}_{x\sim P_{data(x)}} [logD(x)]\tag{3} \\ \\ -\mathbb{E}_{z\sim P_z(z)} [log(1-D(G(z)))]\tag{4} \end{align} $$

 

D의 학습은 fake image, real image일 때 모두 구분을 잘 하도록 다음 위의 loss 들의 합인 $ -\mathbb{E}_{x\sim P_{data(x)}} [logD(x)] - \mathbb{E}_{z\sim P_z(z)} [log(1-D(G(x)))]$ 를 최소화 하는 것 이므로 $ \mathbb{E}_{x\sim P_{data(x)}} [logD(x)] + \mathbb{E}_{z\sim P_z(z)} [log(1-D(G(z)))]$ 를 최대화 하는것과 같습니다. 

 

G의 학습 $ -\mathbb{E}_{z\sim P_z(z)} [log(1-D(G(z)))]$를 최대화 하는 것 이므로 $ \mathbb{E}_{z\sim P_z(z)} [log(1-D(G(z)))]$최소화 하는것과 같습니다. 

 

GAN 전체 loss D와 G를 iterative 하게 각각 학습시키며 loss 는 다음과 같습니다. 특히 D는 loss 를 최대화, G는 loss 를 최소화 하도록 학습합니다. 

loss of GAN


GAN CODE

 

Gan 의 코드를 통해서 이해해보는 chapter 입니다. 

 

1. configuration

 

configuration 부분에서는 학습에 필요한 각종정보들을 구성/설정해주는 부분입니다. 

 

########################################################
# 1. configuration
########################################################

parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--save_path', type=str, default='./saves')
parser.add_argument('--data_path', type=str, default='D:\data\MNIST')
parser.add_argument('--save_file_name', type=str, default='finetuned_pruned')
opts = parser.parse_args()

 

2. datasets

 

간단한 실험을 위해서 torchvision 에서 제공하는 MNIST dataset 을 이용하였습니다.

data transforms 및 augmentation 은 normalization 만 수행했습니다. 

 

########################################################
# 2. datasets
########################################################

import torchvision.transforms as tfs
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

# transform
transform_mnist = tfs.Compose([tfs.ToTensor(),
                               tfs.Normalize((0.1307,), (0.3081,))
                               ])
test_transfrom_mnist = tfs.Compose([tfs.ToTensor(),
                                    tfs.Normalize((0.1307,), (0.3081,))
                                    ])
# dataset
train_set = MNIST(root=opts.data_path, train=True, download=True, transform=transform_mnist)
test_set = MNIST(root=opts.data_path, train=False, download=True, transform=test_transfrom_mnist)

# data loader
train_loader = DataLoader(dataset=train_set,
                          shuffle=True,
                          batch_size=opts.batch_size)

test_loader = DataLoader(dataset=test_set,
                         shuffle=False,
                         batch_size=opts.batch_size)

 

3. models

 

models 부분에서는 generator, discriminator 를 각각 구현하였습니다. 

 

########################################################
# 3. models
########################################################
import torch.nn as nn

# generator
d_noise  = 100
d_hidden = 256


class Generator(nn.Module):
    def __init__(self, d_noise=100, d_hidden=256):
        super().__init__()

        self.generator = nn.Sequential(
            nn.Linear(d_noise, d_hidden),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_hidden, d_hidden),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_hidden, 28 * 28),
            nn.Tanh()
        )

    def forward(self, x):
        return self.generator(x)


# discriminator
class Discriminator(nn.Module):
    def __init__(self, d_hidden=256):
        super().__init__()

        self.discriminator = nn.Sequential(
            nn.Linear(28 * 28, d_hidden),
            nn.LeakyReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_hidden, d_hidden),
            nn.LeakyReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_hidden, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.discriminator(x)

 

4. loss

 

 

 

 

 

MNIST

 

할 때, D 를 학습할 때를 나누어 생각을 해 보겠습니다. 

Target of training Real / Fake (a, b) BCE Loss
G Real x x
G Fake $(D(G(z)), 1)$ $-\frac{1}{n}\sum_i^n{logD(G(z))}$
D Real $(D(x), 1)$ $-\frac{1}{n}\sum_i^n{logD(x)}$
D Fake $(D(G(z)), 0)$                                                          $-\frac{1}{n}\sum_i^n{log(1-D(G(z)))}$

 

자 이제 G 를 학습할 때, D 를 학습할 때를 나누어 생각을 해 보겠습니다. 

 

1) Real image 가 들어갈 때를 생각해 봅시다.

 

이때, label을 1로 가정하면, $(D(x), 1)$ 가 들어가서 BCE 의 뒤의 term 이 사라지고 

$-\frac{1}{n}\sum_i^n{logD(x)}$ 가 남습니다. 여기서 이 loss 값이 줄면, discriminator network $D(.)$ 가 진짜 이미지를 잘 구별하도록 학습시키는 것 입니다.

 

여기에 음수 negative 가 붙었으니, $-\frac{1}{n}\sum_i^n{logD(x)}$를 최소화 시키는 것은  $\frac{1}{n}\sum_i^n{logD(x)}$ 를 최대화 시키는 것으로 볼 수 있습니다. 이러한 D를 구하는 것이 Real image 를 잘 판별하는 D를 학습하는 것 입니다. 

음수의 min은 양수의 max 로 변할 수 있습니다. (1) 또한 이 loss 를 학습데이터의 전체 분포($P_{data(x})$)로 생각해 본다면, 거기서 뽑은 sample $x$에 대한 expectation으로 표현 가능합니다. (2) 따라서 BCE loss 로 D 를 학습하는 과정은 (3) 으로 표현 할 수 있습니다 

 

2) 이번에는 Fake image 가 들어갈 때 입니다. 

 

(a, b) 대신에 $(D(G(z)), 0)$ 이 bce 의 input으로 들어갑니다. Loss는 오히려 앞쪽의 term이 사라지고 $-\frac{1}{n}\sum_i^n{log(1-D(G(z)))}$ 만 남게 됩니다. 이 Loss term 이 줄어든다는 뜻은 D가 fake image 를 잘 구별한다는 것 입니다. 그런데, G의 입장에서는 이 Loss를 오히려 크게 만들어야지 IDEA에 맞게 됩니다. 

따라서 (4) 수식처럼 G는 loss 를 키우고 이것은 - 가 사라지는 대신 min 으로 바뀌게 됩니다. 

위와 마찬가지로 이번에는 $P_z(z)$ 에서 sampling된 latent vector의 expectation 으로 사용할 수 있고(5),

결국 G는 이 loss * (-1) 를 min하는 방향으로 학습 하게 됩니다. (6)

 

(3) 식과 (6) 식을 차례로 optimize 한다면 GAN의 loss 와 같은 형태가 됩니다. 

즉 BCE에서부터 D, G를 학습시키도록 loss 를 정해주면 됩니다. 😎😎😎😎😎

loss of GAN


작성중..

반응형

댓글