Files
DeepHSV/dataset/dataset_paris.py
2019-05-06 17:43:25 +08:00

105 lines
3.8 KiB
Python

"""
@file: dataset_paris.py
@time: 2018/7/31 15:03
@desc:Create the input data pipeline using `tf.data`
"""
import numpy as np
import tensorflow as tf
image_width = None
image_height = None
images_dir = None
channels = 1
def _read_image(filename, is_augment):
image_string = tf.read_file(tf.string_join([images_dir, filename]))
image_decoded = tf.image.decode_png(image_string, channels=channels)
true_constant = tf.constant(1, dtype=tf.int32, name="true_constant")
image_decoded = tf.cond(tf.equal(true_constant, is_augment),
lambda: tf.image.flip_left_right(image_decoded),
lambda: image_decoded)
image_resized = tf.image.resize_images(image_decoded, [image_width, image_height])
return image_resized
def _parse_function(item):
is_aug = tf.string_to_number(item[3], out_type=tf.int32)
image0 = _read_image(item[0], is_aug)
image1 = _read_image(item[1], is_aug)
image = tf.concat([image0, image1], 2)
return image, tf.string_to_number(item[2])
def _input_fn(params, is_training, is_augment=False, pos_repeating=1, only_label=None):
"""Train input function.
Args:
listfile_path: listfile has 3 item per line
params: contains hyperparameters of the model (ex: data_dir, image's width and height.)
"""
listfile_path = params.signature_train_list if is_training else params.signature_val_list
data = []
shuffle_neg = []
size_per_signer = params.positive_size + params.negative_size
file = open(listfile_path)
for i, line in enumerate(file.readlines()):
items = line.split(' ')
file0 = items[0]
file1 = items[1]
label = int(items[2])
if (only_label is not None and label != only_label) or label == 2:
continue
repeating = 1
if is_training and pos_repeating > 0 and i % size_per_signer == 0:
"""the number of positive/negative pairs is 276/996,
so we need to expand positive pairs, or reduce the negative pairs"""
shuffle_neg = np.arange(params.positive_size, size_per_signer)
np.random.shuffle(shuffle_neg)
shuffle_neg = shuffle_neg[:params.positive_size * pos_repeating]
if is_training and pos_repeating > 0:
"""expand positive pairs"""
if label == 2:
repeating = 1 if (i % params.negative_size) > params.positive_size * pos_repeating else 0
repeating = 0
elif label == 0:
"""reduce negative pairs """
repeating = 1 if i % size_per_signer in shuffle_neg else 0
elif label == 1:
repeating = pos_repeating
for j in range(repeating):
"""file0, file1, label, is_augment"""
data.append((file0, file1, label, 0))
if is_augment and is_training:
data.append((file0, file1, label, 1))
# data.append((file1, file0, label))
file.close()
np.random.shuffle(data)
print("examples of data: -> %d" % len(data))
dataset = tf.data.Dataset.from_tensor_slices(np.array(data))
dataset = dataset.map(_parse_function, num_parallel_calls=params.num_parallel_calls)
dataset = dataset.shuffle(10000)
dataset = dataset.repeat(params.num_epochs)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(
params.batch_size * params.num_gpus))
dataset = dataset.prefetch(10)
return dataset
def input_fn(params, is_training, repeating=1, is_augment=False, only_label=None):
global image_width, image_height, images_dir, channels
image_width = params.image_width
image_height = params.image_height
images_dir = params.images_dir
channels = params.channels
return _input_fn(params, is_training, pos_repeating=repeating, is_augment=is_augment, only_label=only_label)