网络模型的保存与加载

 法一

保存

保存该模型的架构和参数

import torch
import torchvision

vgg16=torchvision.models.vgg16(weights=None)
# 保存方式1,保存了该模型结构和参数
torch.save(vgg16,"vgg16_method1.pth")

torch.save(模型名称,路径)

运行之后,当前目录下会出现相应的名字

加载

import torch
import torchvision

# 方式1
model=torch.load("vgg16_method1.pth")
print(model)

法二

保存

将模型中的参数以字典形式保存,通过这种方式可以保存所保存模型的状态,只保存模型的参数(官方推荐)减小了保存文件的大小

import torch
import torchvision

vgg16=torchvision.models.vgg16(weights=None)
# 保存方式2,保存模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

 

加载

import torch
import torchvision
model=torch.load("vgg16_method2.pth")
print(model)

 可以看到是个字典形式的数据,不再是网络模型了

如果要恢复成网络模型需要先创建网络模型结构,再用这个模型的load_state_dict 用于加载模型状态字典的重要工具

import torch
import torchvision

# 方式2
vgg16=torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

 

实战: 将自定义的网络模型保存

保存

import torch
import torchvision
from mmcv.cnn import Conv2d
from torch import nn
from torch.nn import Flatten, MaxPool2d, Linear

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self,x):
        x=self.model1(x)
        return x
network=Network()
torch.save(network,"network_method1.pth")

加载

import torch
model=torch.load("network_method1.pth")
print(model)

如果直接加载,会报错,

显示,你不能得到Network这个属性

主要原因是找不到该模型的类,因此方式一读取自定义的神经网络时需要写上类的定义

import torch
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self,x):
        x=self.model1(x)
        return x

model=torch.load("network_method1.pth")
print(model)

但是不需要实例化,即network=Network()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值