pytorch如何知道某个Parameter是在哪一个Module中的创建的
在定位pytorch精度问题时,发现optimizer中某些Parameter值异常,想知道它属于哪个模块的.本文提供二种方法
1.全局搜索
2.在创建Parameter的地方加一个属性,写明所在的模块名,需要的时候直接获取
代码
import torch
import sys
sys.setrecursionlimit(1000)
def search_recursive(var,stack,_id,depth):
if var.__class__.__name__ in [
"module","type","NoneType",
"str","int","function","method-wrapper",
"builtin_function_or_method",
"method","_TensorMeta",
"Tensor","method_descriptor",
"bool","device","dtype",
"getset_descriptor","layout",
"wrapper_descriptor","property",
"_ParameterM