pytorch保存模型方法

本文介绍了Pytorch中两种保存和加载模型的方法。推荐使用保存模型参数的方式,以确保在不同环境间迁移的兼容性。首先,通过`state_dict()`保存和加载模型参数;其次,虽然可以保存整个模型,但这种方式可能存在设备和目录适配问题。对于Transformers库的预训练模型,使用`save_pretrained()`和`from_pretrained()`进行保存和加载。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Pytorch 有两种保存模型的方式,都是通过调用pickle序列化方法实现的。

第一种方法只保存模型参数。第二种方法保存完整模型。推荐使用第一种,第二种方法可能在切换设备和目录的时候出现各种问题。

1.保存模型参数方法:

print(model.state_dict().keys())                                # 输出模型参数名称

# 保存模型参数到路径"./data/model_parameter.pkl"
torch.save(model.state_dict(), "./data/model_parameter.pkl")
new_model = Model()                                                    # 调用模型Model
new_model.load_state_dict(torch.load("./data/model_parameter.pkl"))    # 加载模型参数     
new_model.forward(input)                                               # 进行使用

2.保存完整模型(不推荐)

torch.save(model, './data/model.pkl')        # 保存整个模型
new_model = torch.load('./data/model.pkl')   # 加载模型

3.Transfomers库预训练模型的加载

# 使用transformers预训练后进行保存
model.save_pretrained(model_path)                              
tokenizer.save_pretrained(tokenizer_path)

# 预训练模型使用 `from_pretrained()` 重新加载
model.from_pretrained(model_path)                              
tokenizer.from_pretrained(tokenizer_path)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值