Pytorch 有两种保存模型的方式,都是通过调用pickle序列化方法实现的。
第一种方法只保存模型参数。第二种方法保存完整模型。推荐使用第一种,第二种方法可能在切换设备和目录的时候出现各种问题。
1.保存模型参数方法:
print(model.state_dict().keys()) # 输出模型参数名称
# 保存模型参数到路径"./data/model_parameter.pkl"
torch.save(model.state_dict(), "./data/model_parameter.pkl")
new_model = Model() # 调用模型Model
new_model.load_state_dict(torch.load("./data/model_parameter.pkl")) # 加载模型参数
new_model.forward(input) # 进行使用
2.保存完整模型(不推荐)
torch.save(model, './data/model.pkl') # 保存整个模型
new_model = torch.load('./data/model.pkl') # 加载模型
3.Transfomers库预训练模型的加载
# 使用transformers预训练后进行保存
model.save_pretrained(model_path)
tokenizer.save_pretrained(tokenizer_path)
# 预训练模型使用 `from_pretrained()` 重新加载
model.from_pretrained(model_path)
tokenizer.from_pretrained(tokenizer_path)