论文
《CONVOLUTIONAL, LONG SHORT-TERM MEMORY,FULLY CONNECTED DEEP NEURAL NETWORKS》
CLDNN网络结构描述
考虑到LSTM局限性,CLDNN通过结合三种网络结构,来解决问题:
- 把特征输入到CNN层,降低谱差异性;
- 把CNN的输出输入到LSTM建模时序特征;
- 把LSTM的输出作为DNN的输入,减少LSTM隐层的变化,使得特征转化到更可分的空间。
Tensorflow网络实现测试Mnist
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tf.reset_default_graph()
def cnn(input_tensor):
input_tensor = tf.reshape(input_tensor, [-1, 28, 28, 1])
with tf.name_scope('conv-1') as scope:
conv1_weights = tf.Variable(initial_value=tf.truncated_normal(shape=[5, 5, 1, 64], dtype=tf.float32, stddev=0.01),
name='weights'
)
conv1_bias = tf.Variable(
initial_value=tf.truncated_normal(shape=[64], dtype=tf.float32, stddev=0.01),
name='bias'
)
conv1 = tf.nn.conv2d(input_tensor, conv1_weights, strides=[1, 1, 1, 1], padding='SAME')
conv1 = tf.nn.bias_add(conv1, conv1_bias)
relu1 = tf.nn.relu(conv1, name=scope)
with tf.name_scope('pool-1'):
pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
with tf.name_scope('conv-2') as scope:
conv2_weights = tf.Variable(
initial_value=tf.truncated_normal(shape=[5, 5, 64, 64], dtype=tf.float32, stddev=0.01),
name='wieghts'
)
conv2_bias = tf.Variable(
initial_value=tf.truncated_normal(shape=[64], dtype=tf.float32, stddev=0.01),
name='bias'
)
conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')
conv2 = tf.nn.bias_add(conv2, conv2_bias)
relu2 = tf.nn.relu(conv2, name=scope)
return relu2
def linear(input_tensor):
input_shape = input_tensor.get_shape().as_list()
width = input_shape[1]
height = input_shape[2]
maps = input_shape[3]
flatten_size = width * height * maps
input_tensor = tf.reshape(input_tensor, shape=[-1, flatten_size])
with tf.name_scope('Linear'):
dense_weights = tf.Variable(tf.truncated_normal(shape=[flatten_size, 256],stddev=0.1), dtype=tf.float32)
dense_bias = tf.Variable(tf.truncated_normal(shape=[256]), dtype=tf.float32)
result = tf.nn.bias_add(tf.matmul(input_tensor, dense_weights), dense_bias)
return result
# parameters init
l_r = 0.001
training_iters = 100000
batch_size = 128
n_inputs = 16
n_steps = 16
n_hidden_units = 128
n_classes = 10
def lstm(input_tensor):
with tf.variable_scope('lstm1'):
lstm1_weights_in = tf.get_variable("weight_in", [n_inputs,n_hidden_units],initializer = tf.random_normal_initializer())
lstm1_biases_in = tf.get_variable("bias_in", [n_hidden_units,],initializer = tf.constant_initializer(0.1))
lstm1_weights_out = tf.get_variable("weight_out", [n_hidden_units,n_classes],initializer = tf.random_normal_initializer())
lstm1_biases_out = tf.get_variable("bias_out", [n_classes,],initializer = tf.constant_initializer(0.1))
input_tensor = tf.reshape(input_tensor, [-1, n_inputs])
x_in = tf.matmul(input_tensor, lstm1_weights_in) + lstm1_biases_in
x_in = tf.reshape(x_in, [-1, n_steps, n_hidden_units])
lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)
_init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs,states = tf.nn.dynamic_rnn(lstm_cell, x_in, initial_state=_init_state, time_major=False)
outputs = tf.unstack(tf.transpose(outputs, [1,0,2]))
results = tf.matmul(outputs[-1], lstm1_weights_out) + lstm1_biases_out
return results
def inference(input_tensor):
net = cnn(input_tensor)
net = linear(net)
net = lstm(net)
return net
#load mnist data
mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
#define placeholder for input
x = tf.placeholder(tf.float32, [None, 28*28*1])
y_ = tf.placeholder(tf.float32, [None, n_classes])
y = inference(x)
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_, 1), logits=y))
train_op = tf.train.AdamOptimizer(l_r).minimize(cost)
correct_pred = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
#init session
sess = tf.Session()
#init all variables
sess.run(tf.global_variables_initializer())
#start training
#for i in range(training_iters):
for i in range(training_iters):
#get batch to learn easily
batch_x, batch_y = mnist.train.next_batch(batch_size)
# batch_x = batch_x.reshape([batch_size, n_steps, n_inputs])
sess.run(train_op,feed_dict={x: batch_x, y_: batch_y})
if i % 50 == 0:
print(sess.run(accuracy,feed_dict={x: batch_x, y_: batch_y,}))
#test_data = mnist.test.images.reshape([-1, n_steps, n_inputs])
#test_label = mnist.test.labels
#print("Testing Accuracy: ", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
参考:
https://github.com/xxoospring/CLDNNs