Tensorflow学习7:长短期记忆网络LSTM

用LSTM网络进行手写数字识别

1
2
3
4
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
PWS_DIR = "C:/Users/lenovo/Desktop/Python WORK SPACE/"
LOG_DIR = "D:/Tensorflow/logs/"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
mnist = input_data.read_data_sets(PWS_DIR + "DATASET/MNIST_data", one_hot=True)

n_inputs = 28 #图片一行一行输入,一行28个数据
max_time = 28 #一共28行(时间序列长度)
lstm_size = 100 #隐层block数
n_class = 10
batch_size = 50
n_batch = mnist.train.num_examples // batch_size

x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

weights = tf.Variable(tf.truncated_normal([lstm_size, n_class], stddev=0.1))
biases = tf.Variable(tf.constant(0.1, shape=[n_class]))

def LSTM(X, weights, biases):
#inputs = [batch_size, max_time, n_inputs]
inputs = tf.reshape(X, [-1, max_time, n_inputs])
#定义LSTM基本CELL
lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
# final_state[state, batch_size, cell.state_size] (state_size即lstm_size)
# final_state[state=0]是cell state
# final_state[state=1]是hidden state (最终输出)
# outputs: The RNN output `Tensor`. (单个序列输出)
# If time_major == False (default), this will be a `Tensor` shaped:
# `[batch_size, max_time, cell.output_size]`.
# If time_major == True, this will be a `Tensor` shaped:
# `[max_time, batch_size, cell.output_size]`.
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, inputs, dtype=tf.float32)
results = tf.nn.softmax(tf.matmul(final_state[1], weights) + biases)
return results

#计算LSTM返回结果
predict = LSTM(x, weights, biases)

#使用交叉熵代价函数
loss_cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=predict))
#使用AdamOptimizer进行优化
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss_cross_entropy)
#结果放在一个BOOL型列表中
predict_bool = tf.equal(tf.argmax(y, 1), tf.argmax(predict, 1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(predict_bool, tf.float32))

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(6):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})

acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print("Iter" + str(epoch) + ", Testing Accuracy= " + str(acc))

Extracting C:/Users/lenovo/Desktop/Python WORK SPACE/DATASET/MNIST_data\train-images-idx3-ubyte.gz
Extracting C:/Users/lenovo/Desktop/Python WORK SPACE/DATASET/MNIST_data\train-labels-idx1-ubyte.gz
Extracting C:/Users/lenovo/Desktop/Python WORK SPACE/DATASET/MNIST_data\t10k-images-idx3-ubyte.gz
Extracting C:/Users/lenovo/Desktop/Python WORK SPACE/DATASET/MNIST_data\t10k-labels-idx1-ubyte.gz
Iter0, Testing Accuracy= 0.7837
Iter1, Testing Accuracy= 0.8641
Iter2, Testing Accuracy= 0.8951
Iter3, Testing Accuracy= 0.9087
Iter4, Testing Accuracy= 0.9263
Iter5, Testing Accuracy= 0.9325

꧁༺The༒End༻꧂