
深度学习需要巨大的计算量,对于开发原型可以用python,matlab比较方便,但是真正实用化的话,c++作为底层语言明显更具优势。
目前大部分的优秀的深度学习项目仍然是以python为主流语言,但随着libtorch的出现,相信未来会有更多利用C++作为深度学习开发的项目。
这篇文章主要是给大家介绍如何将经典的maskrcnn_benchmark的模型转为libtorch可以用的C++模型。
maskrcnn的整体框架图如下

利用torch.jit.trace保存模型权重值
import argparse
import torch
import os
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
class wrapper(torch.nn.Module):
def __init__(self, model):
super(wrapper, self).__init__()
self.model = model
def forward(self, input):
output = self.model(input)
return tuple(output)
def output_tuple_or_tensor(model, tensor):
output = model(tensor)
if isinstance(output, (torch.Tensor, tuple)):
return model
elif isinstance(output, list):
return wrapper(model)
def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection jit")
parser.add_argument(
"--config-file",
default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--weight_path",
help="weight file path"
)
parser.add_argument(
"--output_path",
help="jit output path",
type=str,
default='./'
)
args = parser.parse_args()
cfg.merge_from_file(args.config_file)
model = build_detection_model(cfg)
checkpoint = DetectronCheckpointer(cfg, model, save_dir=cfg.OUTPUT_DIR)
_ = checkpoint.load(args.weight_path, None)
#backbone save
backbone_tensor = torch.zeros(2, 3, 320, 320)
tmp_model = output_tuple_or_tensor(model.backbone, backbone_tensor)
for i, j in model.named_parameters():
print(i)
print("-------------------------------------------------------------------------------------------------")
for i, j in model.named_buffers():
print(i)
backbone_jit = torch.jit.trace(tmp_model, backbone_tensor)
backbone_jit.save(os.path.join(args.output_path, 'backbone.pth'))
#rpn
#conv
conv_in_channels = model.rpn.head.conv.weight.shape[1]
conv_tensor = torch.zeros(2, conv_in_channels, 10, 10)
conv_jit = torch.jit.trace(model.rpn.head.conv, conv_tensor)
conv_jit.save(os.path.join(args.output_path, 'rpn_conv.pth'))
#cls_logits
logits_in_channels = model.rpn.head.cls_logits.weight.shape[1]
logits_tensor = torch.zeros(2, logits_in_channels, 10, 10)
logits_jit = torch.jit.trace(model.rpn.head.cls_logits, logits_tensor)
logits_jit.save(os.path.join(args.output_path, 'rpn_logits.pth'))
#bbox_pred
bbox_in_channels = model.rpn.head.bbox_pred.weight.shape[1]
bbox_tensor = torch.zeros(2, bbox_in_channels, 10, 10)
bbox_jit = torch.jit.trace(model.rpn.head.bbox_pred, bbox_tensor)
bbox_jit.save(os.path.join(args.output_path, 'rpn_bbox.pth'))
#box_head
#feature extractor
for k, v in model.roi_heads.box.feature_extractor._modules.items():
if 'head' in k:
#resnet head
extractor_in_channels = v._modules['layer4'][0]._modules['downsample'][0].weight.shape[1]
extractor_tensor = torch.zeros(2, extractor_in_channels, 10, 10)
extractor_jit = torch.jit.trace(v, extractor_tensor)
extractor_jit.save(os.path.join(args.output_path, 'extractor_' + k + '.pth'))
elif 'fc' in k:
extractor_in_channels = v.weight.shape[1]
extractor_tensor = torch.zeros(2, extractor_in_channels)
extractor_jit = torch.jit.trace(v, extractor_tensor)
extractor_jit.save(os.path.join(args.output_path, 'extractor_' + k + '.pth'))
elif 'conv' in k:
extractor_in_channels = v.weight.shape[1]
extractor_tensor = torch.zeros(2, extractor_in_channels, 10, 10)
extractor_jit = torch.jit.trace(v, extractor_tensor)
extractor_jit.save(os.path.join(args.output_path, 'extractor_' + k + '.pth'))
#box_head
#predictor
cls_score_in_channels = model.roi_heads.box.predictor.cls_score.weight.shape[1]
cls_score_tensor = torch.zeros(2, cls_score_in_channels)
cls_score_jit = torch.jit.trace(model.roi_heads.box.predictor.cls_score, cls_score_tensor)
cls_score_jit.save(os.path.join(args.output_path, 'cls_score.pth'))
bbox_pred_in_channels = model.roi_heads.box.predictor.bbox_pred.weight.shape[1]
bbox_pred_tensor = torch.zeros(2, bbox_pred_in_channels)
bbox_pred_jit = torch.jit.trace(model.roi_heads.box.predictor.bbox_pred, bbox_pred_tensor)
bbox_pred_jit.save(os.path.join(args.output_path, 'bbox_pred.pth'))
#mask_head
#feature extractor
for k, v in model.roi_heads.mask.feature_extractor._modules.items():
if 'head' in k:
#resnet head
extractor_in_channels = v._modules['layer4'][0]._modules['downsample'][0].weight.shape[1]
extractor_tensor = torch.zeros(2, extractor_in_channels, 10, 10)
extractor_jit = torch.jit.trace(v, extractor_tensor)
extractor_jit.save(os.path.join(args.output_path, 'extractor_mask_' + k + '.pth'))
elif 'fc' in k:
extractor_in_channels = v.weight.shape[1]
extractor_tensor = torch.zeros(256, extractor_in_channels,3,3)
extractor_jit = torch.jit.trace(v, extractor_tensor)
extractor_jit.save(os.path.join(args.output_path, 'extractor_mask_' + k + '.pth'))
elif 'conv' in k:
extractor_in_channels = v.weight.shape[1]
extractor_tensor = torch.zeros(2, extractor_in_channels, 10, 10)
extractor_jit = torch.jit.trace(v, extractor_tensor)
extractor_jit.save(os.path.join(args.output_path, 'extractor_mask_' + k + '.pth'))
#mask head
#predictor
mask_cls_score_in_channels = model.roi_heads.mask.predictor.conv5_mask.weight.shape[1]
#print(mask_cls_score_in_channels)
mask_cls_score_tensor = torch.zeros(2, mask_cls_score_in_channels,2,2)
mask_cls_score_jit = torch.jit.trace(model.roi_heads.mask.predictor.conv5_mask, mask_cls_score_tensor)
mask_cls_score_jit.save(os.path.join(args.output_path, 'mask_cls_score.pth'))
mask_pred_in_channels = model.roi_heads.mask.predictor.mask_fcn_logits.weight.shape[1]
mask_pred_tensor = torch.zeros(2, mask_pred_in_channels,2,2)
mask_pred_jit = torch.jit.trace(model.roi_heads.mask.predictor.mask_fcn_logits, mask_pred_tensor)
mask_pred_jit.save(os.path.join(args.output_path, 'mask_pred.pth'))
if __name__ == "__main__":
main()
print('Complete!')
注意!我们这边的操作仅仅是把模型的权重给保存下来,由于torch.jit.trace功能并不能将动态图和控制流等操作保存下来,所以并不能直接形成一个python一样的端到端的模型!所以我们这边退而求其次的保存模型权重,其他的操作在C++部分完成(比如NMS等)。
我们可以在python端,通过打印模型的named_parameters()和named_buffers知道我们要保存哪些权重信息。
在得到各个模块trace的模型之后,我们利用libtorch将所有模型整合在一起:
由于整体的代码工程较为庞大(我们把整个maskrcnn_benchmark的代码转为了C++版本的),所以我们附上其中模型转换比较核心的部分和大家交流。
其中
modeling::GeneralizedRCNN model = modeling::BuildDetectionModel();
是根据python版maskrcnn_benchmark写的一个C++的模型类,内容完全一样,只是语言不同。
这边转换的核心思想是,将模型中的named_parameters和named_buffers中带有backbone,rpn和roi_heads.box关键词的权重读到这个实体类中。
调用jit_to_cpp.h中的jit_to_cpp(模型权重存放位置, 模型的yaml配置文件, 模型权重名<向量形式>)就可以整合出C++的模型。
通过搜索关键词,将模型中的权重信息整合起来
jit_to_cpp.h
#pragma once
#include <torch/torch.h>
#include <torch/script.h>
#include <defaults.h>
#include <modeling.h>
#include <iostream>
#include <bisect.h>
namespace rcnn{
namespace utils{
void recur(torch::jit::script::Module& module, std::string name, std::map<std::string, torch::Tensor>& saved);
template<typename T>
void jit_to_cpp(std::string weight_dir, std::string config_path, std::vector<std::string> weight_files){
T mapper = T();
std::map<std::string, torch::Tensor> saved;
std::set<std::string> updated;
std::map<std::string, std::string> mapping;
rcnn::config::SetCFGFromFile(config_path);
modeling::GeneralizedRCNN model = modeling::BuildDetectionModel();
torch::NoGradGuard guard;
for(auto& weight_file : weight_files){
auto module_part = torch::jit::load(weight_dir + "/" + weight_file);
recur(module_part, weight_file.substr(0, weight_file.size()-4), saved);
}
for (auto i : saved) {
std::cout << i.first << "n";
}
for(auto& i : model->named_parameters()){
std::cout << i.key() << "n";
std::string new_name;
if(i.key().find("backbone") != std::string::npos){
new_name = mapper.backboneMapping(i.key(), i.value(), saved);
updated.insert(i.key());
mapping[i.key()] = new_name;
}
else if(i.key().find("rpn") != std::string::npos){
new_name = mapper.rpn(i.key(), i.value(), saved);
updated.insert(i.key());
mapping[i.key()] = new_name;
}
else if((i.key().find("roi_heads.box") != std::string::npos)||(i.key().find("roi_heads.mask") != std::string::npos)){
new_name = mapper.roiHead(i.key(), i.value(), saved);
updated.insert(i.key());
mapping[i.key()] = new_name;
}
else{
assert(false);
}
}
for(auto& i : model->named_buffers()){
std::string new_name;
if(i.key().find("backbone") != std::string::npos){
new_name = mapper.backboneMapping(i.key(), i.value(), saved);
updated.insert(i.key());
mapping[i.key()] = new_name;
}
else if(i.key().find("rpn") != std::string::npos){
new_name = mapper.rpn(i.key(), i.value(), saved);
updated.insert(i.key());
mapping[i.key()] = new_name;
}
else if((i.key().find("roi_heads.box") != std::string::npos) || (i.key().find("roi_heads.mask") != std::string::npos)) {
new_name = mapper.roiHead(i.key(), i.value(), saved);
updated.insert(i.key());
mapping[i.key()] = new_name;
}
else{
assert(false);
}
}
torch::serialize::OutputArchive archive;
for(auto& i : model->named_parameters()){
std::cout << i.key() << " parameter loaded from " << mapping[i.key()] << "n";
assert(updated.count(i.key()));
assert((saved.at(mapping[i.key()]) != i.value()).sum().item<int>() == 0);
archive.write(i.key(), i.value());
}
for(auto& i : model->named_buffers()){
std::cout << i.key() << " buffer loaded from " << mapping[i.key()] << "n";
assert(updated.count(i.key()));
if(i.key().find("anchor_generator") == std::string::npos){
assert( (saved.at(mapping[i.key()]) != i.value()).sum().item<int>() == 0);
}
archive.write(i.key(), i.value(), true);
}
archive.save_to("../models/new_pth_from_python_cpp.pth");
std::cout << "saved as /models/new_pth_from_python_cpp.pthn";
}
class ResNetMapper{
public:
ResNetMapper() = default;
std::string backboneMapping(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
std::string roiHead(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
std::string rpn(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
};
class VoVNetMapper{
public:
VoVNetMapper() = default;
std::string backboneMapping(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
std::string roiHead(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
std::string rpn(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved);
};
}
}
jit_to_cpp.cpp:
#include "jit_to_cpp.h"
#include <cassert>
#include <iostream>
namespace rcnn{
namespace utils{
void recur(torch::jit::script::Module& module, std::string name, std::map<std::string, torch::Tensor>& saved){
std::string new_name;
if(name.compare("") != 0)
new_name = name + ".";
for(auto u : module.get_parameters()){
torch::Tensor tensor = u.value().toTensor();;
saved[new_name + u.name()] = tensor;
}
for(auto u : module.get_attributes()){
torch::Tensor tensor = u.value().toTensor();
saved[new_name + u.name()] = tensor;
}
for(auto i : module.get_modules())
recur(i, new_name + i.name().name(), saved);
}
std::string ResNetMapper::backboneMapping(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved){
std::string new_name = name.substr(20);
if(name.find("fpn") != std::string::npos){
new_name = new_name.substr(0, 14);
if(name.find("weight") != std::string::npos){
new_name += ".weight";
}
else{
new_name += ".bias";
}
}
for(auto s = saved.begin(); s != saved.end(); ++s){
if((s->first).find(new_name) != std::string::npos){
value.copy_(s->second);
return s->first;
// updated.insert(i.key());
// mapping[i.key()] = s->first;
}
}
assert(false);
}
std::string ResNetMapper::roiHead(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved){
std::string new_name;
if (name.find("cls_score") != std::string::npos) {
if (name.find("weight") != std::string::npos)
new_name = "cls_score.weight";
else
new_name = "cls_score.bias";
value.copy_(saved.at(new_name));
return new_name;
}
else if (name.find("bbox_pred") != std::string::npos) {
if (name.find("weight") != std::string::npos)
new_name = "bbox_pred.weight";
else
new_name = "bbox_pred.bias";
value.copy_(saved.at(new_name));
// updated.insert(i.key());
// mapping[i.key()] = name;
return new_name;
}
else if (name.find(".head.") != std::string::npos) {
new_name = name.substr(39);
for (auto s = saved.begin(); s != saved.end(); ++s) {
if ((s->first).find(new_name) != std::string::npos) {
value.copy_(s->second);
// updated.insert(i.key());
// mapping[i.key()] = s->first;
return s->first;
}
}
assert(false);
}
else if (name.find("conv5_mask") != std::string::npos) {
if (name.find("weight") != std::string::npos)
new_name = "mask_cls_score.weight";
else
new_name = "mask_cls_score.bias";
value.copy_(saved.at(new_name));
// updated.insert(i.key());
// mapping[i.key()] = name;
return new_name;
}
else if (name.find("mask_fcn_logits") != std::string::npos) {
if (name.find("weight") != std::string::npos)
new_name = "mask_pred.weight";
else
new_name = "mask_pred.bias";
value.copy_(saved.at(new_name));
// updated.insert(i.key());
// mapping[i.key()] = name;
return new_name;
}
else if (name.find("mask_fcn") != std::string::npos) {
new_name = "extractor_mask_" + name.substr(name.find("mask_fcn"));
new_name.erase(new_name.begin() + 24, new_name.begin() + 26);
value.copy_(saved.at(new_name));
// updated.insert(i.key());
// mapping[i.key()] = name;
return new_name;
}
else if (name.find("fc") != std::string::npos) {
new_name = "extractor_" + name.substr(name.find("fc"));
value.copy_(saved.at(new_name));
// updated.insert(i.key());
// mapping[i.key()] = name;
return new_name;
}
else{
assert(false);
}
}
std::string ResNetMapper::rpn(const std::string& name, torch::Tensor& value, std::map<std::string, torch::Tensor>& saved){
std::string new_name;
if(name.find("conv") != std::string::npos){
if(name.find("weight") != std::string::npos)
new_name = "rpn_conv.weight";
else
new_name = "rpn_conv.bias";
value.copy_(saved.at(new_name));
return new_name;
}
else if(name.find("bbox") != std::string::npos){
if(name.find("weight") != std::string::npos)
new_name = "rpn_bbox.weight";
else
new_name = "rpn_bbox.bias";
value.copy_(saved.at(new_name));
// updated.insert(i.key());
// mapping[i.key()] = name;
return new_name;
}
else if(name.find("logits") != std::string::npos){
if(name.find("weight") != std::string::npos)
new_name = "rpn_logits.weight";
else
new_name = "rpn_logits.bias";
value.copy_(saved.at(new_name));
return new_name;
}
else if (name.find("anchors") != std::string::npos) {
new_name = name;
return new_name;
}
else{
assert(false);
}
}
}
}