LSTM结构
lstm具体内容参考#继RNN之后-LSTM_rnn和lstm区别-CSDN博客
首先回顾一下LSTM的结构:
正向传播公式
还有LSTM正向传播的公式:
总体的代码如下:
代码
import numpy as np
from torch import sigmoid, softmax
def forward_prop(self,x):
#序列样本时间长度 例如 abcd a是t=0时刻 b是t=1时刻 c是t=2时刻 d是t=3时刻
T = len(x)
#初始化LSTM门控结构各个状态向量
states = self.init_state(T)
#对abcd这个序列进行累计求和 所以用for循环遍历abcd时刻
for t in range(T):
#前一时刻的隐藏层状态 就是ht-1
ht_pre = np.array(states["ht"][t-1]).reshape(-1,1)
xt = np.row_stack((ht_pre,x[t]))
#遗忘门
states["ft"][t] = sigmoid(np.dot(self.wf,xt)+self.bf)
#输入门
states["it"][t] = sigmoid(np.dot(self.wi,xt)+self.bi)
states["at"][t] = np.tanh(np.dot(self.wa,xt)+self.ba)
#更新细胞状态 ct = ft * ct_pre + it * at
states["ct"][t] = states["ft"][t] * states["ct"][t-1]+states["it"][t] * states["at"][t]
#输出门
states["ot"][t] = sigmoid(np.dot(self.wo,xt)+self.bo)
states["ht"][t] = states["ot"][t] * np.tanh(states["ct"][t])
#预测输出
states["yt"][t] = softmax(np.dot(self.wy,states["ht"][t])+self.by)
return states
代码分析
接下来逐行进行分析:
1.确定时间长度
T = len(x):
确定序列样本x的时间长度
- 计算了输入序列
x