본문 바로가기
Pytorch

[Pytorch] torch.tensor.repeat() 사용하기

by pulluper 2023. 3. 8.
반응형

torch.repeat 이거 dim 늘릴때 참 좋은 함수이다. 

 

예를들어 anchor point를 다음과 같이 [100, 2] 개를 만들고 싶다. 

그림2. 100개의 anchor point

이때, 이를  batch (=8) 개 만큼 늘리고

각 100개 point를 3번씩 가지는 anchor points 를 만들고싶을때, repeat을 쓰면 좋다.

다음 코드의 결과는 [8, 300, 2] 이다. 

import torch

anchor_points = torch.rand(100, 2)
anchor_points = anchor_points.unsqueeze(0).repeat(8, 3, 1)
print(anchor_points.size())

 

repeat 은 차원이 안 맞아도 가능하다. repeat의 뒷쪽부터 맞추는 것 같다. 

그치만 헷갈리니 차원을 맞춰서 하는걸 추천

import torch

anchor_points = torch.rand(100, 2)
anchor_points = anchor_points.repeat(8, 3, 1)
print(anchor_points.size())
반응형

댓글