pytorch 保存模型和加载模型

这篇博客详细介绍了如何在PyTorch中保存已经训练好的模型,并且演示了如何在之后的训练过程中加载模型以继续训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

def save_model(save_dir, phase, name, epoch, f1score, model):
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    save_dir = os.path.join(save_dir, args.model)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    save_dir = os.path.join(save_dir, phase)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    state_dict = model.state_dict()
    for key in state_dict.keys():
        state_dict[key] = state_dict[key].cpu()
    state_dict_all = {
        'state_dict': state_dict,
        'epoch': epoch,
        'f1score': f1score,
    }
    torch.save(state_dict_all, os.path.join(save_dir, '{:s}.ckpt'.format(name)))
    if 'best' in name and f1score > 0.3:
        torch.save(state_dict_all, os.path.join(save_dir, '{:s}_{:s}.ckpt'.format(name, str(epoch))))

pytorch 保存模型

pytorch 加载模型进行继续训练

    if args.resume:
        s
### 如何在PyTorch保存加载模型 #### 保存整个模型 为了持久化存储训练好的神经网络,在PyTorch中有多种方式可以实现这一目标。一种方法就是直接保存整个模型对象,这包含了架构以及参数的信息。 ```python import torch.nn as nn import torch.optim as optim from torchvision import models model = models.resnet50(pretrained=False) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) torch.save(model, 'entire_model.pth') ``` 这种方法简单直观但是不够灵活,因为如果想要改变模型结构的一部分就需要重新编写代码并调整其他部分[^1]。 #### 只保存模型状态字典 更推荐的做法是仅导出模型的状态字典`state_dict()`,它只记录了权重而忽略了具体的类定义其他属性设置。 ```python # Saving only the state dictionary of a model torch.save(model.state_dict(), 'model_weights.pth') # Loading back just the weights into an existing architecture instance new_model = models.resnet50(pretrained=False) new_model.load_state_dict(torch.load('model_weights.pth')) ``` 这种方式允许轻松迁移学习或微调预训练模型,并且文件体积通常较小。 #### 加载特定设备上的模型 有时需要指定是在CPU还是GPU上运行已加载模型。可以通过传递额外参数给`.to(device)`函数完成此操作。 ```python device = torch.device("cuda" if torch.cuda.is_available() else "cpu") loaded_model = models.resnet50(pretrained=False).to(device) loaded_model.load_state_dict(torch.load('model_weights.pth', map_location=device)) ``` 上述过程展示了如何针对不同场景有效地管理PyTorch中的模型存档与恢复工作流。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值