[Pytorch] torch.flatten() 사용하기
torch.flatten(t, s) 함수는 s 차원 이후에 평평하게 펴라는 뜻이다. torch.flatten(t, s, e) 이렇게 사용하면 t 차원부터 e차원까지만 평평하게 펴라는 뜻 아래는 예제 t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # t.size() is [2, 2, 2] torch.flatten(t, 0) tensor([1, 2, 3, 4, 5, 6, 7, 8]) torch.flatten(t, 1) tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) torch.flatten(t, 2) tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) torch.equal(t, torch.flatten(t, 2..
2023. 1. 26.