MMDetectionのFaster R-CNNモデルは、設定ファイル(例: faster_rcnn_r50_fpn.py)内でtype='FasterRCNN'と指定され、対応する実装はmmdet/models/detectors/faster_rcnn.pyに存在する。
from ..builder import DETECTORS
from .two_stage import TwoStageDetector
@DETECTORS.register_module()
class FasterRCNN(TwoStageDetector):
"""Faster R-CNNの実装 (論文: https://arxiv.org/abs/1506.01497)"""
def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None):
super().__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
このクラスはTwoStageDetectorを継承しており、これは2段階型物体検出器の共通基盤を提供する。その定義はmmdet/models/detectors/two_stage.pyにある:
import torch
import torch.nn as nn
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector
@DETECTORS.register_module()
class TwoStageDetector(BaseDetector):
"""2段階型検出器のベースクラス。
RPNとRoIヘッドから構成される一般的なアーキテクチャをサポート。
"""
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super().__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
if rpn_head is not None:
rpn_train_cfg = train_cfg.rpn if train_cfg else None
rpn_cfg = rpn_head.copy()
rpn_cfg.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_cfg)
if roi_head is not None:
rcnn_train_cfg = train_cfg.rcnn if train_cfg else None
roi_head.update(train_cfg=rcnn_train_cfg, test_cfg=test_cfg.rcnn)
self.roi_head = build_head(roi_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
@property
def with_rpn(self):
return hasattr(self, 'rpn_head') and self.rpn_head is not None
@property
def with_roi_head(self):
return hasattr(self, 'roi_head') and self.roi_head is not None
@property
def with_neck(self):
return hasattr(self, 'neck') and self.neck is not None
def init_weights(self, pretrained=None):
super().init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for module in self.neck:
module.init_weights()
else:
self.neck.init_weights()
if self.with_rpn:
self.rpn_head.init_weights()
if self.with_roi_head:
self.roi_head.init_weights(pretrained)
def extract_feat(self, img):
feats = self.backbone(img)
if self.with_neck:
feats = self.neck(feats)
return feats
def forward_dummy(self, img):
outs = ()
feats = self.extract_feat(img)
if self.with_rpn:
rpn_outs = self.rpn_head(feats)
outs += (rpn_outs,)
dummy_proposals = torch.randn(1000, 4).to(img.device)
roi_outs = self.roi_head.forward_dummy(feats, dummy_proposals)
outs += (roi_outs,)
return outs
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
feats = self.extract_feat(img)
losses = {}
if self.with_rpn:
proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn)
rpn_losses, proposal_list = self.rpn_head.forward_train(
feats, img_metas, gt_bboxes, None, gt_bboxes_ignore, proposal_cfg)
losses.update(rpn_losses)
else:
proposal_list = proposals
roi_losses = self.roi_head.forward_train(
feats, img_metas, proposal_list, gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks, **kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, img, img_metas, proposals=None, rescale=False):
assert self.with_bbox, "BBoxヘッドが実装されていません"
feats = self.extract_feat(img)
if proposals is None:
proposals = self.rpn_head.simple_test_rpn(feats, img_metas)
return self.roi_head.simple_test(feats, proposals, img_metas, rescale)
async def async_simple_test(self, img, img_meta, proposals=None, rescale=False):
assert self.with_bbox
feats = self.extract_feat(img)
if proposals is None:
proposals = await self.rpn_head.async_simple_test_rpn(feats, img_meta)
return await self.roi_head.async_simple_test(feats, proposals, img_meta, rescale)
def aug_test(self, imgs, img_metas, rescale=False):
feats = self.extract_feats(imgs)
proposals = self.rpn_head.aug_test_rpn(feats, img_metas)
return self.roi_head.aug_test(feats, proposals, img_metas, rescale)
TwoStageDetectorはさらにBaseDetector(mmdet/models/detectors/base.pyで定義)を継承しており、MMDetectionにおけるすべての検出モデルの共通インターフェースを提供している。