# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, build_upsample_layer
from mmcv.ops.carafe import CARAFEPack
from mmengine.model import BaseModule, ModuleList, xavier_init
from mmdet.registry import MODELS
@MODELS.register_module()
class FPN_CARAFE(BaseModule):
"""FPN_CARAFE is a more flexible implementation of FPN. It allows more
choice for upsample methods during the top-down pathway.
It can reproduce the performance of ICCV 2019 paper
CARAFE: Content-Aware ReAssembly of FEatures
Please refer to https://arxiv.org/abs/1905.02188 for more details.
Args:
in_channels (list[int]): Number of channels for each input feature map.
out_channels (int): Output channels of feature pyramids.
num_outs (int): Number of output stages.
start_level (int): Start level of feature pyramids.
(Default: 0)
end_level (int): End level of feature pyramids.
(Default: -1 indicates the last level).
norm_cfg (dict): Dictionary to construct and config norm layer.
activate (str): Type of activation function in ConvModule
(Default: None indicates w/o activation).
order (dict): Order of components in ConvModule.
upsample (str): Type of upsample layer.
upsample_cfg (dict): Dictionary to construct and config upsample layer.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
norm_cfg=None,
act_cfg=None,
order=('conv', 'norm', 'act'),
upsample_cfg=dict(
type='carafe',
up_kernel=5,
up_group=1,
encoder_kernel=3,
encoder_dilation=1),
init_cfg=None):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(FPN_CARAFE, self).__init__(init_cfg)
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.with_bias = norm_cfg is None
self.upsample_cfg = upsample_cfg.copy()
self.upsample = self.upsample_cfg.get('type')
self.relu = nn.ReLU(inplace=False)
self.order = order
assert order in [('conv', 'norm', 'act'), ('act', 'conv', 'norm')]
assert self.upsample in [
'nearest', 'bilinear', 'deconv', 'pixel_shuffle', 'carafe', None
]
if self.upsample in ['deconv', 'pixel_shuffle']:
assert hasattr(
self.upsample_cfg,
'upsample_kernel') and self.upsample_cfg.upsample_kernel > 0
self.upsample_kernel = self.upsample_cfg.pop('upsample_kernel')
if end_level == -1 or end_level == self.num_ins - 1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level is not the last level, no extra level is allowed
self.backbone_end_level = end_level + 1
assert end_level < self.num_ins
assert num_outs == end_level - start_level + 1
self.start_level = start_level
self.end_level = end_level
self.lateral_convs = ModuleList()
self.fpn_convs = ModuleList()
self.upsample_modules = ModuleList()
for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
norm_cfg=norm_cfg,
bias=self.with_bias,
act_cfg=act_cfg,
inplace=False,
order=self.order)
fpn_conv = ConvModule(
out_channels,
out_channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
bias=self.with_bias,
act_cfg=act_cfg,
inplace=False,
order=self.order)
if i != self.backbone_end_level - 1:
upsample_cfg_ = self.upsample_cfg.copy()
if self.upsample == 'deconv':
upsample_cfg_.update(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=self.upsample_kernel,
stride=2,
padding=(self.upsample_kernel - 1) // 2,
output_padding=(self.upsample_kernel - 1) // 2)
elif self.upsample == 'pixel_shuffle':
upsample_cfg_.update(
in_channels=out_channels,
out_channels=out_channels,
scale_factor=2,
upsample_kernel=self.upsample_kernel)
elif self.upsample == 'carafe':
upsample_cfg_.update(channels=out_channels, scale_factor=2)
else:
# suppress warnings
align_corners = (None
if self.upsample == 'nearest' else False)
upsample_cfg_.update(
scale_factor=2,
mode=self.upsample,
align_corners=align_corners)
upsample_module = build_upsample_layer(upsample_cfg_)
self.upsample_modules.append(upsample_module)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
# add extra conv layers (e.g., RetinaNet)
extra_out_levels = (
num_outs - self.backbone_end_level + self.start_level)
if extra_out_levels >= 1:
for i in range(extra_out_levels):
in_channels = (
self.in_channels[self.backbone_end_level -
1] if i == 0 else out_channels)
extra_l_conv = ConvModule(
in_channels,
out_channels,
3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
bias=self.with_bias,
act_cfg=act_cfg,
inplace=False,
order=self.order)
if self.upsample == 'deconv':
upsampler_cfg_ = dict(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=self.upsample_kernel,
stride=2,
padding=(self.upsample_kernel - 1) // 2,
output_padding=(self.upsample_kernel - 1) // 2)
elif self.upsample == 'pixel_shuffle':
upsampler_cfg_ = dict(
in_channels=out_channels,
out_channels=out_channels,
scale_factor=2,
upsample_kernel=self.upsample_kernel)
elif self.upsample == 'carafe':
upsampler_cfg_ = dict(
channels=out_channels,
scale_factor=2,
**self.upsample_cfg)
else:
# suppress warnings
align_corners = (None
if self.upsample == 'nearest' else False)
upsampler_cfg_ = dict(
scale_