Reparameterization 重参数/Gumbel-Max/Gumbel-Softmax

本文探讨了重参数化在期望形式目标函数处理中的作用,介绍了如何通过Gumbel-Softmax处理离散分布采样并保持梯度可导性。对比了softmax与Gumbel-Softmax在分类任务中的区别,以及它们在实际应用中的关键技术和原理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、三者之间的关系

  重参数是一种处理期望形式目标函数的方法,处理这种目标函数意味着要从分布中采样。该分布中带有参数,如果直接采样的话,就会失去参数的梯度。而重参数提供这样一种变换,是我们可以直接从分布中采样,并且保留参数的梯度。

  如果从离散分布中采样,需要定义一种可微分的与离散分布近似的取样方法,这就需要用到Gumbel-Softmax。

  Gumbel-Max是Gumbel-Softmax的基础,提供了一种从类别分布中采样的方法。但Gumbel-Max因为存在取最大值的操作,故其不可导。而Gumbel-Softmax是Gumbel-Max的光滑近似版,其可导,可以保证参数的反向传播。

  参考资料:

漫谈重参数:从正态分布到Gumbel-Softmax

函数光滑化杂谈:不可导函数的可导逼近

词向量和embedding

2、重参数的一般方法

 首先从无参分布中采样,然后通过变换实现要完成分布的采样(连续分布是直接进行变换,离散分布是依概率进行采样)。

3、softmax和Gumbel-Softmax的区别

  分类任务用softmax,它并不是采样过程,因为它不包含随机性。Gumbel-Softmax从均匀分布中进行采样,进行随机扰动,保证了采样的随机性。

参考资料:

如何理解Gumbel-Max trick

Gumbel-Softmax Trick和Gumbel分布

 

 

 

### Argmax采样与Gumbel-Softmax采样的对比 #### 定义与机制 Argmax操作是从给定的一组概率中选取具有最高概率的类别作为输出。这一过程是非随机性的,总是返回最可能的结果。 相比之下,Gumbel-Softmax引入了随机性,在logits上加上来自Gumbel分布的噪声之后再执行Softmax函数来近似离散分布抽样[^2]。这使得模型能够在训练期间保留梯度信息并允许反向传播算法更新参数。 #### 应用场景 当需要确切的最大值预测时,比如分类任务中的最终决策阶段,通常采用argmax方法。然而,在涉及强化学习或变分自编码器等情况下,其中期望对不确定性和探索有所表示,则更倾向于使用带有温度参数控制平滑程度的Gumbel-Softmax技巧[^1]。 #### 优点分析 ##### Argmax采样 - **简单直观**:直接选择最大可能性选项,易于理解和实现。 - **确定性强**:每次都会给出相同的决定,适合于那些不需要考虑多样性的场合。 ##### Gumbel-Softmax采样 - **可微分性**:由于加入了连续松弛(即通过调整τ),因此可以在神经网络框架内方便地计算损失相对于输入logit的导数。 - **支持不确定性建模**:能够表达不同样本间的差异以及提供一定程度上的多样性,有助于防止过拟合现象的发生。 #### 缺点探讨 ##### Argmax采样 - **缺乏灵活性**:一旦做出选择就无法改变,不利于处理动态环境下的问题求解。 - **不可导特性**:因为它是硬性分配而不是软性加权平均的形式,所以在某些特定架构下难以与其他组件协同工作。 ##### Gumbel-Softmax采样 - **复杂度增加**:相较于简单的取大运算而言,增加了额外的操作步骤和超参调节需求(如设置合适的τ值)。 - **潜在偏差风险**:如果设定不当的话可能会导致估计结果偏离真实情况较远;另外,随着维度的增长也可能面临数值稳定性挑战。 ```python import torch from torch.distributions import gumbel def sample_gumbel_softmax(logits, temperature=1.0): """Sample from the Gumbel-Softmax distribution and optionally discretize. Args: logits: [batch_size, n_class] unnormalized log-probs temperature: non-negative scalar Returns: Sampled tensor of shape [batch_size, n_class] """ y = logits + gumbel.Gumbel(0, 1).sample(logits.shape) return torch.nn.functional.softmax(y / temperature, dim=-1) # Example usage with argmax vs gumbel-softmax sampling logits = torch.tensor([[1., 2., 3.]]) print("Argmax Sampling:", torch.argmax(logits, dim=-1)) for temp in [0.5, 1.0, 2.0]: print(f"Gumbel-Softmax Sampling (temp={temp}):", sample_gumbel_softmax(logits, temperature=temp)) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值