안녕하세요. pulluper 입니다. :)
최근 torchvision을 둘러보니, 2023년 5월 기준 version이 main (0.15.0a0+282001f) 까지 나왔습니다.

최신 version의 torchvision에서 제공하는 object detection은 다음과 같습니다.
(segmentation, keypoint detection 등도 존재..torchvision 짱.. 😍😘)

이 중에서 faster rcnn의 종류는 다음과 같습니다.
원래 사용한 backbone인 vgg16 은 없고,
resnet50_fpn과 mobilenet_v3_fpn에 대한 모델들이 있네요.

torchvision detection model의 output인 pred은 다음과 같은 형태를 갖습니다.

https://github.com/pytorch/vision/blob/main/torchvision/models/detection/roi_heads.py#L777
GitHub - pytorch/vision: Datasets, Transforms and Models specific to Computer Vision
Datasets, Transforms and Models specific to Computer Vision - GitHub - pytorch/vision: Datasets, Transforms and Models specific to Computer Vision
github.com
따라서 pred에서 batch를 제거하고 각 결과를 가진 dict을 이용해 쉽게 detection result를 뽑아 낼 수 있습니다. 예제는 축구를 하는 마라도나님를 이용 해 보겠습니다. ⚽️

다음은 demo 라는 함수로 torchvision fasterrcnn_resnet50_fpn model
을 이용해서 예측 박스, 레이블, 스코어를 가져오는 함수입니다.
pretrained는 COCO dataset 으로 되어있습니다. (classes = 91)
threshold는 기준 스코어로 이 점수를 넘는 박스만 출력하도록 합니다.
또한 이미지의 정규화가 object detection 에는 없는것이 특징입니다. (다음 이슈 참고)
https://github.com/pytorch/vision/issues/2397
Normalization for object detection · Issue #2397 · pytorch/vision
Migrated from discuss.pytorch.org. Requests were made by @mattans. 📚 Documentation The reference implementations for classification, segmentation, and video classification all use a normalization t...
github.com
def demo(img_path, threshold):
# 1. load image
img_pil = Image.open(img_path).convert('RGB')
transform = T.Compose([T.ToTensor()])
img = transform(img_pil)
batch_img = [img]
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
pred = model(batch_img)
# 2. remove first batch
pred_dict = pred[0]
'''
pred_dict
{'boxes' : tensor,
'labels' : tensor,
'scores' : tensor}
'''
# 3. get pred boxes and labels, scores
pred_boxes = pred_dict['boxes'] # [N, 1]
pred_labels = pred_dict['labels'] # [N]
pred_scores = pred_dict['scores'] # [N]
# 4. Get pred according to threshold
indices = pred_scores >= threshold
pred_boxes = pred_boxes[indices]
pred_labels = pred_labels[indices]
pred_scores = pred_scores[indices]
# 5. visualize
visualize_detection_result(img_pil, pred_boxes, pred_labels, pred_scores)
결과사진입니다. 😀

전체 데모 및 시각화 코드는 다음과 같습니다.
import cv2 | |
import torchvision | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as T | |
def visualize_detection_result(img_pil, boxes, labels, scores): | |
""" | |
img_pil : pil image range - [0 255], uint8 | |
boxes : torch.Tensor, [num_obj, 4], torch.float32 | |
labels : torch.Tensor, [num_obj] torch.int64 | |
scores : torch.Tensor, [num_obj] torch.float32 | |
""" | |
# 1. uint8 -> float32 | |
image_np = np.array(img_pil).astype(np.float32) / 255. | |
x_img = image_np | |
im_show = cv2.cvtColor(x_img, cv2.COLOR_RGB2BGR) | |
for j in range(len(boxes)): | |
label_list = list(coco_labels_map.keys()) | |
color_array = coco_colors_array | |
x_min = int(boxes[j][0]) | |
y_min = int(boxes[j][1]) | |
x_max = int(boxes[j][2]) | |
y_max = int(boxes[j][3]) | |
cv2.rectangle(im_show, | |
pt1=(x_min, y_min), | |
pt2=(x_max, y_max), | |
color=color_array[labels[j]], | |
thickness=2) | |
# text_size | |
text_size = cv2.getTextSize(text=label_list[labels[j]] + ' {:.2f}'.format(scores[j].item()), | |
fontFace=cv2.FONT_HERSHEY_PLAIN, | |
fontScale=1, | |
thickness=1)[0] | |
# text_rec | |
cv2.rectangle(im_show, | |
pt1=(x_min, y_min), | |
pt2=(x_min + text_size[0] + 3, y_min + text_size[1] + 4), | |
color=color_array[labels[j]], | |
thickness=-1) | |
# put text | |
cv2.putText(im_show, | |
text=label_list[labels[j]] + ' {:.2f}'.format(scores[j].item()), | |
org=(x_min + 10, y_min + 10), # must be int | |
fontFace=0, | |
fontScale=0.4, | |
color=(0, 0, 0)) | |
# cv2.imshow(...) : float values in the range [0, 1] | |
cv2.imshow('result', im_show) | |
cv2.waitKey(0) | |
# cv2.imwrite(...) : int values in the range [0, 255] | |
# im_show = im_show * 255 | |
# cv2.imwrite("result.png", im_show) | |
return 0 | |
def demo(img_path, threshold): | |
""" | |
demo faster rcnn | |
:param img_path: image path (default - soccer.png) | |
:param threshold: the threshold of object detection score (default - 0.9) | |
:return: None | |
""" | |
# 1. load image | |
img_pil = Image.open(img_path).convert('RGB') | |
transform = T.Compose([T.ToTensor()]) | |
img = transform(img_pil) | |
batch_img = [img] | |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | |
model.eval() | |
pred = model(batch_img) | |
# 2. remove first batch | |
pred_dict = pred[0] | |
''' | |
pred_dict | |
{'boxes' : tensor, | |
'labels' : tensor, | |
'scores' : tensor} | |
''' | |
# 3. get pred boxes and labels, scores | |
pred_boxes = pred_dict['boxes'] # [N, 1] | |
pred_labels = pred_dict['labels'] # [N] | |
pred_scores = pred_dict['scores'] # [N] | |
# 4. Get pred according to threshold | |
indices = pred_scores >= threshold | |
pred_boxes = pred_boxes[indices] | |
pred_labels = pred_labels[indices] | |
pred_scores = pred_scores[indices] | |
# 5. visualize | |
visualize_detection_result(img_pil, pred_boxes, pred_labels, pred_scores) | |
if __name__ == '__main__': | |
coco_labels_list = [ | |
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', | |
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', | |
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | |
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', | |
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', | |
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' | |
] | |
coco_labels_map = {k: v for v, k in enumerate(coco_labels_list)} | |
np.random.seed(1) | |
coco_colors_array = np.random.randint(256, size=(91, 3)) / 255 | |
# demo | |
demo('./soccer.png', threshold=0.9) |
다음 깃헙에서 전체 코드 및 사진을 참조 하실수 있습니다. 🥰
↓ 구현 코드 ↓
https://github.com/csm-kr/torchvision_fasterrcnn_tutorial/blob/master/demo.py
감사합니다.
Reference
https://github.com/spmallick/learnopencv/blob/master/PyTorch-Mask-RCNN/PyTorch_Mask_RCNN.ipynb
https://github.com/csm-kr/retinanet_pytorch/blob/master/demo.py
'Object Detection' 카테고리의 다른 글
도커환경에서 ultralytics yolo v11 학습 및 검증 (0) | 2025.01.16 |
---|
댓글