一、网络模型的保存
Pytorch提供了两种方式进行保存模型。
import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1:模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth") # 保存模型结构及参数
# 保存方式2:模型参数,保存成字典的形式(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# 陷阱1:方式1保存模型,陷阱在加载处
class Model(nn.Module):
def __init__(self) -> None:
super().__init__() # 初始化父类属性
self.model1 = Sequential(
Conv2d(3, 32, 5, stride=1, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5