반응형
torch.repeat 이거 dim 늘릴때 참 좋은 함수이다.
예를들어 anchor point를 다음과 같이 [100, 2] 개를 만들고 싶다.
이때, 이를 batch (=8) 개 만큼 늘리고
각 100개 point를 3번씩 가지는 anchor points 를 만들고싶을때, repeat을 쓰면 좋다.
다음 코드의 결과는 [8, 300, 2] 이다.
import torch
anchor_points = torch.rand(100, 2)
anchor_points = anchor_points.unsqueeze(0).repeat(8, 3, 1)
print(anchor_points.size())
repeat 은 차원이 안 맞아도 가능하다. repeat의 뒷쪽부터 맞추는 것 같다.
그치만 헷갈리니 차원을 맞춰서 하는걸 추천
import torch
anchor_points = torch.rand(100, 2)
anchor_points = anchor_points.repeat(8, 3, 1)
print(anchor_points.size())
반응형
'Pytorch' 카테고리의 다른 글
[Pytorch] RTX3060 window에서 최신 anaconda, 그래픽 드라이버, cuda11.7, cudnn, pytorch2.0 설치 (0) | 2023.03.29 |
---|---|
[Pytorch] torch.nn.Unfold (0) | 2023.03.16 |
[Pytorch] network parameter 갯수 확인 (0) | 2023.02.01 |
[Pytorch] torch.flatten() 사용하기 (1) | 2023.01.26 |
[Pytorch] 분류(classification)문제 에서 label 변환 (one-hot vs class) (0) | 2022.12.04 |
댓글