Tensorflow学习9-4:谷歌inception-v3模型之生成tfrecord文件

生成tfrecord文件,用于从零训练自己的模型或者fine-tune微调训练。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#tfrecord文件,底层就是protobuf格式
import tensorflow as tf
import os
import random
import math
import sys

#验证集数量
_NUM_VALID = 1000
#随机种子
_RANDOM_SEED = 7
#数据块
_NUM_SHARDS = 2
#
DATASET_DIR = "D:/Tensorflow/slim/images2/"
#标签文件名
LABELS_FILENAME = "D:/Tensorflow/slim/images2/labels.txt"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#定义tfrecord文件的路径+名字
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = "image_%s_%05d-of-%05d.tfrecord" % (split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)

#判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
for split_name in ["train", "validation"]:
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True

#获取总图片文件夹下的 所有图片文件名以及分类(子文件夹名)
def _get_filenames_and_classes(dataset_dir):
#数据目录
directories = []
#分类名称
class_names = []
for filename in os.listdir(dataset_dir):
#合并文件路径
path = os.path.join(dataset_dir, filename)
#判断该路径是否为目录
if os.path.isdir(path):
#加入数据目录
directories.append(path)
#加入类别名称
class_names.append(filename)

photo_filenames = []
#循环每个分类的文件夹
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
#把图片加入图片列表
photo_filenames.append(path)

return photo_filenames, class_names

def int64_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

def image_to_tfexample(image_data, image_format, class_id):
#Abstract base class for protocol message
return tf.train.Example(features=tf.train.Features(feature={
"image/encoded": bytes_feature(image_data),
"image/format": bytes_feature(image_format),
"image/class/label": int64_feature(class_id),
}))

def write_label_file(labels_to_class_names, dataset_dir, filename=LABELS_FILENAME):
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, "w") as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write("%d:%s\n" % (label, class_name))

#把数据转为tfrecord格式
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
assert split_name in ["train", "validation"]
#切分数据块维多个tfrecord文件,计算每个数据块有多少
num_per_shard = int(len(filenames) / _NUM_SHARDS)
with tf.Graph().as_default():
with tf.Session() as sess:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord文件路径
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
#每一个数据块开始的位置
start_ndx = shard_id * num_per_shard
#每一个数据块最后的位置
end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
for i in range(start_ndx, end_ndx):
try:
sys.stdout.write("\r>> Converting image(%s) %d/%d shard %d" % (split_name, i+1, len(filenames), shard_id))
sys.stdout.flush()
#读取图片
image_data = tf.gfile.FastGFile(filenames[i], "rb").read()
#获得图片的类别名称
class_name = os.path.basename(os.path.dirname(filenames[i]))
#找到类别名称对应的ID
class_id = class_names_to_ids[class_name]
#生成tfrecord文件
example = image_to_tfexample(image_data, b"jpg", class_id)
tfrecord_writer.write(example.SerializeToString())
except IOError as e:
print("Could not read:", filenames[i])
print("Error:",e)
print("Skip the pic.\n")
sys.stdout.write("\n")
sys.stdout.flush()


if __name__ == "__main__":
#判断tfrecord文件是否存在
if _dataset_exists(DATASET_DIR):
print("tfrecord文件已存在")
else:
#获得所有图片以及分类
photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)
#吧分类转为字典格式,类似于{"house": 3, "flower": 1}
class_names_to_ids = dict(zip(class_names, range(len(class_names))))

#数据切分为训练集和测试集
random.seed(_RANDOM_SEED)
random.shuffle(photo_filenames)
training_filenames = photo_filenames[_NUM_VALID:] #500之后的图片作为训练
validation_filenames = photo_filenames[:_NUM_VALID] #0-500的图片作为训练
# for var in training_filenames:
# print("training_filenames: ", os.path.basename(var))
# for var in validation_filenames:
# print("validation_filenames: ", os.path.basename(var))

#数据转换
_convert_dataset("train", training_filenames, class_names_to_ids, DATASET_DIR)
_convert_dataset("validation", validation_filenames, class_names_to_ids, DATASET_DIR)

#输出labels文件
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
write_label_file(labels_to_class_names, DATASET_DIR)

Converting image(train) 3800/3800 shard 1
Converting image(validation) 1000/1000 shard 1

꧁༺The༒End༻꧂