반응형
bce를 사용하는등의 거나 여러 상황에서 class의 label을one-hot label로 혹은 class label 로 변환해야 하는 때가 있다.
1. one-hot label 에서 class label 로 변환 : torch.argmax(dim)
shape 변환 (*, num_classes) 에서 (*) 로 바뀐다.
import torch
one_hot_label = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0]])
class_label = torch.argmax(one_hot_label, dim=-1)
print(class_label)
정답 : tensor([0, 2])
2. class label 에서 one-hot label로 변환 : torch.nn.functional.one_hot(num_classes)
shape 변환 (*) 에서 (*, num_classes) 로 바뀐다.
one_hot_label_ = F.one_hot(class_label, num_classes=4)
print(one_hot_label_)
정답 : tensor([[1, 0, 0, 0], [0, 0, 1, 0]])
반응형
'Pytorch' 카테고리의 다른 글
[Pytorch] network parameter 갯수 확인 (0) | 2023.02.01 |
---|---|
[Pytorch] torch.flatten() 사용하기 (1) | 2023.01.26 |
[Pytorch] pytorch 에서 np.where 처럼 index 가져오기 (0) | 2022.08.17 |
[Pytorch] Distributed package 를 이용한 분산학습으로 Multi-GPU 효율적으로 사용하기 (4) | 2022.06.15 |
[Pytorch] PIL, cv2, pytorch 이미지 처리 library 비교 (2) | 2022.04.11 |
댓글