Gumbel-Softmax 离散化sac
时间: 2025-05-10 10:14:32 浏览: 39
### Gumbel-Softmax在SAC中的离散化实现
#### 背景介绍
软演员-评论家(Soft Actor-Critic, SAC)是一种基于最大熵框架的强化学习算法,适用于连续动作空间的任务。然而,在面对离散动作空间时,传统的SAC无法直接应用,因为其核心依赖于对动作分布进行梯度计算的操作。为了克服这一限制,可以采用Gumbel-Softmax技巧将离散分布转化为可微分的形式。
Gumbel-Softmax的核心思想是在softmax函数的基础上引入噪声项,从而使得原本不可导的离散采样过程变得平滑且可微分[^2]。这种方法允许我们通过反向传播优化神经网络参数,即使目标涉及离散决策。
---
#### Gumbel-Softmax的工作机制
Gumbel-Softmax通过对logits加上经过缩放的Gumbel噪声来模拟离散采样的效果。具体而言:
1. **定义Logits**: 假设输入的动作概率由 logits $\pi(a|s)$ 表示,则对于每个可能的动作 $a_i$ 的 logit 记作 $z_i$。
2. **加入Gumbel噪声**: 对每个 $z_i$ 加上独立抽样的Gumbel随机变量 $g_i \sim Gumble(0, 1)$,得到新的值:
$$
y_i = z_i + g_i
$$
3. **应用Softmax**: 将上述加噪后的值送入 softmax 函数中,形成一个松弛化的概率分布:
$$
p_i = \frac{\exp(y_i / \tau)}{\sum_j \exp(y_j / \tau)}
$$
其中 $\tau > 0$ 是温度超参数,控制分布的锐利程度。当 $\tau \to 0$ 时,该分布逐渐逼近 one-hot 编码;而较大的 $\tau$ 则会使分布更加平坦。
4. **近似Argmax**: 使用最终的概率分布作为动作的选择依据,或者进一步取 argmax 来获得具体的离散动作。
这种技术不仅保留了原始策略的学习能力,还能够兼容 SAC 中所需的梯度估计操作。
---
#### 在SAC中的集成方式
以下是将Gumbel-Softmax融入SAC的具体步骤说明以及相应的PyTorch代码片段:
##### 步骤描述
- 政策网络输出未经激活的 logits 向量;
- 应用Gumbel-Softmax转换这些 logits 成为光滑的概率分布;
- 根据所得分布执行探索行为或评估期望回报。
下面是完整的Python实现例子:
```python
import torch
from torch.distributions import Categorical
def sample_gumbel(shape, eps=1e-20):
"""Sample from the Gumbel distribution."""
U = torch.rand(shape).cuda() # or .cpu()
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
"""Draw a sample from the Gumbel-Softmax distribution."""
y = logits + sample_gumbel(logits.size())
return torch.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature, hard=False):
"""
Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [..., n_classes] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax; otherwise use soft samples
Returns:
[..., n_classes] tensor of probabilities
"""
y = gumbel_softmax_sample(logits, temperature)
if not hard:
return y
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
return (y_hard - y).detach() + y
```
在此基础上构建actor部分即可完成整个系统的搭建工作[^3]。
---
#### 关键差异分析
需要注意的是,尽管标准 Softmax 和 Gumbel-Softmax 都会产生某种形式的概率质量分配,但由于后者额外包含了随机扰动成分,因此两者之间必然存在区别。尤其体现在极端情况下——比如非常低温条件下($\tau → 0$)—前者趋向均匀分布,而后者的输出则更接近真实样本点位置处的高度集中形态。
至于为何 `argmax(Gumbel-Softmax)` 不总是等于 `argmax(Softmax)` ,这是因为即便两者的理论峰值相同,实际数值仍会受到不同初始化条件的影响而导致细微偏差出现。这种情况属于正常现象,并不会影响整体性能表现。
---
阅读全文
相关推荐




















