반응형
torch.roll 설명
https://pytorch.org/docs/stable/generated/torch.roll.html
위의 그림과 같이 torch.roll 은 원하는 dim으로 tensor를 이동시키는 것이다. (굴린다고 표현)
dim기준으로 가장자리의 tensor의 요소들은 이동하기 때문에 생기는 빈 공간을 다시 채워준다.
다음 그림을 보면 이해가 확실하다.
다음은 위 그림의 코드이다.
torch.roll(image, (-3, -3), dim=(2, 3))
image 의 shape 은 [1, 1, 28, 28].
따라서 dim 2, 3이 각각 h, w를 뜻하고 -3씩 이동시킨것을 의미.
import torch
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
if __name__ == '__main__':
train_data = datasets.MNIST(root='./data/',
train=True,
download=True,
transform=transforms.ToTensor())
test_data = datasets.MNIST(root='./data/',
train=False,
download=True,
transform=transforms.ToTensor())
image, label = train_data[0]
# image : 1, 28, 28
# make a batch
image = image.unsqueeze(0) # 1, 1, 28, 28
plt.figure("origin")
# 1, 28, 28
plt.imshow(image.squeeze().numpy(), cmap='gray')
shifted_image = torch.roll(image, shifts=(-3, -3), dims=(2, 3))
plt.figure('roll')
plt.imshow(shifted_image.squeeze().numpy(), cmap='gray')
plt.show()
반응형
'Pytorch' 카테고리의 다른 글
[Pytorch] Remote server에서의 visdom설정 (0) | 2023.05.10 |
---|---|
[Pytorch] Distributed package으로 Multi-Node Multi-GPU 학습 알아보기 (3) | 2023.04.26 |
[Pytorch] RTX3060 window에서 최신 anaconda, 그래픽 드라이버, cuda11.7, cudnn, pytorch2.0 설치 (0) | 2023.03.29 |
[Pytorch] torch.nn.Unfold (0) | 2023.03.16 |
[Pytorch] torch.tensor.repeat() 사용하기 (0) | 2023.03.08 |
댓글