解析Torch中 `Embedding`

依次关系:
Embedding
( 1 )-> Parameter
( 2 )-> embedding -> has_torch_function_variadic -> handle_torch_function-> _get_overloaded_args

首先详细解析(1)和(2),最后基于上述知识解析Embedding

一. 开始解析(1)和(2):

( 1 )-> Parameter

class Parameter(torch.Tensor):
    r"""一种被视为模块参数的Tensor类型。

    参数是 :class:~torch.Tensor 的子类,当与 :class:Module 一起使用时,具有非常特殊的一个特性——当它们作为 :class:Module 的属性被赋值时,
    会自动被添加到该模块的参数列表中,并且会出现在如 :meth:~Module.parameters 的迭代器中。直接赋值一个普通的Tensor并不具备这样的效果。
    这是因为有时我们可能希望在模型中缓存一些临时状态,比如RNN的上一次隐藏状态。如果没有 :class:Parameter 这样的类,
    这些临时变量也会被注册为模型的参数。

    参数:
        data (Tensor): 参数的Tensor数据。
        requires_grad (布尔值, 可选): 表示该参数是否需要计算梯度。更多细节请参考 :ref:locally-disable-grad-doc。默认值: True
    """
    def __new__(cls, data=None, requires_grad=True):
        if data is None:
            data = torch.tensor([])
        return torch.Tensor._make_subclass(cls, data, requires_grad)

    def __deepcopy__(self, memo):
        if id(self) in memo:
            return memo[id(self)]
        else:
            result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
            memo[id(self)] = result
            return result

    def __repr__(self):
        return 'Parameter containing:\n' + super(Parameter, self).__repr__()

    def __reduce_ex__(self, proto):
        # See Note [Don't serialize hooks]
        return (
            torch._utils._rebuild_parameter,
            (self.data, self.requires_grad, OrderedDict())
        )

    __torch_function__ = _disabled_torch_function_impl

(a)__new__

def __new__(cls, data=None, requires_grad=True):
    if data is None:
        data = torch.tensor([])
    return torch.Tensor._make_subclass(cls, data, requires_grad)

__new__方法在Python中是一个特殊的方法,用于控制一个类的实例化过程。在PyTorch中,Parameter类的__new__方法被重写以定制Parameter实例的创建过程。让我们深入分析这段代码:

def __new__(cls, data=None, requires_grad=True):

这里定义了Parameter类的__new__方法,它接受两个参数:

data:这是用于初始化Parameter对象的数据,可以是任何可以转换为Tensor的对象,默认为None。
requires_grad:一个布尔值,表示Parameter对象是否需要计算梯度,默认为True。

if data is None: data = torch.tensor([])

首先,检查data参数是否为None。如果是None,则创建一个空的Tensor。这确保即使用户没有提供初始数据,Parameter也可以被正确初始化。

return torch.Tensor._make_subclass(cls, data, requires_grad)

这是__new__方法的核心部分。torch.Tensor._make_subclass是一个内部方法,用于从现有的Tensor创建一个子类实例。在这个方法调用中:

cls:指的是Parameter类本身,这告诉_make_subclass方法创建一个Parameter类的实例。
data:这是用于初始化Parameter的数据,它应该是已经转换为Tensor的对象。
requires_grad:一个布尔值,表示这个Parameter实例是否需要计算梯度。
通过调用torch.Tensor._make_subclass,Parameter类的__new__方法实际上是在创建一个Tensor的子类实例,这个子类继承自Parameter类,并且具有datarequires_grad属性。

这个过程确保了Parameter实例不仅具有普通Tensor的所有属性和方法,而且还具有Parameter类的特有功能,比如自动被加入到所属Module的参数列表中,从而可以在训练过程中被优化算法正确地识别和更新。

总结来说,Parameter类的__new__方法通过torch.Tensor._make_subclass创建了一个既有Tensor特性又有Parameter特性的对象,这在PyTorch的模型构建和训练过程中起到了关键作用。
(b) __deepcopy__

def __deepcopy__(self, memo):
    if id(self) in memo:
        return memo[id(self)]
    else:
        result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
        memo[id(self)] = result
        return result

__deepcopy__方法是Python的内置方法,用于支持深度拷贝(deep copy)操作。深度拷贝意味着创建一个对象的完全独立副本,包括所有嵌套的对象。在PyTorch中,Parameter类的__deepcopy__方法实现了对Parameter对象的深度拷贝,确保拷贝的Parameter对象与其原对象在内存中是完全分离的。

让我们逐步解析__deepcopy__方法的实现:

if id(self) in memo:

这里的id(self)获取的是当前Parameter对象的唯一内存地址。memo是一个字典,用于存储已经拷贝过的对象,以避免重复拷贝相同的对象。这一行代码检查当前Parameter对象是否已经被拷贝过了。

return memo[id(self)]

如果当前Parameter对象已经在memo字典中,说明它之前已经被拷贝过,那么就直接返回这个已经拷贝好的对象,避免重复拷贝。

else:

如果当前Parameter对象还没有被拷贝过,那么就进入这个分支。

result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)

这一行代码是创建一个新的Parameter对象的关键。它做了两件事:

使用type(self)获取当前Parameter对象的类型,也就是Parameter类本身,这样可以创建一个新的同类型的Parameter对象。
调用self.data.clone(memory_format=torch.preserve_format)来创建data属性的深拷贝。memory_format=torch.preserve_format参数确保拷贝后的Tensor的内存布局与原Tensor相同。这是深度拷贝的核心,确保了data属性的独立性。
设置新Parameter对象的requires_grad属性,确保拷贝的Parameter对象与原对象的梯度计算需求一致。

memo[id(self)] = result

将新创建的Parameter对象存储到memo字典中,使用当前Parameter对象的内存地址作为键,这样下次如果遇到相同的对象就可以直接从memo中取出拷贝结果。

return result

最后,返回新创建的Parameter对象。

总的来说,__deepcopy__方法确保了在深度拷贝操作中,Parameter对象及其内部的data属性都被正确地独立复制,同时通过memo字典避免了不必要的重复拷贝,提高了效率。
( c ) __repr__

def __repr__(self):
    return 'Parameter containing:\n' + super(Parameter, self).__repr__()

( d ) __reduce_ex__
这是Parameter类的字符串表示方法,用于返回Parameter的可读性描述。它返回一个字符串,其中包含了Parameter的描述信息,以及内部Tensor的字符串表示。

def __reduce_ex__(self, proto):
    # See Note [Don't serialize hooks]
    return (
        torch._utils._rebuild_parameter,
        (self.data, self.requires_grad, OrderedDict())
    )

__reduce_ex__是Python对象序列化接口的一个重要组成部分,它主要用于支持Python的pickle模块进行对象的序列化和反序列化。pickle模块是Python内置的用于序列化和反序列化复杂对象的工具,它允许将Python对象保存到磁盘文件中或者在网络上传输,之后可以恢复为原来的对象状态。__reduce_ex__方法的实现对于支持pickle模块的序列化至关重要,尤其是在处理像Parameter这样的复杂对象时。

解析__reduce_ex__方法

def __reduce_ex__(self, proto):

proto参数是序列化协议的版本号,__reduce_ex__方法需要根据这个版本号来确定序列化的策略。但在大多数情况下,proto参数并不会直接影响序列化的过程。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值