一. pytorch 保存、读取 tensor
- 首先导包:
import torch
save_torch = torch.Tensor([[1, 2, 3, 4],
[2, 34, 5, 6]])
- 保存 tensor
torch.save(save_torch, 'test_save_tensor.pt')
- 读取 tensor
load_torch = torch.load('test_save_tensor.pt')
- 完整测试代码
import torch
save_torch = torch.Tensor([[1, 2, 3, 4],
[2, 34, 5, 6]])
print(save_torch)
torch.save(save_torch, 'test_save_tensor.pt') # 保存
load_torch = torch.load('test_save_tensor.pt') # 读取
print(load_torch)
- 保存网络结构:model是自己定义的网络结构:
# 保存整个网络
torch.save(net, PATH.pth)
# 保存网络中的参数, 速度快,占空间少
torch.save(net.state_dict(),PATH.pth)
#--------------------------------------------------
#针对上面一般的保存方法,加载的方法分别是:
model_dict=torch.load(PATH)
model_dict=model.load_state_dict(torch.load(PATH))
二. pytorch 保存、读取numpy
- Numpy保存数据:利用numpy.save()函数将array保存为.npy格式的数据:
import numpy as np
np.save('where/you/wanto/store/output',arr) #numpy 会自动加上.npy后缀
- Numpy读取数据
b = np.load('here/you/wanto/store/output.npy')
三. 相关链接
pytorch 保存、读取 tensor 数据
Python Numpy Pytorch 保存数据
pytorch中的tensor以numpy形式进行输出保存
PyTorch教程-7:PyTorch中保存与加载tensor和模型详解
pytorch 模型输出特征 保存npy