안녕하세요 pulluper 입니다! 😁
이번 포스팅은 generative model중 하나인 GAN에 대하여 리뷰 / 코드분석 을 해 보겠습니다.
INTRO
요슈아 뱅지오 교수님과 이얀 굿펠로우등의 연구자들이 NIPS2014 년에 발표한 논문입니다.
나오고 엄청난 열풍이 붑니다. 2022년 3월 기준 4만회의 citation 이 있으며,
대표적인 generative model 의 한 분야로 자리 잡았습니다.
https://arxiv.org/pdf/1406.2661.pdf
GAN(Generative Adversarial Net) 을 직역해 보면,
Generative : 생성의
Adversarial : 적대적인
Network : 네트워크
한국말로 "적대적 생성 신경망" 이라고 불리기도 합니다.
IDEA
Generarive model의 목표는 학습 data와 유사한 data를 생성하는 것 입니다.
즉, 학습 data 의 분포를 학습 하는 것으로 생각 할 수 있습니다.
GAN은 적대적인(서로 반대되는 학습 목적을 가진) 딥뉴럴 네트워크를 통해서
Generative model 의 목표를 이루고자 합니다.
GAN의 아이디어는 다음과 같습니다.
예시로 많이 쓰이는 도둑(위조지폐범)과 경찰(감별사)에 대하여 생각해 봅시다. 🥲
파란색 생성 부분에서 도둑(위조지폐범)은 위조지폐를 진짜처럼 만드는게 목표이고,
분홍색 감별 부분에서 경찰(감별사)는 진짜지폐와 위조지폐를 잘 구분하는게 목표입니다.
이렇듯, 도둑(위조지폐범)과 경찰(감별사)을 각각 Deep neural network로 두고
적절한 Loss 로 서로 반대되는 목표를 학습하도록 만든 것이 GAN의 아이디어 입니다.
그렇다면, 어떻게 생성을 하고 감별을 하도록 적절한 Loss를 만들수 있을까요? 🤪
IDEA to LOSS
그럼 이 아이디어를 어떻게 적절한 Loss로 만드는지 알아봅시다.
"감별"한다는 것은 Real(진짜지폐)와 Fake(위조지폐)의 2개의 class 를 구분한다는 것입니다.
우리가 익히 알고있는 2개의 class 를 구분하는 방법 즉, binary cross entropy 를 이용하면 되지 않을까요?
BCE loss 는 다음과 같습니다.
그리고 다음과 같이 설정합니다.
$D(.)$ : Discriminator,
$G(.)$ : Generator,
$x$ : real image,
$z$ : latent vector,
자 이제 D 를 학습할 때, G 를 학습할 때를 나누어 생각을 해 보겠습니다.
D(Discriminator) 학습
D의 입장에서 BCE의 Loss 가 줄면, D는 분류를 잘 하도록 학습이 됩니다.
D의 학습을 위해서는 real image(x) 의 경우와 fake image(G(z))의 경우를 모두 학습하여야 합니다.
real image 일 때 $(a, b) = (D(x), 1)$ 이고, loss 는 $-\frac{1}{n}\sum_i^n{logD(x)}$ 입니다.
fake image 일 때 $(a, b) = (D(G(z)), 0)$ 이고, loss 는 $-\frac{1}{n}\sum_i^n{(1-logD(G(z)))}$ 입니다.
G(Generator) 학습
G의 학습을 위해서는 fake image(G(z)) 경우만 고려하면 됩니다. 왜냐하면 real image가 들어갈때는 G가 관여를 하지 않기 때문입니다. 따라서 fake image 일 때 $(a, b) = (D(G(z)), 0)$ 이고, loss 는 $-\frac{1}{n}\sum_i^n{(1-logD(G(z)))}$ 입니다. 이때, loss 가 minimize 되는것은 D를 학습시키는 것 이므로 반대로 maximize 시켜야 합니다.
학습 Loss 정리
$$ \begin{align} -\frac{1}{n}\sum_i^n{logD(x)}\tag{1} \\ -\frac{1}{n}\sum_i^n{(1-logD(G(z)))} \tag{2}\\ \end{align} $$
(1)의 수식은 loss가 결국 전체 학습 data의 분포 $P_{data}(x)$ 로부터 sampling된 x에 대한 expectation 입니다.
마찬가지로, (2)의 수식은 이번에는 latent space 분포인 $P_z(z)$ 로부터 sampling된 z에 대한 기댓값입니다.
따라서 (1) 의 수식은 (3) 으로, (2) 의 수식은 (4)로 서술될 수 있습니다.
$$ \begin{align} -\mathbb{E}_{x\sim P_{data(x)}} [logD(x)]\tag{3} \\ \\ -\mathbb{E}_{z\sim P_z(z)} [log(1-D(G(z)))]\tag{4} \end{align} $$
D의 학습은 fake image, real image일 때 모두 구분을 잘 하도록 다음 위의 loss 들의 합인 $ -\mathbb{E}_{x\sim P_{data(x)}} [logD(x)] - \mathbb{E}_{z\sim P_z(z)} [log(1-D(G(x)))]$ 를 최소화 하는 것 이므로 $ \mathbb{E}_{x\sim P_{data(x)}} [logD(x)] + \mathbb{E}_{z\sim P_z(z)} [log(1-D(G(z)))]$ 를 최대화 하는것과 같습니다.
G의 학습은 $ -\mathbb{E}_{z\sim P_z(z)} [log(1-D(G(z)))]$를 최대화 하는 것 이므로 $ \mathbb{E}_{z\sim P_z(z)} [log(1-D(G(z)))]$를 최소화 하는것과 같습니다.
GAN 전체 loss D와 G를 iterative 하게 각각 학습시키며 loss 는 다음과 같습니다. 특히 D는 loss 를 최대화, G는 loss 를 최소화 하도록 학습합니다.
GAN CODE
Gan 의 코드를 통해서 이해해보는 chapter 입니다.
1. configuration
configuration 부분에서는 학습에 필요한 각종정보들을 구성/설정해주는 부분입니다.
########################################################
# 1. configuration
########################################################
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--save_path', type=str, default='./saves')
parser.add_argument('--data_path', type=str, default='D:\data\MNIST')
parser.add_argument('--save_file_name', type=str, default='finetuned_pruned')
opts = parser.parse_args()
2. datasets
간단한 실험을 위해서 torchvision 에서 제공하는 MNIST dataset 을 이용하였습니다.
data transforms 및 augmentation 은 normalization 만 수행했습니다.
########################################################
# 2. datasets
########################################################
import torchvision.transforms as tfs
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
# transform
transform_mnist = tfs.Compose([tfs.ToTensor(),
tfs.Normalize((0.1307,), (0.3081,))
])
test_transfrom_mnist = tfs.Compose([tfs.ToTensor(),
tfs.Normalize((0.1307,), (0.3081,))
])
# dataset
train_set = MNIST(root=opts.data_path, train=True, download=True, transform=transform_mnist)
test_set = MNIST(root=opts.data_path, train=False, download=True, transform=test_transfrom_mnist)
# data loader
train_loader = DataLoader(dataset=train_set,
shuffle=True,
batch_size=opts.batch_size)
test_loader = DataLoader(dataset=test_set,
shuffle=False,
batch_size=opts.batch_size)
3. models
models 부분에서는 generator, discriminator 를 각각 구현하였습니다.
########################################################
# 3. models
########################################################
import torch.nn as nn
# generator
d_noise = 100
d_hidden = 256
class Generator(nn.Module):
def __init__(self, d_noise=100, d_hidden=256):
super().__init__()
self.generator = nn.Sequential(
nn.Linear(d_noise, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 28 * 28),
nn.Tanh()
)
def forward(self, x):
return self.generator(x)
# discriminator
class Discriminator(nn.Module):
def __init__(self, d_hidden=256):
super().__init__()
self.discriminator = nn.Sequential(
nn.Linear(28 * 28, d_hidden),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, d_hidden),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.discriminator(x)
4. loss
MNIST
할 때, D 를 학습할 때를 나누어 생각을 해 보겠습니다.
Target of training | Real / Fake | (a, b) | BCE Loss |
G | Real | x | x |
G | Fake | $(D(G(z)), 1)$ | $-\frac{1}{n}\sum_i^n{logD(G(z))}$ |
D | Real | $(D(x), 1)$ | $-\frac{1}{n}\sum_i^n{logD(x)}$ |
D | Fake | $(D(G(z)), 0)$ | $-\frac{1}{n}\sum_i^n{log(1-D(G(z)))}$ |
자 이제 G 를 학습할 때, D 를 학습할 때를 나누어 생각을 해 보겠습니다.
1) Real image 가 들어갈 때를 생각해 봅시다.
이때, label을 1로 가정하면, $(D(x), 1)$ 가 들어가서 BCE 의 뒤의 term 이 사라지고
$-\frac{1}{n}\sum_i^n{logD(x)}$ 가 남습니다. 여기서 이 loss 값이 줄면, discriminator network $D(.)$ 가 진짜 이미지를 잘 구별하도록 학습시키는 것 입니다.
여기에 음수 negative 가 붙었으니, $-\frac{1}{n}\sum_i^n{logD(x)}$를 최소화 시키는 것은 $\frac{1}{n}\sum_i^n{logD(x)}$ 를 최대화 시키는 것으로 볼 수 있습니다. 이러한 D를 구하는 것이 Real image 를 잘 판별하는 D를 학습하는 것 입니다.
음수의 min은 양수의 max 로 변할 수 있습니다. (1) 또한 이 loss 를 학습데이터의 전체 분포($P_{data(x})$)로 생각해 본다면, 거기서 뽑은 sample $x$에 대한 expectation으로 표현 가능합니다. (2) 따라서 BCE loss 로 D 를 학습하는 과정은 (3) 으로 표현 할 수 있습니다
2) 이번에는 Fake image 가 들어갈 때 입니다.
(a, b) 대신에 $(D(G(z)), 0)$ 이 bce 의 input으로 들어갑니다. Loss는 오히려 앞쪽의 term이 사라지고 $-\frac{1}{n}\sum_i^n{log(1-D(G(z)))}$ 만 남게 됩니다. 이 Loss term 이 줄어든다는 뜻은 D가 fake image 를 잘 구별한다는 것 입니다. 그런데, G의 입장에서는 이 Loss를 오히려 크게 만들어야지 IDEA에 맞게 됩니다.
따라서 (4) 수식처럼 G는 loss 를 키우고 이것은 - 가 사라지는 대신 min 으로 바뀌게 됩니다.
위와 마찬가지로 이번에는 $P_z(z)$ 에서 sampling된 latent vector의 expectation 으로 사용할 수 있고(5),
결국 G는 이 loss * (-1) 를 min하는 방향으로 학습 하게 됩니다. (6)
(3) 식과 (6) 식을 차례로 optimize 한다면 GAN의 loss 와 같은 형태가 됩니다.
즉 BCE에서부터 D, G를 학습시키도록 loss 를 정해주면 됩니다. 😎😎😎😎😎
작성중..
댓글