본문 바로가기
Pytorch

[Pytorch] torch.nn.Unfold

by pulluper 2023. 3. 16.
반응형

torch.nn.Unfold 는 다음과 같다. 

batched tensor에 대하여 마치 convolution 처럼 sliding 움직이면서 그 local block을 구하는 것이다. 

 

다음 사이트에서 자세한 설명을 볼 수 있다. 

https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html

 

Unfold — PyTorch 2.0 documentation

Shortcuts

pytorch.org

 

torch.nn.Unfold

예를들면 다음과 같다.

MNIST train 1번 그림인데, 여기에 Unfold 를 kernel=7, stride=7 padding=0 으로 주면  

오른쪽 그림과 같이 kernel 기준으로 나뉘게 되고 하나의 local block 은 1(=c) x (kernel_w *kernel_h)

의 tensor 가 되고 이를 1차원으로 모은것이 된다. (kernel 은 n차원 가능)

import torch.nn as nn
from matplotlib import pyplot as plt
from torchvision import datasets, transforms


if __name__ == '__main__':
    train_data = datasets.MNIST(root='./data/',
                                train=True,
                                download=True,
                                transform=transforms.ToTensor())

    test_data = datasets.MNIST(root='./data/',
                               train=False,
                               download=True,
                               transform=transforms.ToTensor())

    image, label = train_data[0]
    # image : 1, 28, 28
    # make a batch
    image = image.unsqueeze(0)  # 1, 1, 28, 28
    print(image.size())

    unfold = nn.Unfold(kernel_size=7, stride=7, padding=0)
    x = unfold(image)
    print(x.size())

 

결과는 다음과 같다. 

 

torch.Size([1, 1, 28, 28])
torch.Size([1, 49, 16])

 

이를 이용해서 kernel 을 2 x 2 로 두고 맨 각 kernel 왼쪽 위의 값만 ([0][0]) 가져와서 

downsampling 등을 쉽게 할 수 있다. 

kernel=2, stride=2 and view

 

 

import torch.nn as nn
from matplotlib import pyplot as plt
from torchvision import datasets, transforms


if __name__ == '__main__':
    train_data = datasets.MNIST(root='./data/',
                                train=True,
                                download=True,
                                transform=transforms.ToTensor())

    test_data = datasets.MNIST(root='./data/',
                               train=False,
                               download=True,
                               transform=transforms.ToTensor())

    image, label = train_data[0]
    # image : 1, 28, 28
    # make a batch
    image = image.unsqueeze(0)  # 1, 1, 28, 28
    print(image.size())
    unfold = nn.Unfold(kernel_size=2, stride=2, padding=0)
    x = unfold(image)
    print(x.size())
    # 1, 4, 196
    x = x[:, 0, ...].view(1, 1, 14, 14)
    plt.figure("origin")
    plt.imshow(image.squeeze().numpy(), cmap='gray')
    plt.figure('down')
    plt.imshow(x.squeeze().numpy(), cmap='gray')
    plt.show()

 

반응형

댓글