在网上看了很多代码详解,感觉对数据的shape变化都没有一个简洁直观的解释过程,在这里,我尝试着自己总结一下。
X原始结构是[?, 784](对于X_batch而言,?=batch_size,下文都将用?表示),然后经过reshape变为[?, 28, 28]
lstm_cell包含100个节点
权重w的shape为[100, 10]
偏置b的shape为[10]
关键在于tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)的返回值是什么样的结构?
第一个返回值outputs的shape是[?, 28, 100]
第二个返回值final_state则包含两部分,一个是c,即final memory cell,另一个是h,即final hidden state,我们需要的是后者。这两部分的shape都是[?, 100]
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_DATA',one_hot=True)
n_inputs=28
n_rows=28
n_lstm=100
n_classes=10
batch_size=50
n_batch=mnist.train.num_examples//batch_size
X=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
w=tf.Variable(tf.truncated_normal([n_lstm,n_classes],stddev=0.1))
b=tf.Variable(tf.constant(0.1,shape=[n_classes]))
def RNN(X,w,b):
inputs=tf.reshape(X,[-1,n_rows,n_inputs])
'''
tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=1.0, state_is_tuple=True, activation=None, reuse=None, name)
num_units: int类型,指LSTM中的单元数量
forget_bias: 遗忘的偏置是0-1的数,1全记得,0全忘记
state_is_tuple: 如果是True,返回的是一个二元组,包含两个状态c_state和m_state。如果是False,沿着列方向将c_state和m_state拼接成一个向量。
activation: 内部状态的激活函数。 默认值:tanh
reuse(optional): 当值不为True时,重使用已存在的变量会报错
name: 这一层的名称,具有相同名称的层会共享权重参数,但此时reuse应该赋值为True
'''
lstm_cell=tf.contrib.rnn.BasicLSTMCell(n_lstm)
'''
tf.nn.dynamic_rnn(cell,inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None)
cell: 自己定义的cell内容,可以是BasicLSTMCell,BasicRNNCell,GRUCell等
inputs: 如果是time_major=True(默认值是False),input的维度是[max_time, batch_size, input_size],反之就是[batch_size, max_time, input_zise]
sequence_length=None:
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
return:
outputs: 如果是time_major=True,output的维度是[max_time, batch_size, Hidden_size],反之就是[batch_size, max_time, Hidden_size]
state: 最终状态
outputs里面,包含了所有时刻的输出 H;state里面,包含了最后一个时刻的输出 H 和 C
所以如果想用dynamic_rnn得到输出后,只需要最后一个时刻的状态输出,直接使用state 里面的 H 的就可以了。
'''
outputs, final_state=tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
results=tf.nn.softmax(tf.matmul(final_state[1],w)+b)
return results
prediction=RNN(X,w,b)
cross_entropy=tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y)
train=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
acc=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction,1),tf.argmax(y,1)),tf.float32))
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(21):
for batch in range(n_batch):
X_batch,y_batch=mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={X:X_batch,y:y_batch})
accuracy=sess.run(acc,feed_dict={X:mnist.test.images,y:mnist.test.labels})
print("Iteration "+str(epoch)+", accuracy = "+str(accuracy))