法一
保存
保存该模型的架构和参数
import torch
import torchvision
vgg16=torchvision.models.vgg16(weights=None)
# 保存方式1,保存了该模型结构和参数
torch.save(vgg16,"vgg16_method1.pth")
torch.save(模型名称,路径)
运行之后,当前目录下会出现相应的名字
加载
import torch
import torchvision
# 方式1
model=torch.load("vgg16_method1.pth")
print(model)
法二
保存
将模型中的参数以字典形式保存,通过这种方式可以保存所保存模型的状态,只保存模型的参数(官方推荐)减小了保存文件的大小
import torch
import torchvision
vgg16=torchvision.models.vgg16(weights=None)
# 保存方式2,保存模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
加载
import torch
import torchvision
model=torch.load("vgg16_method2.pth")
print(model)
可以看到是个字典形式的数据,不再是网络模型了
如果要恢复成网络模型需要先创建网络模型结构,再用这个模型的load_state_dict 用于加载模型状态字典的重要工具
import torch
import torchvision
# 方式2
vgg16=torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)
实战: 将自定义的网络模型保存
保存
import torch
import torchvision
from mmcv.cnn import Conv2d
from torch import nn
from torch.nn import Flatten, MaxPool2d, Linear
class Network(nn.Module):
def __init__(self):
super().__init__()
self.model1 = nn.Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x=self.model1(x)
return x
network=Network()
torch.save(network,"network_method1.pth")
加载
import torch
model=torch.load("network_method1.pth")
print(model)
如果直接加载,会报错,
显示,你不能得到Network这个属性
主要原因是找不到该模型的类,因此方式一读取自定义的神经网络时需要写上类的定义
import torch
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class Network(nn.Module):
def __init__(self):
super().__init__()
self.model1 = nn.Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x=self.model1(x)
return x
model=torch.load("network_method1.pth")
print(model)
但是不需要实例化,即network=Network()