Faster R-CNN 目標檢測體驗篇

  • 2019 年 12 月 2 日
  • 筆記

上次我們分享了目標檢測 One-Stage 的代表 YOLO,從體驗、理論到程式碼實戰。其實 One-Stage 還有一個代表是 SSD ,這個等到下一次我們再講解,因為 SSD 涉及到部分 Two-Stage 目標檢測的知識。

本期我們分享的是 Two-Stage 的代表作 Fater R-CNN ,這是屬於 R-CNN 系列中比較經典的一個,目前比較流行。今天我們就帶大家體驗一把 Faster R-CNN 的檢測,程式碼不多。

程式碼說明

我們程式碼使用的是 Pytorch 提供的目標檢測模型 fasterrcnn_resnet50_fpn

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

模型預測後得到的結果是

  • Bounding boxes [x0, y0, x1, y1] 邊框的四個值
  • Labels 所有預測的標籤
  • Scores 所有標籤的分數

以下就是本次內容的所有程式碼:

import torchvision  # 0.3.0  version  這裡指的是所使用包的版本  from torchvision import transforms as T  import cv2  # 4.1.1  version  import matplotlib.pyplot as plt  # 3.0.0  version  from PIL import Image  # 5.3.0  version  import random  import os  import torch    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True).to(device)  #  載入模型  model.eval()  # 設置成評估模式    # 定義 Pytorch 官方給的類別名稱,有些是 'N/A' 是已經去掉的類別  COCO_INSTANCE_CATEGORY_NAMES = [      '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',      'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',      'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',      'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',      'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',      'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',      'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',      'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',      'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',      'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',      'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',      'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'  ]        #  獲取單張圖片的預測結果  def get_prediction(img_path, threshold):      img = Image.open(img_path)  # Load the image  載入圖片      transform = T.Compose([T.ToTensor()]) # Defing PyTorch Transform      img = transform(img)  # Apply the transform to the image  轉換成 torch 形式      pred = model([img.to(device)])  # Pass the image to the model  開始推理      pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].cpu().numpy())]  # Get the Prediction Score  獲取預測的類別      pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().cpu().numpy())]  # Bounding boxes  獲取各個類別的邊框      pred_score = list(pred[0]['scores'].cpu().detach().numpy())  #  獲取各個類別的分數        # Get list of index with score greater than threshold.      pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]  #  判斷分數大於閾值對於的分數的最大索引      #  因為預測後的分數是從大到小排序的,只要找到大於閾值最後一個的索引值即可      pred_boxes = pred_boxes[:pred_t+1]      pred_class = pred_class[:pred_t+1]      return pred_boxes, pred_class      #  根據預測的結果繪製邊框及類別  def object_detection_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):      boxes, pred_cls = get_prediction(img_path, threshold)  # Get predictions      img = cv2.imread(img_path)  # Read image with cv2      img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB      result_dict = {}  #  用來保存每個類別的名稱及數量      for i in range(len(boxes)):          color = tuple(random.randint(0, 255) for i in range(3))          cv2.rectangle(img,                        boxes[i][0],                        boxes[i][1],                        color=color,                        thickness=rect_th)  # Draw Rectangle with the coordinates            cv2.putText(img,                      pred_cls[i],                      boxes[i][0],                      cv2.FONT_HERSHEY_SIMPLEX,                      text_size,                      color,                      thickness=text_th)  # Write the prediction class            #  將各個預測的結果保存到一個字典里          if pred_cls[i] not in result_dict:              result_dict[pred_cls[i]] = 1          else:              result_dict[pred_cls[i]] += 1          print(result_dict)      plt.figure(figsize=(20, 30))  # display the output image      plt.imshow(img)      plt.xticks([])      plt.yticks([])      plt.show()      if __name__ == "__main__":      object_detection_api('./people.jpg', threshold=0.5)

實驗效果

測試圖片

我們分別在 FasterR-CNN 和 YOLO 下進行測試,測試結果如下: 都存在漏檢的情況,這裡我們只是簡單的做個比較,等 Faster R-CNN 更新結束後,我們再統一來分析兩者的差別


Faster R-CNN 預測結果


YOLO 預測結果


大家如果不想找測試圖片的話,這裡給大家提供幾張

wget https://www.wsha.org/wp-content/uploads/banner-diverse-group-of-people-2.jpg -O people.jpg  wget https://hips.hearstapps.com/hmg-prod.s3.amazonaws.com/images/10best-cars-group-cropped-1542126037.jpg -O car.jpg  wget https://cdn.pixabay.com/photo/2013/07/05/01/08/traffic-143391_960_720.jpg -O traffic.jpg  wget https://images.unsplash.com/photo-1458169495136-854e4c39548a -O girl_cars.jpg