본문 바로가기
Pytorch

[Pytorch] 분류(classification)문제 에서 label 변환 (one-hot vs class)

by pulluper 2022. 12. 4.
반응형

bce를 사용하는등의 거나 여러 상황에서 class의 label을one-hot label로 혹은 class label 로 변환해야 하는 때가 있다. 

 

1. one-hot label 에서  class label 로 변환 : torch.argmax(dim)

 

shape 변환 (*, num_classes) 에서 (*) 로 바뀐다. 

 

import torch

one_hot_label = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0]])
class_label = torch.argmax(one_hot_label, dim=-1)
print(class_label)

 

정답 : tensor([0, 2])

 

 

2. class label 에서 one-hot label로 변환 : torch.nn.functional.one_hot(num_classes)

 

shape 변환 (*) 에서 (*, num_classes) 로 바뀐다. 

 

one_hot_label_ = F.one_hot(class_label, num_classes=4)
print(one_hot_label_)

 

정답 : tensor([[1, 0, 0, 0], [0, 0, 1, 0]])

반응형

댓글