-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinput_pipeline.py
34 lines (29 loc) · 1.61 KB
/
input_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# -*- use coding:utf-8 -*-
import tensorflow as tf
def read_from_tfrecord(tfrecord_file_queue, patch_size):
reader = tf.TFRecordReader()
_, tfrecord_serialized = reader.read(tfrecord_file_queue)
tfrecord_features = tf.parse_single_example(tfrecord_serialized,
features={
'label': tf.FixedLenFeature([], tf.string),
'patch_raw': tf.FixedLenFeature([], tf.string)},
name='features')
image = tf.decode_raw(tfrecord_features['patch_raw'], tf.uint8)
ground_truth = tf.decode_raw(tfrecord_features['label'], tf.int32)
image = tf.cast(tf.reshape(image, [patch_size, patch_size, 3]), tf.float32)
# image = tf.image.random_flip_up_down(image)
# image = tf.image.random_flip_left_right(image)
image = tf.image.per_image_standardization(image)
ground_truth = tf.reshape(ground_truth, [1])
return image, ground_truth
def input_pipeline(filenames, batch_size, patch_size, read_threads=2, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example_list = [read_from_tfrecord(filename_queue, patch_size)
for _ in range(read_threads)]
min_after_dequeue = 500
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch_join(
example_list, batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch