안녕하세요 pulluper 입니다. :)
오늘은 u-net 구조와 unet 을 활용한 colorization 에 대하여 알아보겠습니다.
U-net
U-net 이 처음에 제안된 논문은 medical 분야인 MICCAI 2015 학회에서 발표 되었으며
"U-Net: Convolutional Networks for Biomedical Image Segmentation"
위와 같은 논문입니다. 이 네트워크의 구조는 image segmentation 을 위한 것이며, 이를 더 잘 하기 위해서 만든 구조입니다. 특징으로는 upsampling 구조에 더 많은 channel 이 있어 네트워크가 resolution 을 키울 때 도움을 준다는 점 입니다.
구조
u-net 은 그림과 같이 u자형 형태로 되어 있으며, convolution 과 pooling 을 통해서 feature map 이 줄어드는 부분과 다시 upsampling 을 한 부분을 concatenation 을 하여 그 다음의 feature 로 넘겨주는 구조를 하고 있습니다.
Code
각 층의 convolution 은 2개의 convolution 으로 되어있습니다. 이 2개의 conv 를 사용하는 모듈을 만들면 다음과 같습니다.
class DoubleConv(nn.Module):
def __init__(self, nin, nout):
super().__init__()
self.double_conv = nn.Sequential(nn.Conv2d(nin, nout, 3, padding=1, stride=1),
nn.BatchNorm2d(nout),
nn.ReLU(inplace=True),
nn.Conv2d(nout, nout, 3, padding=1, stride=1),
nn.BatchNorm2d(nout),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
그리고 feature map 을 줄이는 down 부분에서는 다음과 같은 모듈을 사용 할 수 있습니다.
class Down(nn.Module):
def __init__(self, nin, nout):
super().__init__()
self.down_conv = nn.Sequential(nn.MaxPool2d(2),
DoubleConv(nin, nout))
def forward(self, x):
return self.down_conv(x)
그리고 feature map 을 늘리고 concat 을 하는 up 부분에서는 다음과 같은 모듈을 사용 할 수 있습니다.
class Up(nn.Module):
def __init__(self, nin, nout):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.double_conv = DoubleConv(nin, nout)
def forward(self, x1, x2):
x1 = self.up(x1)
# padding
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2))
x = torch.cat([x2, x1], dim=1)
x = self.double_conv(x)
return x
그리고 마지막 output conv 에서는 다음과 같은 모듈을 사용합니다.
class OutConv(nn.Module):
def __init__(self, nin, nout):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(nin, nout, kernel_size=1)
def forward(self, x):
return self.conv(x)
u-net model의 총 코드입니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def __init__(self, nin, nout):
super().__init__()
self.double_conv = nn.Sequential(nn.Conv2d(nin, nout, 3, padding=1, stride=1),
nn.BatchNorm2d(nout),
nn.ReLU(inplace=True),
nn.Conv2d(nout, nout, 3, padding=1, stride=1),
nn.BatchNorm2d(nout),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
def __init__(self, nin, nout):
super().__init__()
self.down_conv = nn.Sequential(nn.MaxPool2d(2),
DoubleConv(nin, nout))
def forward(self, x):
return self.down_conv(x)
class Up(nn.Module):
def __init__(self, nin, nout):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.double_conv = DoubleConv(nin, nout)
def forward(self, x1, x2):
x1 = self.up(x1)
# padding
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2))
x = torch.cat([x2, x1], dim=1)
x = self.double_conv(x)
return x
class OutConv(nn.Module):
def __init__(self, nin, nout):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(nin, nout, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, nin, nout):
super().__init__()
self.in_conv = DoubleConv(nin, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024 // 2)
self.up1 = Up(1024, 512 // 2)
self.up2 = Up(512, 256 // 2)
self.up3 = Up(256, 128 // 2)
self.up4 = Up(128, 64)
self.out_conv = OutConv(64, nout)
def forward(self, x):
x1 = self.in_conv(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.out_conv(x)
return x
if __name__ == "__main__":
img = torch.rand([10, 1, 256, 256]).cuda()
model = UNet(nin=1, nout=2).cuda()
print(model.forward(img).size())
네 이렇게 u-net 을 사용하면, input 과 output 의 resolution 이 같은 feature 를 사용할 수 있습니다.
코드는 다음에서 확인 할 수 있습니다.
github.com/csm-kr/dog_cat_colorization
Reference
https://github.com/milesial/Pytorch-UNet
https://github.com/usuyama/pytorch-unet/blob/master/pytorch_unet.py
'Network' 카테고리의 다른 글
[DNN] VGG구현을 위한 리뷰(ICLR 2015) (0) | 2022.07.27 |
---|---|
[DNN] ViT 구조 (0) | 2022.06.08 |
[DNN] Alexnet 리뷰 및 구현 (NIPS 2012) (0) | 2021.09.24 |
[DNN] Densenet 논문리뷰 및 구현 (CVPR2017) (0) | 2021.04.13 |
ILSVRC(Imagenet classification)validation set torchvision 으로 성능평가하기 (12) | 2021.03.01 |
댓글