排查方法:在backward()前面加上torch.autograd.set_detect_anomaly(True)
torch.autograd.set_detect_anomaly(True)
loss.backward()
我报错的原因:输入网络的值为NAN,或存在极端值(过大/过小)
排查方法:打出torch最大值和最小值
print("x_path", x_path.min(), x_path.max(), torch.isnan(x_path).any())
三种解决方法:
删除极端值
进行归一化
将极端值赋值(可能会影响学习的准确率)
wsi_bag = torch.nan_to_num(wsi_bag, nan=0.0)
wsi_bag = torch.clamp(wsi_bag, min=0.0, max=1.0)