【一文学会】Gumbel-Softmax的采样技巧

目录

基于softmax的采样

基于gumbel-max的采样

基于gumbel-softmax的采样

基于ST-gumbel-softmax的采样

Gumbel分布

回答问题一

回答问题二

回答问题三

附录


 

以强化学习为例,假设网络输出的三维向量代表三个动作(前进、停留、后退)在下一步的收益,value=[-10,10,15],那么下一步我们就会选择收益最大的动作(后退)继续执行,于是输出动作[0,0,1]。选择值最大的作为输出动作,这样做本身没问题,但是在网络中这种取法有个问题是不能计算梯度,也就不能更新网络。

基于softmax的采样

这时通常的做法是加上softmax函数,把向量归一化,这样既能计算梯度,同时值的大小还能表示概率的含义(多项分布)。

                                                    \fn_phv \large \pi_k = \frac{e^{x_k}}{\sum_{i=1}^{K} e^{x_{i}}}

于是value=[-10,10,15]通过softmax函数后有σ(value)=[0,0.007,0.993],这样做不会改变动作或者说类别的选取,同时softmax倾向于让最大值的概率显著大于其他值,比如这里15和10经过softmax放缩之后变成了0.993和0.007,这有利于把网络训成一个one-hot输出的形式,这种方式在分类问题中是常用方法。

但这样就不会体现概率的含义了,因为σ(value)=[0,0.007,0.993]与σ(value)=[0.3,0.2,0.5]在类别选取的结果看来没有任何差别,都是选择第三个类别,但是从概率意义上讲差别是巨大的。

很直接的方法是依概率采样完事了,比如直接用np.random.choice函数依照概率生成样本值,这样概率就有意义了。所以,经典的采样方法就是用softmax函数加上轮盘赌方法(np.random.choice)。但这样还是会有个问题,这种方式怎么计算梯度?不能计算梯度怎么更新网络?

def sample_with_softmax(logits, size):
# logits为输入数据
# size为采样数
    pro = softmax(logits)
    return np.random.choice(len(logits), size, p=pro)

 

基于gumbel-max的采样

gumbel分布的具体介绍会放在后文,我们先看看结论。对于K维概率向量\large \alpha,对\large \alpha对应的离散变量x_{i}=log(\alpha _i)添加Gumbel噪声,再取样

                                   \large x=\mathop{argmax}_i(\log(\alpha _i)+G_i)

其中,\fn_phv \large G_i是独立同分布的标准Gumbel分布的随机变量,标准Gumbel分布的CDF为F(x)=e^{-e^{-x}}.所以\large G_i可以通过Gumbel分布求逆从均匀分布生成,即G_i=-\log(-\log(U_i)),U_i\sim U(0,1)x_{i}=log(\alpha _i)代入计算可知,这里的\large \alpha就是上面softmax采样的\large \pi,这样就得到了基于gumbel-max的采样过程:

  • 对于网络输出的一个K维向量v,生成K个服从均匀分布U(0,1)的独立样本ϵ1,...,ϵK;

  • 通过G_i=-\log(-\log(\varepsilon _i))计算得到G_i;

  • 对应相加得到新的值向量v′=[v1+G1,v2+G2,...,vK+GK];

  • 取最大值作为最终的类别

可以证明,gumbel-max 方法的采样效果等效于基于 softmax 的方式(后文也会证明)。由于 Gumbel 随机数可以预先计算好,采样过程也不需要计算 softmax,因此,某些情况下,gumbel-max 方法相比于 softmax,在采样速度上会有优势。当然,可以看到由于这中间有一个argmax操作,这是不可导的,依旧没法用于计算网络梯度。

def sample_with_gumbel_noise(logits, size):
    noise = sample_gumbel((size, len(logits)))    # 产生gumbel noise
    return np.argmax(logits + noise, axis=1)

 

基于gumbel-softmax的采样

如果仅仅是提供一种常规 softmax 采样的替代方案, gumbel 分布似乎应用价值并不大。幸运的是,我们可以利用 gumbel 实现多项分布采样的 reparameterization(再参数化)。

VAE中,假设隐变量(latent variables)服从标准正态分布。而现在,利用 gumbel-softmax 技巧,我们可以将隐变量建模为服从离散的多项分布。在前面的两种方法中,random.choice和argmax注定了这两种方法不可导,但我们可以将后一种方法中的argmax soft化,变为softmax。

                              \large x=\mathop{softmax}((\log(\alpha _i)+G_i)/temperature)

temperature 是在大于零的参数,它控制着 softmax 的 soft 程度。温度越高,生成的分布越平滑;温度越低,生成的分布越接近离散的 one-hot 分布。训练中,可以通过逐渐降低温度,以逐步逼近真实的离散分布。

这样就得到了基于gumbel-max的采样过程:

  • 对于网络输出的一个K维向量v,生成K个服从均匀分布U(0,1)的独立样本ϵ1,...,ϵK;

  • 通过

评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值