今天遇到一个bug,报错显示RuntimeError: Error(s) in loading state_dict for DistributedDataParallel:
我理解就是模型的权重加载失败了,上网找了有下面几种说法 1.torch的版本不一致
2.模型保存的时候是使用分布式保存,加载时要用特定的方法
3.我修改了模型的结构,导致参数加载不匹配,需要在加载时指定一个参数,strict=false
但是我最后debug出来似乎都不是以上原因,而是因为我在一个网络(假设是net1)中嵌套了一个网络(net2),但是在train文件中,我进行了model.load_state_dict()(这里想要加载的是net1的权重)这时候会可能会把这个权重加载到嵌套中的另一个网络(net2),而不是本身的网络(net1)中,所以我将这两个网络单独定义,然后各自创建实例,加载权重,就解决这个问题了。