본문 바로가기
Pytorch

[Pytorch] torch.roll 설명 및 예제

by pulluper 2023. 4. 10.
반응형

torch.roll 설명

https://pytorch.org/docs/stable/generated/torch.roll.html

 

torch.roll — PyTorch 2.0 documentation

Shortcuts

pytorch.org

 

위의 그림과 같이 torch.roll 은 원하는 dim으로 tensor를 이동시키는 것이다.  (굴린다고 표현)

dim기준으로 가장자리의 tensor의 요소들은 이동하기 때문에 생기는 빈 공간을 다시 채워준다. 

 

다음 그림을 보면 이해가 확실하다. 

 

다음은 위 그림의 코드이다. 

 

torch.roll(image, (-3, -3), dim=(2, 3))

image 의 shape 은 [1, 1, 28, 28].

따라서 dim 2, 3이 각각 h, w를 뜻하고 -3씩 이동시킨것을 의미. 

 

import torch
from matplotlib import pyplot as plt
from torchvision import datasets, transforms


if __name__ == '__main__':
    train_data = datasets.MNIST(root='./data/',
                                train=True,
                                download=True,
                                transform=transforms.ToTensor())

    test_data = datasets.MNIST(root='./data/',
                               train=False,
                               download=True,
                               transform=transforms.ToTensor())

    image, label = train_data[0]
    # image : 1, 28, 28
    # make a batch
    image = image.unsqueeze(0)  # 1, 1, 28, 28

    plt.figure("origin")
    # 1, 28, 28
    plt.imshow(image.squeeze().numpy(), cmap='gray')

    shifted_image = torch.roll(image, shifts=(-3, -3), dims=(2, 3))
    plt.figure('roll')
    plt.imshow(shifted_image.squeeze().numpy(), cmap='gray')
    plt.show()
반응형

댓글