From e7e1dd0d454b78114f258a088af47add2efdf8cb Mon Sep 17 00:00:00 2001 From: edwardyehuang Date: Wed, 16 Feb 2022 02:39:44 +0800 Subject: [PATCH] improve data_loader --- utils/data_loader.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/utils/data_loader.py b/utils/data_loader.py index e997508..a5fb8fd 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -37,11 +37,14 @@ def load_image_tensor_from_path(image_path, label_path=None): return image_tensor, label_tensor -def simple_load_image(image_path): +def simple_load_image(image_path, label_path=None, ignore_label=255): - image_tensor, _ = load_image_tensor_from_path(image_path) + image_tensor, label_tensor = load_image_tensor_from_path(image_path, label_path) image_tensor = tf.expand_dims(tf.cast(image_tensor, tf.float32), axis=0) # [1, H, W, 3] + if label_tensor is not None: + label_tensor = tf.expand_dims(label_tensor, axis=0) # [1, H, W, 1] + image_size = tf.shape(image_tensor)[1:3] pad_height = tf.cast(tf.math.ceil(image_size[0] / 32) * 32, tf.int32) @@ -55,5 +58,8 @@ def simple_load_image(image_path): pad_image_tensor = pad_to_bounding_box(image_tensor, 0, 0, pad_height, pad_width, pad_value=[127.5, 127.5, 127.5]) pad_image_tensor = normalize_value_range(pad_image_tensor) - return pad_image_tensor, image_size + if label_tensor is not None: + label_tensor = pad_to_bounding_box(label_tensor, 0, 0, pad_height, pad_width, pad_value=ignore_label) + + return pad_image_tensor, label_tensor, image_size