본문 바로가기
Network

[DNN] U-net 구조와 code 구현 (MICCAI2015)

by pulluper 2021. 5. 2.
반응형

안녕하세요 pulluper 입니다. :)

오늘은 u-net 구조와 unet 을 활용한 colorization 에 대하여 알아보겠습니다. 


U-net

U-net 이 처음에 제안된 논문은 medical 분야인 MICCAI 2015 학회에서 발표 되었으며 

"U-Net: Convolutional Networks for Biomedical Image Segmentation"

arxiv.org/abs/1505.04597

 

U-Net: Convolutional Networks for Biomedical Image Segmentation

There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated

arxiv.org

위와 같은 논문입니다. 이 네트워크의 구조는 image segmentation 을 위한 것이며, 이를 더 잘 하기 위해서 만든 구조입니다.  특징으로는 upsampling 구조에 더 많은 channel 이 있어 네트워크가 resolution 을 키울 때 도움을 준다는 점 입니다. 


구조

u-net structure

u-net 은 그림과 같이 u자형 형태로 되어 있으며, convolution 과 pooling 을 통해서 feature map 이 줄어드는 부분과 다시 upsampling 을 한 부분을 concatenation 을 하여 그 다음의 feature 로 넘겨주는 구조를 하고 있습니다.  


Code

 

각 층의 convolution 은 2개의 convolution 으로 되어있습니다.  이 2개의 conv 를 사용하는 모듈을 만들면 다음과 같습니다. 

class DoubleConv(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.double_conv = nn.Sequential(nn.Conv2d(nin, nout, 3, padding=1, stride=1),
                                         nn.BatchNorm2d(nout),
                                         nn.ReLU(inplace=True),
                                         nn.Conv2d(nout, nout, 3, padding=1, stride=1),
                                         nn.BatchNorm2d(nout),
                                         nn.ReLU(inplace=True)
                                         )

    def forward(self, x):
        return self.double_conv(x)

그리고 feature map 을 줄이는 down 부분에서는 다음과 같은 모듈을 사용 할 수 있습니다. 

class Down(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.down_conv = nn.Sequential(nn.MaxPool2d(2),
                                       DoubleConv(nin, nout))

    def forward(self, x):
        return self.down_conv(x)

그리고 feature map 을 늘리고 concat 을 하는 up 부분에서는 다음과 같은 모듈을 사용 할 수 있습니다. 

class Up(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.double_conv = DoubleConv(nin, nout)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # padding
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2))

        x = torch.cat([x2, x1], dim=1)
        x = self.double_conv(x)
        return x

그리고 마지막 output conv 에서는 다음과 같은 모듈을 사용합니다. 

class OutConv(nn.Module):
    def __init__(self, nin, nout):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

u-net model의 총 코드입니다. 

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.double_conv = nn.Sequential(nn.Conv2d(nin, nout, 3, padding=1, stride=1),
                                         nn.BatchNorm2d(nout),
                                         nn.ReLU(inplace=True),
                                         nn.Conv2d(nout, nout, 3, padding=1, stride=1),
                                         nn.BatchNorm2d(nout),
                                         nn.ReLU(inplace=True)
                                         )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.down_conv = nn.Sequential(nn.MaxPool2d(2),
                                       DoubleConv(nin, nout))

    def forward(self, x):
        return self.down_conv(x)


class Up(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.double_conv = DoubleConv(nin, nout)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # padding
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2))

        x = torch.cat([x2, x1], dim=1)
        x = self.double_conv(x)
        return x


class OutConv(nn.Module):
    def __init__(self, nin, nout):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.in_conv = DoubleConv(nin, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // 2)
        self.up1 = Up(1024, 512 // 2)
        self.up2 = Up(512, 256 // 2)
        self.up3 = Up(256, 128 // 2)
        self.up4 = Up(128, 64)
        self.out_conv = OutConv(64, nout)

    def forward(self, x):
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.out_conv(x)
        return x


if __name__ == "__main__":
    img = torch.rand([10, 1, 256, 256]).cuda()
    model = UNet(nin=1, nout=2).cuda()
    print(model.forward(img).size())  
    

네 이렇게 u-net 을 사용하면, input 과 output 의 resolution 이 같은 feature 를 사용할 수 있습니다. 

코드는 다음에서 확인 할 수 있습니다. 

 

github.com/csm-kr/dog_cat_colorization

 

csm-kr/dog_cat_colorization

Contribute to csm-kr/dog_cat_colorization development by creating an account on GitHub.

github.com

Reference

 

https://github.com/milesial/Pytorch-UNet

 

GitHub - milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images

PyTorch implementation of the U-Net for image semantic segmentation with high quality images - GitHub - milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation wi...

github.com

 

https://github.com/usuyama/pytorch-unet/blob/master/pytorch_unet.py

 

GitHub - usuyama/pytorch-unet: Simple PyTorch implementations of U-Net/FullyConvNet (FCN) for image segmentation

Simple PyTorch implementations of U-Net/FullyConvNet (FCN) for image segmentation - GitHub - usuyama/pytorch-unet: Simple PyTorch implementations of U-Net/FullyConvNet (FCN) for image segmentation

github.com

반응형

댓글