Tensorflow学习8:模型的保存和加载

1
2
3
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
MOD_DIR = "D:/Tensorflow/models/"
1
# 模型的保存
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

batch_size = 100
batch_num = mnist.train.num_examples // batch_size
input_x = tf.placeholder(tf.float32, [None, 784])
input_y = tf.placeholder(tf.float32, [None, 10])

W1 = tf.Variable(tf.truncated_normal([784,128], 0.,0.5))
b1 = tf.Variable(tf.zeros([128]) + 0.1)
L1 = tf.nn.relu(tf.matmul(input_x, W1) + b1)

W2 = tf.Variable(tf.truncated_normal([128,10], 0.,0.5))
b2 = tf.Variable(tf.zeros([10]) + 0.1)
L2 = tf.matmul(L1, W2) + b2


loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=L2, labels=input_y))
train = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

correct_indices = tf.equal(tf.argmax(input_y, 1), tf.argmax(L2, 1))
accuracy = tf.reduce_mean(tf.cast(correct_indices, tf.float32))

# STEP 1:定义saver对象
saver = tf.train.Saver()

init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(11):
for batch in range(batch_num):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train, feed_dict={input_x:batch_xs, input_y:batch_ys})
acc = sess.run(accuracy, feed_dict={input_x:mnist.test.images, input_y:mnist.test.labels})
print("epoch:"+ str(epoch) + ", accuracy:"+ str(acc))

# STEP2:用saver保存模型
saver.save(sess, MOD_DIR + "test_net.ckpt")
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
epoch:0, accuracy:0.9057
epoch:1, accuracy:0.9222
epoch:2, accuracy:0.9296
epoch:3, accuracy:0.9394
epoch:4, accuracy:0.9423
epoch:5, accuracy:0.9448
epoch:6, accuracy:0.9489
epoch:7, accuracy:0.9479
epoch:8, accuracy:0.9516
epoch:9, accuracy:0.9507
epoch:10, accuracy:0.9552
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

batch_size = 100
batch_num = mnist.train.num_examples // batch_size
input_x = tf.placeholder(tf.float32, [None, 784])
input_y = tf.placeholder(tf.float32, [None, 10])

W1 = tf.Variable(tf.truncated_normal([784,128], 0.,0.5))
b1 = tf.Variable(tf.zeros([128]) + 0.1)
L1 = tf.nn.relu(tf.matmul(input_x, W1) + b1)

W2 = tf.Variable(tf.truncated_normal([128,10], 0.,0.5))
b2 = tf.Variable(tf.zeros([10]) + 0.1)
L2 = tf.matmul(L1, W2) + b2


loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=L2, labels=input_y))
train = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

correct_indices = tf.equal(tf.argmax(input_y, 1), tf.argmax(L2, 1))
accuracy = tf.reduce_mean(tf.cast(correct_indices, tf.float32))

# STEP 1:定义saver对象
saver = tf.train.Saver()

init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(accuracy, feed_dict={input_x:mnist.test.images, input_y:mnist.test.labels}))

# STEP2:用saver加载模型,打印载入训练模型后的准确率变化
saver.restore(sess, MOD_DIR + "test_net.ckpt")
print(sess.run(accuracy, feed_dict={input_x:mnist.test.images, input_y:mnist.test.labels}))
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
0.1311
INFO:tensorflow:Restoring parameters from D:/Tensorflow/models/test_net.ckpt
0.9552

总结:
如果是.pb模型文件,

保存用 graph_def()

导入用 tf.import_graph_def()

꧁༺The༒End༻꧂