主要是测试下模型裁剪后转onnx的问题。删除vgg16网络全连接层,加载预训练模型并重新保存模型参数,将该参数用于转onnx模型格式。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time :2022/8/4 14:45
# @Author :weiz
# @ProjectName :cbir
# @File :vgg.py
# @Description :
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
class VGG16(nn.Module):
def __init__(self):
super(VGG16, self).__init__()
# 1 * 3 * 224 * 224
self.conv1_1 = nn.Conv2d(3, 64, 3) # conv1_1:1 * 64 * 222 * 222
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1)) # conv1_2:1 * 64 * 222* 222
self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1)) # maxpool1: 1 * 64 * 112 * 112
self.conv2_1 = nn.Conv2d(64, 128, 3) # conv2_1:1 * 128 * 110 * 110
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=(1, 1)) # conv2_2:1 * 128 * 110 * 110
self.maxpool2 = nn.MaxPool2d((2, 2), pa