Skip to content

Commit

Permalink
improve data_loader
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardyehuang committed Feb 15, 2022
1 parent 54ce0f9 commit e7e1dd0
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ def load_image_tensor_from_path(image_path, label_path=None):
return image_tensor, label_tensor


def simple_load_image(image_path):
def simple_load_image(image_path, label_path=None, ignore_label=255):

image_tensor, _ = load_image_tensor_from_path(image_path)
image_tensor, label_tensor = load_image_tensor_from_path(image_path, label_path)
image_tensor = tf.expand_dims(tf.cast(image_tensor, tf.float32), axis=0) # [1, H, W, 3]

if label_tensor is not None:
label_tensor = tf.expand_dims(label_tensor, axis=0) # [1, H, W, 1]

image_size = tf.shape(image_tensor)[1:3]

pad_height = tf.cast(tf.math.ceil(image_size[0] / 32) * 32, tf.int32)
Expand All @@ -55,5 +58,8 @@ def simple_load_image(image_path):
pad_image_tensor = pad_to_bounding_box(image_tensor, 0, 0, pad_height, pad_width, pad_value=[127.5, 127.5, 127.5])
pad_image_tensor = normalize_value_range(pad_image_tensor)

return pad_image_tensor, image_size
if label_tensor is not None:
label_tensor = pad_to_bounding_box(label_tensor, 0, 0, pad_height, pad_width, pad_value=ignore_label)

return pad_image_tensor, label_tensor, image_size

0 comments on commit e7e1dd0

Please sign in to comment.