MMDeploy项目:如何为模型部署添加新模型支持

MMDeploy项目:如何为模型部署添加新模型支持

mmdeploy OpenMMLab Model Deployment Framework mmdeploy 项目地址: https://gitcode.com/gh_mirrors/mm/mmdeploy

前言

在深度学习模型部署过程中,经常会遇到模型无法直接转换为目标推理引擎格式的问题。MMDeploy提供了一套强大的工具集,帮助开发者解决模型转换中的各种兼容性问题。本文将详细介绍三种核心重写工具的使用方法,帮助开发者快速为MMDeploy添加新模型支持。

函数重写器(Function Rewriter)

基本概念

函数重写器是MMDeploy中最常用的工具之一,它允许开发者在模型转换过程中动态替换不支持的Python函数。这种机制类似于"猴子补丁"(monkey patch),但更加结构化且针对特定后端引擎。

使用场景

当遇到以下情况时,可以考虑使用函数重写器:

  1. PyTorch原生函数在目标推理引擎中不支持
  2. 第三方库函数需要特殊处理才能转换
  3. 需要优化某些函数的实现以提高推理效率

代码示例解析

from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
    func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(input, *size):
    ctx = FUNCTION_REWRITER.get_context()
    origin_func = ctx.origin_func
    if input.dim() == 1 and len(size) == 1:
        return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
    else:
        return origin_func(input, *size)

这个例子展示了如何为TensorRT后端重写torch.Tensor.repeat函数。当输入是一维张量且只有一个size参数时,会先将其扩展为二维再进行repeat操作,最后再压缩回一维。

关键参数说明

  • func_name: 需要重写的函数全名,可以是PyTorch内置函数或自定义函数
  • backend: 指定重写生效的后端引擎,不指定则作为默认实现
  • ctx: 上下文对象,提供原始函数、部署配置等有用信息

模块重写器(Module Rewriter)

基本概念

模块重写器用于替换整个神经网络模块,比函数重写器作用范围更大。当需要对模型中的某个子网络进行整体替换时,这个工具特别有用。

使用场景

  1. 整个模块在目标引擎中不被支持
  2. 需要优化模块实现以提高推理性能
  3. 模块中包含无法转换的控制流或动态结构

代码示例解析

@MODULE_REWRITER.register_rewrite_module(
    'mmagic.models.backbones.sr_backbones.SRCNN', backend='tensorrt')
class SRCNNWrapper(nn.Module):
    def __init__(self, module, cfg, channels=(3, 64, 32, 3), 
                 kernel_sizes=(9, 1, 5), upscale_factor=4):
        super(SRCNNWrapper, self).__init__()
        self._module = module
        module.img_upsampler = nn.Upsample(
            scale_factor=module.upscale_factor,
            mode='bilinear',
            align_corners=False)

这个例子展示了如何为SRCNN模型创建一个TensorRT兼容的包装器。它在初始化时修改了原始模块的上采样器实现。

关键参数说明

  • module_type: 需要重写的模块类全名
  • backend: 指定重写生效的后端引擎
  • 构造函数接收原始模块实例和部署配置作为前两个参数

符号重写器(Symbolic Rewriter)

基本概念

符号重写器专门用于处理PyTorch到ONNX的符号转换问题。它允许开发者自定义PyTorch操作到ONNX节点的映射规则。

使用场景

  1. PyTorch操作没有对应的ONNX实现
  2. 默认的ONNX转换结果不被目标推理引擎支持
  3. 需要优化ONNX图结构以提高推理效率

代码示例解析

@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(g, self, dim=None):
    if dim is None:
        dims = []
        for i, size in enumerate(self.type().sizes()):
            if size == 1:
                dims.append(i)
    else:
        dims = [sym_help._get_const(dim, 'i', 'dim')]
    return g.op('Squeeze', self, axes_i=dims)

这个例子展示了如何自定义squeeze操作的ONNX符号转换。它处理了dim参数为None时自动检测可压缩维度的情况。

关键参数说明

  • func_name: 需要重写的函数名
  • is_pytorch: 标记是否为PyTorch内置函数
  • arg_descriptors: 符号函数参数的描述符
  • g: ONNX图的构建器对象
  • self: 操作的输入张量

最佳实践

  1. 优先使用函数级重写:尽可能在函数级别解决问题,保持修改范围最小化
  2. 考虑多后端兼容性:为不同后端提供特定的重写实现
  3. 保留原始功能:在重写实现中尽量保留原始函数/模块的语义
  4. 充分测试:重写后应在目标引擎上验证数值精度和性能

总结

MMDeploy提供的三种重写工具覆盖了模型转换过程中的不同层次需求。理解这些工具的特点和适用场景,可以帮助开发者高效解决模型部署中的兼容性问题。在实际应用中,通常需要结合使用这些工具,才能完成复杂模型的完整转换。

通过本文的介绍,希望读者能够掌握为MMDeploy添加新模型支持的核心方法,在实际项目中灵活运用这些工具解决模型部署难题。

mmdeploy OpenMMLab Model Deployment Framework mmdeploy 项目地址: https://gitcode.com/gh_mirrors/mm/mmdeploy

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

解雁淞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值