[Pytorch] torch.flatten() 사용하기
torch.flatten(t, s) 함수는 s 차원 이후에 평평하게 펴라는 뜻이다. torch.flatten(t, s, e) 이렇게 사용하면 t 차원부터 e차원까지만 평평하게 펴라는 뜻 아래는 예제 t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # t.size() is [2, 2, 2] torch.flatten(t, 0) tensor([1, 2, 3, 4, 5, 6, 7, 8]) torch.flatten(t, 1) tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) torch.flatten(t, 2) tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) torch.equal(t, torch.flatten(t, 2..
2023. 1. 26.
[Pytorch] 분류(classification)문제 에서 label 변환 (one-hot vs class)
bce를 사용하는등의 거나 여러 상황에서 class의 label을one-hot label로 혹은 class label 로 변환해야 하는 때가 있다. 1. one-hot label 에서 class label 로 변환 : torch.argmax(dim) shape 변환 (*, num_classes) 에서 (*) 로 바뀐다. import torch one_hot_label = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0]]) class_label = torch.argmax(one_hot_label, dim=-1) print(class_label) 정답 : tensor([0, 2]) 2. class label 에서 one-hot label로 변환 : torch.nn.functional..
2022. 12. 4.