본문 바로가기

Pytorch15

[Pytorch] torch.tensor.repeat() 사용하기 torch.repeat 이거 dim 늘릴때 참 좋은 함수이다. 예를들어 anchor point를 다음과 같이 [100, 2] 개를 만들고 싶다. 이때, 이를 batch (=8) 개 만큼 늘리고 각 100개 point를 3번씩 가지는 anchor points 를 만들고싶을때, repeat을 쓰면 좋다. 다음 코드의 결과는 [8, 300, 2] 이다. import torch anchor_points = torch.rand(100, 2) anchor_points = anchor_points.unsqueeze(0).repeat(8, 3, 1) print(anchor_points.size()) repeat 은 차원이 안 맞아도 가능하다. repeat의 뒷쪽부터 맞추는 것 같다. 그치만 헷갈리니 차원을 맞춰서 하는.. 2023. 3. 8.
[Pytorch] network parameter 갯수 확인 1. print(sum(p.numel() for p in model.parameters() if p.requires_grad)) 예를들어 다양한 모델에 대하여 from torchvision.models import * if __name__ == '__main__': model = vgg11() print("vgg11 : ", sum(p.numel() for p in model.parameters() if p.requires_grad)) model = vgg13() print("vgg13 : ", sum(p.numel() for p in model.parameters() if p.requires_grad)) model = vgg16() print("vgg16 : ", sum(p.numel() for p i.. 2023. 2. 1.
[Pytorch] torch.flatten() 사용하기 torch.flatten(t, s) 함수는 s 차원 이후에 평평하게 펴라는 뜻이다. torch.flatten(t, s, e) 이렇게 사용하면 t 차원부터 e차원까지만 평평하게 펴라는 뜻 아래는 예제 t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # t.size() is [2, 2, 2] torch.flatten(t, 0) tensor([1, 2, 3, 4, 5, 6, 7, 8]) torch.flatten(t, 1) tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) torch.flatten(t, 2) tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) torch.equal(t, torch.flatten(t, 2.. 2023. 1. 26.
[Pytorch] 분류(classification)문제 에서 label 변환 (one-hot vs class) bce를 사용하는등의 거나 여러 상황에서 class의 label을one-hot label로 혹은 class label 로 변환해야 하는 때가 있다. 1. one-hot label 에서 class label 로 변환 : torch.argmax(dim) shape 변환 (*, num_classes) 에서 (*) 로 바뀐다. import torch one_hot_label = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0]]) class_label = torch.argmax(one_hot_label, dim=-1) print(class_label) 정답 : tensor([0, 2]) 2. class label 에서 one-hot label로 변환 : torch.nn.functional.. 2022. 12. 4.
[Pytorch] pytorch 에서 np.where 처럼 index 가져오기 a = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) np.where (iou == IoU_max_per_object).nonzero() C:\Users\csm81\Desktop\projects_3 (detection)\Faster_RCNN_Pytorch\model\target_builder.py:1: UserWarning: This overload of nonzero is deprecated: nonzero() Consider using one of the following signatures instead: nonzero(*, bool as_tuple) (Triggered internally at ..\torch\csrc\utils\python_arg_parser.cpp:882.. 2022. 8. 17.
[Pytorch] Distributed package 를 이용한 분산학습으로 Multi-GPU 효율적으로 사용하기 안녕하세요 pulluper 입니다 😁😁 이번 포스팅에서는 pytorch 의 분산(distributed) pakage를 이용해서 multi-gpu 를 모두 효율적으로 사용하는 방법을 알아보겠습니다. 이번 포스팅의 목차는 다음과 같습니다. 1. 용어 2. init 3. dataset 4. distributed data-parallel 5. train 6. 실행 7. CIFAR 10 example 1. 용어 (terminology) 먼저 pytorch.distributed 를 이용하는 것은 멀티프로세스방법을 이용하는 것 입니다. 정확하게는 여러 process를 이용해병렬적으로 연산을 수행하여, 각 프로세스가 효율적으로 gpu를 사용하고 모으는것을 할 수 가 있습니다. 이를 이용하기 위해 이해해야 할 몇가지 용.. 2022. 6. 15.