mmdetection2.6利用自定义数据集训练模型
mmdetection2.6利用自定义数据集训练模型
在这篇文章中,介绍如何采用自定义数据集来训练,测试,推理预定义的模型。
基本的步骤如下:
- 准备自定义数据集
- 准备config文件
- 在自定义数据集上训练,测试,推理模型
1. 准备自定义数据集
mmdetection支持三种方法来自定义数据集:
- 将数据集组织为coco格式
- 将数据集组织为middle格式
- 应用一个新的数据集
前两种方式更简单,更容易操作一点
在这篇文章中,提供一个将自己数据集转换为coco数据集格式的例子。
**注意: ** 对于评估mask AP这种势力分割任务, mmdetection 仅仅支持COCO格式的数据集,因此如果进行实例分割就需要将数据集转换为coco格式的数据集。
1.1 coco标注格式
对于实例分割任务, coco格式的数据集必要的键为:
{
"images": [image],
"annotations": [annotation],
"categories": [category]
}
image = {
"id": int,
"width": int,
"height": int,
"file_name": str,
}
annotation = {
"id": int,
"image_id": int,
"category_id": int,
"segmentation": RLE or [polygon],
"area": float,
"bbox": [x,y,width,height],
"iscrowd": 0 or 1,
}
categories = [{
"id": int,
"name": str,
"supercategory": str,
}]
需要将自己的数据集转换为coco格式的标注信息,然后利用coco格式的数据集load数据,使用CocoDataset来训练与评估模型。
2. 准备config文件
假设采用Mask RCNN+FPN这种结构来训练检测器。假设这个config文件位于configs/ballon
中,命名为mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py
config文件的内容如下:
# The new config inherits a base config to highlight the necessary modification
_base_ = 'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_coco.py'
# We also need to change the num_classes in head to match the dataset's annotation
model = dict(
roi_head=dict(
bbox_head=dict(num_classes=1),
mask_head=dict(num_classes=1)))
# Modify dataset related settings
dataset_type = 'COCODataset'
classes = ('balloon',)
data = dict(
train=dict(
img_prefix='balloon/train/',
classes=classes,
ann_file='balloon/train/annotation_coco.json'),
val=dict(
img_prefix='balloon/val/',
classes=classes,
ann_file='balloon/val/annotation_coco.json'),
test=dict(
img_prefix='balloon/val/',
classes=classes,
ann_file='balloon/val/annotation_coco.json'))
# We can use the pre-trained Mask RCNN model to obtain higher performance
load_from = 'checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'
3. 训练测试 推理模型
训练
python tools/train.py configs/ballon/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py
测试
python tools/test.py configs/ballon/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py work_dirs/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py/latest.pth --eval bbox segm