mmdetection2.6利用自定義數據集訓練模型
mmdetection2.6利用自定義數據集訓練模型
在這篇文章中,介紹如何採用自定義數據集來訓練,測試,推理預定義的模型。
基本的步驟如下:
- 準備自定義數據集
- 準備config文件
- 在自定義數據集上訓練,測試,推理模型
1. 準備自定義數據集
mmdetection支援三種方法來自定義數據集:
- 將數據集組織為coco格式
- 將數據集組織為middle格式
- 應用一個新的數據集
前兩種方式更簡單,更容易操作一點
在這篇文章中,提供一個將自己數據集轉換為coco數據集格式的例子。
**注意: ** 對於評估mask AP這種勢力分割任務, mmdetection 僅僅支援COCO格式的數據集,因此如果進行實例分割就需要將數據集轉換為coco格式的數據集。
1.1 coco標註格式
對於實例分割任務, coco格式的數據集必要的鍵為:
{
"images": [image],
"annotations": [annotation],
"categories": [category]
}
image = {
"id": int,
"width": int,
"height": int,
"file_name": str,
}
annotation = {
"id": int,
"image_id": int,
"category_id": int,
"segmentation": RLE or [polygon],
"area": float,
"bbox": [x,y,width,height],
"iscrowd": 0 or 1,
}
categories = [{
"id": int,
"name": str,
"supercategory": str,
}]
需要將自己的數據集轉換為coco格式的標註資訊,然後利用coco格式的數據集load數據,使用CocoDataset來訓練與評估模型。
2. 準備config文件
假設採用Mask RCNN+FPN這種結構來訓練檢測器。假設這個config文件位於configs/ballon
中,命名為mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py
config文件的內容如下:
# The new config inherits a base config to highlight the necessary modification
_base_ = 'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_coco.py'
# We also need to change the num_classes in head to match the dataset's annotation
model = dict(
roi_head=dict(
bbox_head=dict(num_classes=1),
mask_head=dict(num_classes=1)))
# Modify dataset related settings
dataset_type = 'COCODataset'
classes = ('balloon',)
data = dict(
train=dict(
img_prefix='balloon/train/',
classes=classes,
ann_file='balloon/train/annotation_coco.json'),
val=dict(
img_prefix='balloon/val/',
classes=classes,
ann_file='balloon/val/annotation_coco.json'),
test=dict(
img_prefix='balloon/val/',
classes=classes,
ann_file='balloon/val/annotation_coco.json'))
# We can use the pre-trained Mask RCNN model to obtain higher performance
load_from = 'checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'
3. 訓練測試 推理模型
訓練
python tools/train.py configs/ballon/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py
測試
python tools/test.py configs/ballon/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py work_dirs/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py/latest.pth --eval bbox segm