依次关系:
Embedding
( 1 )-> Parameter
( 2 )-> embedding
-> has_torch_function_variadic
-> handle_torch_function
-> _get_overloaded_args
首先详细解析(1)和(2),最后基于上述知识解析Embedding
。
一. 开始解析(1)和(2):
( 1 )-> Parameter
class Parameter(torch.Tensor):
r"""一种被视为模块参数的Tensor类型。
参数是 :class:~torch.Tensor 的子类,当与 :class:Module 一起使用时,具有非常特殊的一个特性——当它们作为 :class:Module 的属性被赋值时,
会自动被添加到该模块的参数列表中,并且会出现在如 :meth:~Module.parameters 的迭代器中。直接赋值一个普通的Tensor并不具备这样的效果。
这是因为有时我们可能希望在模型中缓存一些临时状态,比如RNN的上一次隐藏状态。如果没有 :class:Parameter 这样的类,
这些临时变量也会被注册为模型的参数。
参数:
data (Tensor): 参数的Tensor数据。
requires_grad (布尔值, 可选): 表示该参数是否需要计算梯度。更多细节请参考 :ref:locally-disable-grad-doc。默认值: True
"""
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.tensor([])
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
memo[id(self)] = result
return result
def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, OrderedDict())
)
__torch_function__ = _disabled_torch_function_impl
(a)__new__
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.tensor([])
return torch.Tensor._make_subclass(cls, data, requires_grad)
__new__方法在Python中是一个特殊的方法,用于控制一个类的实例化过程。在PyTorch中,Parameter类的__new__方法被重写以定制Parameter实例的创建过程。让我们深入分析这段代码:
def __new__(cls, data=None, requires_grad=True):
这里定义了Parameter类的__new__方法,它接受两个参数:
data:这是用于初始化Parameter对象的数据,可以是任何可以转换为Tensor的对象,默认为None。
requires_grad:一个布尔值,表示Parameter对象是否需要计算梯度,默认为True。
if data is None: data = torch.tensor([])
首先,检查data参数是否为None。如果是None,则创建一个空的Tensor。这确保即使用户没有提供初始数据,Parameter也可以被正确初始化。
return torch.Tensor._make_subclass(cls, data, requires_grad)
这是__new__
方法的核心部分。torch.Tensor._make_subclass
是一个内部方法,用于从现有的Tensor
创建一个子类实例。在这个方法调用中:
cls
:指的是Parameter类本身,这告诉_make_subclass
方法创建一个Parameter类的实例。
data
:这是用于初始化Parameter的数据,它应该是已经转换为Tensor的对象。
requires_grad
:一个布尔值,表示这个Parameter实例是否需要计算梯度。
通过调用torch.Tensor._make_subclass
,Parameter类的__new__
方法实际上是在创建一个Tensor的子类实例,这个子类继承自Parameter类,并且具有data
和requires_grad
属性。
这个过程确保了Parameter
实例不仅具有普通Tensor
的所有属性和方法,而且还具有Parameter
类的特有功能,比如自动被加入到所属Module
的参数列表中,从而可以在训练过程中被优化算法正确地识别和更新。
总结来说,Parameter
类的__new__
方法通过torch.Tensor._make_subclass
创建了一个既有Tensor
特性又有Parameter
特性的对象,这在PyTorch
的模型构建和训练过程中起到了关键作用。
(b) __deepcopy__
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
memo[id(self)] = result
return result
__deepcopy__
方法是Python的内置方法,用于支持深度拷贝(deep copy)操作。深度拷贝意味着创建一个对象的完全独立副本,包括所有嵌套的对象。在PyTorch中,Parameter类的__deepcopy__
方法实现了对Parameter对象的深度拷贝,确保拷贝的Parameter对象与其原对象在内存中是完全分离的。
让我们逐步解析__deepcopy__
方法的实现:
if id(self) in memo:
这里的id(self)
获取的是当前Parameter对象的唯一内存地址。memo是一个字典,用于存储已经拷贝过的对象,以避免重复拷贝相同的对象。这一行代码检查当前Parameter对象是否已经被拷贝过了。
return memo[id(self)]
如果当前Parameter对象已经在memo字典中,说明它之前已经被拷贝过,那么就直接返回这个已经拷贝好的对象,避免重复拷贝。
else:
如果当前Parameter对象还没有被拷贝过,那么就进入这个分支。
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
这一行代码是创建一个新的Parameter对象的关键。它做了两件事:
使用type(self)
获取当前Parameter对象的类型,也就是Parameter类本身,这样可以创建一个新的同类型的Parameter对象。
调用self.data.clone(memory_format=torch.preserve_format)
来创建data属性的深拷贝。memory_format=torch.preserve_format
参数确保拷贝后的Tensor的内存布局与原Tensor相同。这是深度拷贝的核心,确保了data属性的独立性。
设置新Parameter对象的requires_grad
属性,确保拷贝的Parameter对象与原对象的梯度计算需求一致。
memo[id(self)] = result
将新创建的Parameter对象存储到memo
字典中,使用当前Parameter对象的内存地址作为键,这样下次如果遇到相同的对象就可以直接从memo
中取出拷贝结果。
return result
最后,返回新创建的Parameter对象。
总的来说,__deepcopy__
方法确保了在深度拷贝操作中,Parameter对象及其内部的data属性都被正确地独立复制,同时通过memo字典避免了不必要的重复拷贝,提高了效率。
( c ) __repr__
def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
( d ) __reduce_ex__
这是Parameter类的字符串表示方法,用于返回Parameter的可读性描述。它返回一个字符串,其中包含了Parameter的描述信息,以及内部Tensor的字符串表示。
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, OrderedDict())
)
__reduce_ex__
是Python对象序列化接口的一个重要组成部分,它主要用于支持Python的pickle模块进行对象的序列化和反序列化。pickle模块是Python内置的用于序列化和反序列化复杂对象的工具,它允许将Python对象保存到磁盘文件中或者在网络上传输,之后可以恢复为原来的对象状态。__reduce_ex__
方法的实现对于支持pickle模块的序列化至关重要,尤其是在处理像Parameter这样的复杂对象时。
解析__reduce_ex__
方法
def __reduce_ex__(self, proto):
proto参数是序列化协议的版本号,__reduce_ex__
方法需要根据这个版本号来确定序列化的策略。但在大多数情况下,proto
参数并不会直接影响序列化的过程。