一、从头开始训练的时候正常,大概占了4/5的显存容量,然后中断了, 再resume 就会出现out of memory ?
因为你 load checkpoint 的时候把参数加载到 显存 里面了;而且大概率你的code是这么写的然后报的错:
model = Network().cuda()
# 这里如果你存下来的权重device在cuda上,那么将自动载入GPU; 此时 model在GPU,新载入的参数也在GPU,就很有可能放不下
# 而且就算不在,在内存上,你还是得 state_dict = state_dict.cuda() 还是会炸显存
state_dict = torch.load('xxxxx.pth.tar')
model.load_state_dict(state_dict)
改成这样就行了
model = Network()
model.load_state_dict(torch.load('xxxx.pth.tar', map_location='cpu' ))
model.cuda()
存储方面,能在 device='cpu' 上面做的事情,就别在 device='cuda'上面做。
二、Pytorch 训练与测试时爆显存(out of memory)的一个解决方案
Pytorch 训练时有时候会因为加载的东西过多而爆显存,有些时候这种情况还可以使用cuda的清理技术进行修整,当然如果