본문 바로가기
Pytorch

[Pytorch] pytorch 에서 np.where 처럼 index 가져오기

by pulluper 2022. 8. 17.
반응형

a = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])

 

np.where

 

(iou == IoU_max_per_object).nonzero()
C:\Users\csm81\Desktop\projects_3 (detection)\Faster_RCNN_Pytorch\model\target_builder.py:1: UserWarning: This overload of nonzero is deprecated:
nonzero()
Consider using one of the following signatures instead:
nonzero(*, bool as_tuple) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:882.)
  import torch
Out[4]: 
tensor([[4253,    2],
        [4887,    0],
        [5389,    1],
        [5666,    1]], device='cuda:0')
(iou == IoU_max_per_object).nonzero()[0]
Out[5]: tensor([4253,    2], device='cuda:0')
(iou == IoU_max_per_object).nonzero()[:, 0]

반응형

댓글