본문 바로가기
Pytorch

[Pytorch] torch.flatten() 사용하기

by pulluper 2023. 1. 26.
반응형

그림1. torch.flatten()

torch.flatten(t, s) 함수는 s 차원 이후에 평평하게 펴라는 뜻이다. 

torch.flatten(t, s, e) 이렇게 사용하면 t 차원부터 e차원까지만 평평하게 펴라는 뜻

 

아래는 예제

 

t = torch.tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]]])
                   
# t.size() is [2, 2, 2]      

torch.flatten(t, 0)
tensor([1, 2, 3, 4, 5, 6, 7, 8])

torch.flatten(t, 1)
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])
        
torch.flatten(t, 2)
tensor([[[1, 2],
         [3, 4]],
        [[5, 6],
         [7, 8]]])

torch.equal(t, torch.flatten(t, 2)) 
#True

torch.flatten(t[:, :, :, None], 0, 2)
tensor([[1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8]])
torch.flatten(t[:, :, :, None], 0, 2).size()
torch.Size([8, 1])

torch.flatten(t[:, :, :, None], 0, 3)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
torch.flatten(t[:, :, :, None], 0, 3).size()
torch.Size([8])

 

반응형

댓글