반응형
torch.nn.Unfold 는 다음과 같다.
batched tensor에 대하여 마치 convolution 처럼 sliding 움직이면서 그 local block을 구하는 것이다.
다음 사이트에서 자세한 설명을 볼 수 있다.
https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
예를들면 다음과 같다.
MNIST train 1번 그림인데, 여기에 Unfold 를 kernel=7, stride=7 padding=0 으로 주면
오른쪽 그림과 같이 kernel 기준으로 나뉘게 되고 하나의 local block 은 1(=c) x (kernel_w *kernel_h)
의 tensor 가 되고 이를 1차원으로 모은것이 된다. (kernel 은 n차원 가능)
import torch.nn as nn
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
if __name__ == '__main__':
train_data = datasets.MNIST(root='./data/',
train=True,
download=True,
transform=transforms.ToTensor())
test_data = datasets.MNIST(root='./data/',
train=False,
download=True,
transform=transforms.ToTensor())
image, label = train_data[0]
# image : 1, 28, 28
# make a batch
image = image.unsqueeze(0) # 1, 1, 28, 28
print(image.size())
unfold = nn.Unfold(kernel_size=7, stride=7, padding=0)
x = unfold(image)
print(x.size())
결과는 다음과 같다.
torch.Size([1, 1, 28, 28])
torch.Size([1, 49, 16])
이를 이용해서 kernel 을 2 x 2 로 두고 맨 각 kernel 왼쪽 위의 값만 ([0][0]) 가져와서
downsampling 등을 쉽게 할 수 있다.
import torch.nn as nn
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
if __name__ == '__main__':
train_data = datasets.MNIST(root='./data/',
train=True,
download=True,
transform=transforms.ToTensor())
test_data = datasets.MNIST(root='./data/',
train=False,
download=True,
transform=transforms.ToTensor())
image, label = train_data[0]
# image : 1, 28, 28
# make a batch
image = image.unsqueeze(0) # 1, 1, 28, 28
print(image.size())
unfold = nn.Unfold(kernel_size=2, stride=2, padding=0)
x = unfold(image)
print(x.size())
# 1, 4, 196
x = x[:, 0, ...].view(1, 1, 14, 14)
plt.figure("origin")
plt.imshow(image.squeeze().numpy(), cmap='gray')
plt.figure('down')
plt.imshow(x.squeeze().numpy(), cmap='gray')
plt.show()
반응형
'Pytorch' 카테고리의 다른 글
[Pytorch] torch.roll 설명 및 예제 (0) | 2023.04.10 |
---|---|
[Pytorch] RTX3060 window에서 최신 anaconda, 그래픽 드라이버, cuda11.7, cudnn, pytorch2.0 설치 (0) | 2023.03.29 |
[Pytorch] torch.tensor.repeat() 사용하기 (0) | 2023.03.08 |
[Pytorch] network parameter 갯수 확인 (0) | 2023.02.01 |
[Pytorch] torch.flatten() 사용하기 (1) | 2023.01.26 |
댓글