Tensorflow学习9-1:谷歌inception-v3模型之下载模型和查看结构

1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf
import os
import tarfile
import requests

#模型下载地址
MOD_URL = "http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
#模型存放目录
MOD_DIR = "D:/Tensorflow/models/inception/"
#模型结构存放目录
LOG_DIR = "D:/Tensorflow/logs/inception/"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

if not os.path.exists(MOD_DIR):
os.makedirs(MOD_DIR)

#取得文件名,以及完整路径
file_name = MOD_URL.split("/")[-1]
file_path = os.path.join(MOD_DIR, file_name)
#file_path = MOD_DIR + file_name

#下载模型
if not os.path.exists(file_path):
print("download: ", file_name)
r = requests.get(MOD_URL, stream=True)
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if(chunk):
f.write(chunk)
print("finish: ", file_name)

#解压模型文件到指定目录
tarfile.open(file_path, "r:gz").extractall(MOD_DIR)

#存放模型结构(用tensorboard查看)
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)

# 解压后的xxx.pb是训练好的模型
inception_model_path = os.path.join(MOD_DIR, "classify_image_graph_def.pb")
with tf.Session() as sess:
#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile(inception_model_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
#保存图的结构
writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
writer.close()

finish: inception-2015-12-05.tgz

结构如下图

꧁༺The༒End༻꧂