tensorflow系列(4)tfrecords的使用

Posted by grt1stnull on 2017-07-22

利用tfrecords文件格式来高效读取数据吧。

0x00.前言

最近涉及到模型的测试,需要读取数据。通常我们可以直接通过文件读取,自己写一个,方便快捷。但是考虑到要在集群上运行,这个数据文件(csv,约1.4G)在每台服务器上都要有,于是我想事先处理一下数据,看看能不能压缩或是缩小文件大小。于是我想到使用tfrecords这种文件格式。

这篇文章给我一种帮助文档的感觉,所以我没那么想发。考虑到关于tfrecords没有很多文档,网上的文章也是一篇文章模子刻出来的,这里我记录一下过程。

0x01.保存为tfrecords文件

将数据保存为tfrecords文件可以视为这样一个流程:提取features -> 保存为Example结构对象 -> TFRecordWriter写入文件

1.提取features

Features的源码在https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/core/example/feature.proto,从源码我们可以看到features总共有3种数据类型,分别是bytesfloatint64

要想将数据类型提取出features,我们首先要将数据转化为feature(即上面三种数据类型

对应类型,封装成如下三个函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def byte_feature(value):
return tf.train.Feature(
bytes_list = tf.train.BytesList(value=[value.encode()])
)
def float_feature(value):
return tf.train.Feature(
float_list = tf.train.FloatList(value=[value])
)
def int_feature(value):
return tf.train.Feature(
int64_list = tf.train.Int64List(value=[value])
)

其中,byte_feature中的value.encode()位置,当我不带.encode()时会报错TypeError: 'AAB0162' has type str, but expected one of: bytes。如果输入的是字符串类型,也可以这样:a = b'AAB0162'

函数构造好后,我们开始构造features,这里为示例。

1
2
3
4
5
features = tf.train.Features(feature={
'name': byte_feature(n),
'time': float_feature(t),
'data': byte_feature(d)
})

2.保存为Example结构对象

1
example = tf.train.Example(features=features)

3.TFRecordWriter写入文件

1
2
3
4
5
6
7
8
9
# writer对象
destination = 'data.tfrecords'
writer = tf.python_io.TFRecordWriter(destination)
# 写入
writer.write(example.SerializeToString())
#关闭
writer.close()

4.完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# 首先封装好三个函数
def byte_feature(value):
return tf.train.Feature(
bytes_list = tf.train.BytesList(value=[value.encode()])
)
def float_feature(value):
return tf.train.Feature(
float_list = tf.train.FloatList(value=[value])
)
def int_feature(value):
return tf.train.Feature(
int64_list = tf.train.Int64List(value=[value])
)
# 定义好文件名
source = 'test.csv'
destination = 'data.tfrecords'
# 创建对象
writer = tf.python_io.TFRecordWriter(destination)
# 打开文件
reader = open(source)
# 循环,按行处理
for line in reader.readlines():
# 对行的提取
n, t, d = deal(line)
# features构造
features = tf.train.Features(feature={
'name': byte_feature(n),
'time': float_feature(t),
'data': byte_feature(d)
})
# 构造example对象 及 写入
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
# 关闭指针
writer.close()
reader.close()

0x02.从tfrecords读取数据

这里分为两部,首先我们从文件中取出features,取出后的类型为Tensor,之后我们使用sess将数据还原。

1.取出features

首先定义一会要用的参数:

1
2
3
4
5
6
7
8
9
# tfrecords文件
file = 'data.tfrecords'
# 线程数量
num_threads = 2
num_epochs = 100
# 每批次数量
batch_size = 10
# 样本数量下限
min_after_dequeue = 10

首先定义reader:

1
reader = tf.TFRecordReader()

定义输入部分:

1
file_queue = tf.train.string_input_producer(file, num_epochs=num_epochs,)

读取:

1
_, example = reader.read(file_queue)

提取出features,并保存为列表:

1
2
3
4
5
6
7
8
9
10
features_dict = tf.parse_single_example(example,
features={
'name': tf.FixedLenFeature([], tf.string),
'time': tf.FixedLenFeature([], tf.float32),
'data': tf.FixedLenFeature([], tf.string)
})
n = features_dict['name']
t = features_dict['time']
d = features_dict['data']

之后我们将其转化为批次队列:

1
2
3
4
5
6
7
n, t, d = tf.train.shuffle_batch(
[n, t, d],
batch_size=batch_size,
num_threads=num_threads,
capacity = min_after_dequeue + 3 * batch_size,
min_after_dequeue = min_after_dequeue
)

2.数据还原

定义session,之后将数据还原:

1
2
3
4
5
6
7
8
9
10
with tf.Session() as sess:
# 初始化
tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer()).run()
tf.train.start_queue_runners(sess=sess)
a_val, b_val, c_val = sess.run([n, t, d])
print(a_val, b_val, c_val)

3.完整代码

(这里我将features的提取封装为了函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def records_to(file, num_threads=2, num_epochs=2, batch_size=2, min_after_dequeue=2):
reader = tf.TFRecordReader()
file_queue = tf.train.string_input_producer(file, num_epochs=num_epochs,)
_, example = reader.read(file_queue)
features_dict = tf.parse_single_example(example,
features={
'name': tf.FixedLenFeature([], tf.string),
'time': tf.FixedLenFeature([], tf.float32),
'data': tf.FixedLenFeature([], tf.string)
})
# n: Tensor("ParseSingleExample/Squeeze_name:0", shape=(), dtype=string)
n = features_dict['name']
t = features_dict['time']
d = features_dict['data']
n, t, d = tf.train.shuffle_batch(
[n, t, d],
batch_size=batch_size,
num_threads=num_threads,
capacity = min_after_dequeue + 3 * batch_size,
min_after_dequeue = min_after_dequeue
)
# 数据格式为Tensor
return n, t, d
def train():
n, t, d = records_to(['data.tfrecords'])
with tf.Session() as sess:
tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer()).run()
tf.train.start_queue_runners(sess=sess)
a_val, b_val, c_val = sess.run([n, t, d])
print(a_val, b_val, c_val)

4.不构造图的读取

这一种方法比较简单,也不需要构造图,可以看作是我们把数据写入tfrecords的逆过程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
record_iterator = tf.python_io.tf_record_iterator(path=file_name)
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
n = example.features.feature['name'].bytes_list.value[0]
t = int(example.features.feature['time'].float_list.value[0])
d = (example.features.feature['data'].bytes_list.value[0])
print(n, t, d)

0x03.参考

TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读

Tfrecords Guide