python将txt转为json_将maskrcnn_benchmark的python模型转为C++ libtorch的模型

本文介绍如何将Python的Mask R-CNN模型转换为C++可用的模型,包括模型权重的保存与加载,以及如何使用libtorch进行模型整合。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

8f583023363ad800ed53c79233afcada.png

深度学习需要巨大的计算量,对于开发原型可以用python,matlab比较方便,但是真正实用化的话,c++作为底层语言明显更具优势。

目前大部分的优秀的深度学习项目仍然是以python为主流语言,但随着libtorch的出现,相信未来会有更多利用C++作为深度学习开发的项目。

这篇文章主要是给大家介绍如何将经典的maskrcnn_benchmark的模型转为libtorch可以用的C++模型。

maskrcnn的整体框架图如下

6a0bda9baf4fb3c0c92944d70e5004d7.png
maskrcnn框架图

利用torch.jit.trace保存模型权重值

import argparse
import torch

import os

from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer


class wrapper(torch.nn.Module):
    def __init__(self, model):
        super(wrapper, self).__init__()
        self.model = model
    
    def forward(self, input):
        output = self.model(input)
        return tuple(output)



def output_tuple_or_tensor(model, tensor):
    output = model(tensor)
    if isinstance(output, (torch.Tensor, tuple)):
        return model
    elif isinstance(output, list):
        return wrapper(model)


def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection jit")
    parser.add_argument(
        "--config-file",
        default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--weight_path",
        help="weight file path"
    )
    parser.add_argument(
        "--output_path",
        help="jit output path",
        type=str,
        default='./'
    )
    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    model = build_detection_model(cfg)
    checkpoint = DetectronCheckpointer(cfg, model, save_dir=cfg.OUTPUT_DIR)
    _ = checkpoint.load(args.weight_path, None)
    
    #backbone save
    backbone_tensor = torch.zeros(2, 3, 320, 320)
    tmp_model = output_tuple_or_tensor(model.backbone, backbone_tensor)
    for i, j in model.named_parameters():
        print(i)
    print("-------------------------------------------------------------------------------------------------")
    for i, j in model.named_buffers():
        print(i)
    backbone_jit = torch.jit.trace(tmp_model, backbone_tensor)
    backbone_jit.save(os.path.join(args.output_path, 'backbone.pth'))

    #rpn
    #conv
    conv_in_channels = model.rpn.head.conv.weight.shape[1]
    conv_tensor = torch.zeros(2, conv_in_channels, 10, 10)
    conv_jit = torch.jit.trace(model.rpn.head.conv, conv_tensor)
    conv_jit.save(os.path.join(args.output_path, 'rpn_conv.pth'))

    #cls_logits
    logits_in_channels = model.rpn.head.cls_logits.weight.shape[1]
    logits_tensor = torch.zeros(2, logits_in_channels, 10, 10)
    logits_jit = torch.jit.trace(model.rpn.head.cls_logits, logits_tensor)
    logits_jit.save(os.path.join(args.output_path, 'rpn_logits.pth'))

    #bbox_pred
    bbox_in_channels = model.rpn.head.bbox_pred.weight.shape[1]
    bbox_tensor = torch.zeros(2, bbox_in_channels, 10, 10)
    bbox_jit = torch.jit.trace(model.rpn.head.bbox_pred, bbox_tensor)
    bbox_jit.save(os.path.join(args.output_path, 'rpn_bbox.pth'))

    #box_head
    #feature extractor
    for k, v in model.roi_heads.box.feature_extractor._modules.items():
        if 'head' in k:
            #resnet head
            extractor_in_channels = v._modules['layer4'][0]._modules['downsample'][0].weight.shape[1]
            extractor_tensor = torch.zeros(2, extractor_in_channels, 10, 10)
            extractor_jit = torch.jit.trace(v, extractor_tensor)
            extractor_jit.save(os.path.join(args.output_path, 'extractor_' + k + '.pth'))
        elif 'fc' in k:
            extractor_in_channels = v.weight.shape[1]
            extractor_tensor = torch.zeros(2, extractor_in_channels)
            extractor_jit = torch.jit.trace(v, extractor_tensor)
            extractor_jit.save(os.path.join(args.output_path, 'extractor_' + k + '.pth'))
        elif 'conv' in k:
            extractor_in_channels = v.weight.shape[1]
            extractor_tensor = torch.zeros(2, extractor_in_channels, 10, 10)
            extractor_jit = torch.jit.trace(v, extractor_tensor)
            extractor_jit.save(os.path.join(args.output_path, 'extractor_' + k + '.pth'))

    #box_head
    #predictor
    cls_score_in_channels = model.roi_heads.box.predictor.cls_score.weight.shape[1]
    cls_score_tensor = torch.zeros(2, cls_score_in_channels)
    cls_score_jit = torch.jit.trace(model.roi_heads.box.predictor.cls_score, cls_score_tensor)
    cls_score_jit.save(os.path.join(args.output_path, 'cls_score.pth'))

    bbox_pred_in_channels = model.roi_heads.box.predictor.bbox_pred.weight.shape[1]
    bbox_pred_tensor = torch.zeros(2, bbox_pred_in_channels)
    bbox_pred_jit = torch.jit.trace(model.roi_heads.box.predictor.bbox_pred, bbox_pred_tensor)
    bbox_pred_jit.save(os.path.join(args.output_path, 'bbox_pred.pth'))


    #mask_head
    #feature extractor
    for k, v in model.roi_heads.mask.feature_extractor._modules.items():
        if 'head' in k:
            #resnet head
            extractor_in_channels = v._modules['layer4'][0]._modules['downsample'][0].weight.shape[1]
            extractor_tensor = torch.zeros(2, extractor_in_channels, 10, 10)
            extractor_jit = torch.jit.trace(v, extractor_tensor)
            extractor_jit.save(os.path.join(args.output_path, 'extractor_mask_' + k + '.pth'))
        elif 'fc' in k:
            extractor_in_channels = v.weight.shape[1]
            extractor_tensor = torch.zeros(256, extractor_in_channels,3,3)
            extractor_jit = torch.jit.trace(v, extractor_tensor)
            extractor_jit.save(os.path.join(args.output_path, 'extractor_mask_' + k + '.pth'))
        elif 'conv' in k:
            extractor_in_channels = v.weight.shape[1]
            extractor_tensor = torch.zeros(2, extractor_in_channels, 10, 10)
            extractor_jit = torch.jit.trace(v, extractor_tensor)
            extractor_jit.save(os.path.join(args.output_path, 'extractor_mask_' + k + '.pth'))

    #mask head
    #predictor
    mask_cls_score_in_channels = model.roi_heads.mask.predictor.conv5_mask.weight.shape[1]
    #print(mask_cls_score_in_channels)
    mask_cls_score_tensor = torch.zeros(2, mask_cls_score_in_channels,2,2)
    mask_cls_score_jit = torch.jit.trace(model.roi_heads.mask.predictor.conv5_mask, mask_cls_score_tensor)
    mask_cls_score_jit.save(os.path.join(args.output_path, 'mask_cls_score.pth'))

    mask_pred_in_channels = model.roi_heads.mask.predictor.mask_fcn_logits.weight.shape[1]
    mask_pred_tensor = torch.zeros(2, mask_pred_in_channels,2,2)
    mask_pred_jit = torch.jit.trace(model.roi_heads.mask.predictor.mask_fcn_logits, mask_pred_tensor)
    mask_pred_jit.save(os.path.join(args.output_path, 'mask_pred.pth'))


if __name__ == "__main__":
    main()
    print('Complete!')

注意!我们这边的操作仅仅是把模型的权重给保存下来,由于torch.jit.trace功能并不能将动态图和控制流等操作保存下来,所以并不能直接形成一个python一样的端到端的模型!所以我们这边退而求其次的保存模型权重,其他的操作在C++部分完成(比如NMS等)。

我们可以在python端,通过打印模型的named_parameters()和named_buffers知道我们要保存哪些权重信息。

在得到各个模块trace的模型之后,我们利用libtorch将所有模型整合在一起:

由于整体的代码工程较为庞大(我们把整个maskrcnn_benchmark的代码转为了C++版本的),所以我们附上其中模型转换比较核心的部分和大家交流。

其中

modeling::GeneralizedRCNN model = modeling::BuildDetectionModel();

是根据python版maskrcnn_benchmark写的一个C++的模型类,内容完全一样,只是语言不同。

这边转换的核心思想是,将模型中的named_parametersnamed_buffers中带有backbone,rpn和roi_heads.box关键词的权重读到这个实体类中。

调用jit_to_cpp.h中的jit_to_cpp(模型权重存放位置, 模型的yaml配置文件, 模型权重名<向量形式>)就可以整合出C++的模型。

通过搜索关键词,将模型中的权重信息整合起来

jit_to_cpp.h

#pragma once
#include <torch/torch.h>
#include <torch/script.h>

#include <defaults.h>
#include <modeling.h>

#include <iostream>
#include <bisect.h>


namespace rcnn{
namespace utils{

void recur(torch::jit::script::Module& module, std::string name, std::map<std::string, torch::Tensor>& saved);

template<typename T>
void jit_to_cpp(std::string weight_dir, std::string config_path, std::vector<std::string> weight_files){
  T mapper = T();
  std::map<std::string, torch::Tensor> saved;
  std::set<std::string> updated;
  std::map<std::string, std::string> mapping;

  rcnn::config::SetCFGFromFile(config_path);
  modeling::GeneralizedRCNN model = modeling::BuildDetectionModel();
  torch::NoGradGuard guard;

  for(auto& weight_file : weight_files){
    auto module_part = torch::jit::load(weight_dir + "/" + weight_file);
    recur(module_part, weight_file.substr(0, weight_file.size()-4), saved);
  }
  for (auto i : saved) {
      std::cout << i.first << "n";
  }
  for(auto& i : model->named_parameters()){
      std::cout << i.key() << "n";
    std::string new_name;
    if(i.key().find("backbone") != std::string::npos){
      new_name = mapper.backboneMapping(i.key(), i.value(), saved);
      updated.insert(i.key());
      mapping[i.key()] = new_name;
    }
    else if(i.key().find("rpn") != std::string::npos){
      new_name = mapper.rpn(i.key(), i.value(), saved);
      updated.insert(i.key());
      mapping[i.key()] = new_name;
    }
    else if((i.key().find("roi_heads.box") != std::string::npos)||(i.key().find("roi_heads.mask") != std::string::npos)){
      new_name = mapper.roiHead(i.key(), i.value(), saved);
      updated.insert(i.key());
      mapping[i.key()] = new_name;
    }
    else{
      assert(false);
    }
  }
  
  for(auto& i : model->named_buffers()){
    std::string new_name;
    if(i.key().find("backbone") != std::string::npos){
      new_name = mapper.backboneMapping(i.key(), i.value(), saved);
      updated.insert(i.key());
      mapping[i.key()] = new_name;
    }
    else if(i.key().find("rpn") != std::string::npos){
      new_name = mapper.rpn(i.key(), i.value(), saved);
      updated.insert(i.key());
      mapping[i.key()] = new_name;
    }
    else if((i.key().find("roi_heads.box") != std::string::npos) || (i.key().find("roi_heads.mask") != std::string::npos)) {
      new_name = mapper.roiHead(i.key(), i.value(), saved);
      updated.insert(i.key());
      mapping[i.key()] = new_name;
    }
    else{
      assert(false);
    }
  }

  torch::serialize::OutputArchive archive;
  for(auto& i : model->named_parameters()){
    std::cout << i.key() << " parameter loaded from " << mapping[i.key()] << "n";
    assert(updated.count(i.key()));
    assert((saved.at(mapping[i.key()]) != i.value()).sum().item<int>() == 0);
    archive.write(i.key(), i.value());
  }

  for(auto& i : model->named_buffers()){
      std::cout << i.key() << " buffer loaded from " << mapping[i.key()] << "n";
    assert(updated.count(i.key()));
    if(i.key().find("anchor_generator") == std::string::npos){
      assert( (saved.at(mapping[i.key()]) != i.value()).sum().item<int>() == 0);
    }
    archive.write(i.key(), i.value(), true);
  }
  archive.save_to("../models/new_pth_from_python_cpp.pth");
  std::cout << "saved as /models/new_pth_from_python_cpp.pthn"; 
}

class ResNetMapper{

public:
  ResNetMapper() = default;
  std::string backboneMapping(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
  std::string roiHead(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
  std::string rpn(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);

};

class VoVNetMapper{

public:
  VoVNetMapper() = default;
  std::string backboneMapping(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
  std::string roiHead(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
  std::string rpn(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
};


}
}

jit_to_cpp.cpp:

#include "jit_to_cpp.h"

#include <cassert>
#include <iostream>


namespace rcnn{
namespace utils{

void recur(torch::jit::script::Module& module, std::string name, std::map<std::string, torch::Tensor>& saved){
  std::string new_name;
  if(name.compare("") != 0)
    new_name = name + ".";
  
  for(auto u : module.get_parameters()){
    torch::Tensor tensor = u.value().toTensor();;
    saved[new_name + u.name()] = tensor;
  }
  for(auto u : module.get_attributes()){

    torch::Tensor tensor = u.value().toTensor();
    saved[new_name + u.name()] = tensor;
  }
  for(auto i : module.get_modules())
    recur(i, new_name + i.name().name(), saved);
}

std::string ResNetMapper::backboneMapping(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved){
  std::string new_name = name.substr(20);
  if(name.find("fpn") != std::string::npos){
    new_name = new_name.substr(0, 14);
    if(name.find("weight") != std::string::npos){
      new_name += ".weight";
    }
    else{
      new_name += ".bias";
    }
  }

  for(auto s = saved.begin(); s != saved.end(); ++s){
    if((s->first).find(new_name) != std::string::npos){
      value.copy_(s->second);
      return s->first;
      // updated.insert(i.key());
      // mapping[i.key()] = s->first;
    }
  }
  assert(false);
}

std::string ResNetMapper::roiHead(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved){
  std::string new_name;
      if (name.find("cls_score") != std::string::npos) {
          if (name.find("weight") != std::string::npos)
              new_name = "cls_score.weight";
          else
              new_name = "cls_score.bias";
          value.copy_(saved.at(new_name));
          return new_name;
      }
      else if (name.find("bbox_pred") != std::string::npos) {
          if (name.find("weight") != std::string::npos)
              new_name = "bbox_pred.weight";
          else
              new_name = "bbox_pred.bias";
          value.copy_(saved.at(new_name));
          // updated.insert(i.key());
          // mapping[i.key()] = name;
          return new_name;
      }
      else if (name.find(".head.") != std::string::npos) {
          new_name = name.substr(39);
          for (auto s = saved.begin(); s != saved.end(); ++s) {
              if ((s->first).find(new_name) != std::string::npos) {
                  value.copy_(s->second);
                  // updated.insert(i.key());
                  // mapping[i.key()] = s->first;
                  return s->first;
              }
          }
          assert(false);
      }
      else if (name.find("conv5_mask") != std::string::npos) {
          if (name.find("weight") != std::string::npos)
              new_name = "mask_cls_score.weight";
          else
              new_name = "mask_cls_score.bias";
          value.copy_(saved.at(new_name));
          // updated.insert(i.key());
          // mapping[i.key()] = name;
          return new_name;
      }
      else if (name.find("mask_fcn_logits") != std::string::npos) {
          if (name.find("weight") != std::string::npos)
              new_name = "mask_pred.weight";
          else
              new_name = "mask_pred.bias";
          value.copy_(saved.at(new_name));
          // updated.insert(i.key());
          // mapping[i.key()] = name;
          return new_name;
      }
      else if (name.find("mask_fcn") != std::string::npos) {
          new_name = "extractor_mask_" + name.substr(name.find("mask_fcn"));
          new_name.erase(new_name.begin() + 24, new_name.begin() + 26);
          value.copy_(saved.at(new_name));
          // updated.insert(i.key());
          // mapping[i.key()] = name;
          return new_name;
      }
      else if (name.find("fc") != std::string::npos) {
          new_name = "extractor_" + name.substr(name.find("fc"));
          value.copy_(saved.at(new_name));
          // updated.insert(i.key());
          // mapping[i.key()] = name;
          return new_name;
      }
  else{
    assert(false);
  }
}

std::string ResNetMapper::rpn(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved){
  std::string new_name;
  if(name.find("conv") != std::string::npos){
    if(name.find("weight") != std::string::npos)
      new_name = "rpn_conv.weight";
    else
      new_name = "rpn_conv.bias";
    value.copy_(saved.at(new_name));
    return new_name;
  }
  else if(name.find("bbox") != std::string::npos){
    if(name.find("weight") != std::string::npos)
      new_name = "rpn_bbox.weight";
    else
      new_name = "rpn_bbox.bias";
    value.copy_(saved.at(new_name));
    // updated.insert(i.key());
    // mapping[i.key()] = name;
    return new_name;
  }
  else if(name.find("logits") != std::string::npos){
    if(name.find("weight") != std::string::npos)
      new_name = "rpn_logits.weight";
    else
      new_name = "rpn_logits.bias";
    value.copy_(saved.at(new_name));
    return new_name;
  }
  else if (name.find("anchors") != std::string::npos) {
      new_name = name;
      return new_name;
  }
  else{
    assert(false);
 }
}

}
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值