Tensorflow学习9-2:谷歌inception-v3模型之图像分类测试

1
2
3
4
5
6
7
8
9
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
#模型存放目录
MOD_DIR = "D:/Tensorflow/models/inception/"
TEST_IMG_DIR = "D:/Tensorflow/Test Images/"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#节点映射解析类【目标获得1-1000分类数字 =》分类名称 的映射字典】
class NodeLookup(object):
def __init__(self):
label_lookup_path = MOD_DIR + "imagenet_2012_challenge_label_map_proto.pbtxt"
uid_lookup_path = MOD_DIR + "imagenet_synset_to_human_label_map.txt"
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)

def load(self, label_lookup_path, uid_lookup_path):
# 加载分类字符串n*******对应分类名称的文件
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
uid_to_human = {}
#一行一行读取数据
for line in proto_as_ascii_lines:
#去掉换行符 \n
line=line.strip("\n")
#按照 \t 分割
parsed_items = line.split("\t")
#获取分类编号
uid = parsed_items[0]
#获取分类名称
human_string = parsed_items[1]
#保存分类编号n*******和分类名称的映射关系
uid_to_human[uid] = human_string

# 加载分类字符串n*******对应分类编号1-1000的文件
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
node_id_to_uid = {}

for line in proto_as_ascii:
if line.startswith(" target_class:"):
#获取分类编号1-1000
target_class = int(line.split(": ")[1])
if line.startswith(" target_class_string:"):
#获取编号字符串n*******
target_class_string = line.split(": ")[1]
#保存分类编号1-1000和编号字符串n*******的映射关系
node_id_to_uid[target_class] = target_class_string[1:-2]

#现在联立2个映射,合成新的 分类编号1-1000到分类名称的映射
node_id_to_name = {}
for key, val in node_id_to_uid.items():
#获得分类名称
name = uid_to_human[val]
#建立映射
node_id_to_name[key] = name
return node_id_to_name

#查询函数【传入分类编号1-1000返回分类名称】
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
print("node_id not in self.node_lookup")
return ""
return self.node_lookup[node_id]

#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile(MOD_DIR + "classify_image_graph_def.pb", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")

with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name("softmax:0")
#遍历用于测试的图片目录
for root,dirs,files in os.walk(TEST_IMG_DIR):
for file in files:
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root,file), "rb").read()
predictions = sess.run(softmax_tensor, {"DecodeJpeg/contents:0" : image_data}) #jpg格式图片
predictions = np.squeeze(predictions) #吧结果转为1维数据

#打印图片路径及名称
image_path = os.path.join(root,file)
print(image_path)
#显示图片
img = Image.open(image_path)
plt.imshow(img)
plt.axis("off")
plt.show()

#排序
top_k = predictions.argsort()[-5:][::-1] #取得从大到小的5个值
node_lookup = NodeLookup()
for node_id in top_k:
#获取分类名称
human_string = node_lookup.id_to_string(node_id)
#获取该分类的概率
score = predictions[node_id]
print("%s (score = %.5f)" % (human_string, score))
print()

D:/Tensorflow/Test Images/555.jpg

gown (score = 0.27292)
hoopskirt, crinoline (score = 0.14043)
maillot (score = 0.10369)
brassiere, bra, bandeau (score = 0.06863)
bikini, two-piece (score = 0.05091)

D:/Tensorflow/Test Images/cl.jpg

torch (score = 0.40945)
volleyball (score = 0.11208)
racket, racquet (score = 0.09447)
tennis ball (score = 0.06729)
soccer ball (score = 0.04869)

D:/Tensorflow/Test Images/nissan jk.jpg

sports car, sport car (score = 0.49891)
beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon (score = 0.12817)
car wheel (score = 0.07217)
grille, radiator grille (score = 0.03533)
cab, hack, taxi, taxicab (score = 0.01612)

꧁༺The༒End༻꧂