import torch import torch.nn as nn class RNNCell(nn.Module): def __init__(self, input_dim, hidden_dim): super().__init__() self.W_hx = nn.Linear(input_dim, hidden_dim) self.W_hh = nn.Linear(hidden_dim, hidden_dim) def forward(self, x, h_prev): h_next = torch.tanh(self.W_hx(x) + self.W_hh(h_prev)) return h_next
时间: 2025-07-26 10:32:02 浏览: 11
<think>我们正在实现一个基本的RNN单元。RNN单元在每个时间步接收当前输入和上一个隐藏状态,输出新的隐藏状态。公式:$$h_t =\tanh(W_{ih} x_t +b_{ih}+ W_{hh} h_{t-1} +b_{hh})$$在PyTorch中,我们可以通过继承`torch.nn.Module`来创建自定义RNN单元。步骤:1.初始化:定义需要的参数,包括输入维度和隐藏层维度。2.定义参数:权重矩阵(W_ih,W_hh)和偏置(b_ih,b_hh)。3.前向传播:根据公式计算当前隐藏状态。注意:在PyTorch中,我们使用`torch.nn.Parameter`来定义可训练参数。我们将构建一个名为`SimpleRNNCell`的类。代码实现如下:</think>### 实现基础RNN单元(PyTorch)
在PyTorch中实现基础RNN单元的核心是理解其数学公式:
$$h_t = \tanh(W_{ih}x_t + b_{ih} + W_{hh}h_{t-1} + b_{hh})$$
其中:
- $x_t$:时间步$t$的输入向量
- $h_t$:当前隐藏状态
- $h_{t-1}$:前一步隐藏状态
- $W_{ih}, W_{hh}$:权重矩阵
- $b_{ih}, b_{hh}$:偏置向量
#### 代码实现
```python
import torch
import torch.nn as nn
class SimpleRNNCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
# 初始化参数
self.weight_ih = nn.Parameter(torch.randn(hidden_size, input_size))
self.weight_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))
self.bias_ih = nn.Parameter(torch.randn(hidden_size))
self.bias_hh = nn.Parameter(torch.randn(hidden_size))
def forward(self, x, hidden):
# RNN前向计算
hy = torch.tanh(
x @ self.weight_ih.t() + self.bias_ih +
hidden @ self.weight_hh.t() + self.bias_hh
)
return hy
# 示例用法
input_size = 10
hidden_size = 20
rnn_cell = SimpleRNNCell(input_size, hidden_size)
# 模拟输入序列 (seq_len=5, batch_size=3)
inputs = torch.randn(5, 3, input_size)
hidden = torch.zeros(3, hidden_size) # 初始隐藏状态
# 按时间步处理序列
outputs = []
for x in inputs:
hidden = rnn_cell(x, hidden)
outputs.append(hidden)
```
#### 关键说明:
1. **参数初始化**:
- `weight_ih`:输入到隐藏层的权重矩阵(维度$hidden\_size \times input\_size$)
- `weight_hh`:隐藏层到隐藏层的权重矩阵(维度$hidden\_size \times hidden\_size$)
- 使用`nn.Parameter`确保参数可训练
2. **前向传播**:
- 实现公式$h_t = \tanh(W_{ih}x_t + W_{hh}h_{t-1} + b)$
- `@`运算符进行矩阵乘法,`.t()`表示转置
- 激活函数使用`tanh`保持输出在$[-1,1]$范围
3. **序列处理**:
- 需手动循环处理时间步
- 每个时间步更新隐藏状态并收集输出
#### 与PyTorch内置RNN对比
PyTorch内置`nn.RNNCell`已提供相同功能[^2],但自定义实现有助于理解底层机制:
```python
# 内置实现等效代码
builtin_rnn = nn.RNNCell(input_size, hidden_size)
```
#### 训练注意事项
1. 使用`nn.init.xavier_uniform_()`初始化权重可改善收敛
2. 长序列需结合梯度裁剪防止梯度爆炸
3. 实际应用建议使用更高级单元(LSTM/GRU)解决梯度消失问题[^2]
阅读全文
相关推荐




















