Skip to content

Commit

Permalink
Add sanity check for image range
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Feb 13, 2025
1 parent 53bb86a commit 7a81132
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 10 deletions.
18 changes: 14 additions & 4 deletions lerobot/common/datasets/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,32 @@ def wrapper(*args, **kwargs):
return wrapper


def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)

if image_array.dtype != np.uint8:
# Assume the image is in [0, 1] range for floating-point data
image_array = np.clip(image_array, 0, 1)
if range_check:
max_ = image_array.max().item()
min_ = image_array.min().item()
if max_ > 1.0 or min_ < 0.0:
raise ValueError(
"The image data type is float, which requires values in the range [0.0, 1.0]. "
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
"provide a uint8 image with values in the range [0, 255]."
)

image_array = (image_array * 255).astype(np.uint8)

return PIL.Image.fromarray(image_array)


def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
try:
if isinstance(image, np.ndarray):
img = image_array_to_image(image)
img = image_array_to_pil_image(image)
elif isinstance(image, PIL.Image.Image):
img = image
else:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.image_writer import image_array_to_pil_image
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
MultiLeRobotDataset,
Expand Down Expand Up @@ -346,6 +347,28 @@ def test_add_frame_string(create_dataset):
assert dataset[0]["caption"] == "dummy_caption"


def test_image_array_to_pil_image_wrong_range_float_0_255():
image = np.random.rand(*DUMMY_HWC) * 255
with pytest.raises(ValueError):
image_array_to_pil_image(image)


def test_image_array_to_pil_image_wrong_range_float_neg_1_1():
image = np.random.rand(*DUMMY_HWC) * 2 - 1
with pytest.raises(ValueError):
image_array_to_pil_image(image)


def test_image_array_to_pil_image_float():
image = np.random.rand(*DUMMY_HWC)
image_array_to_pil_image(image)


def test_image_array_to_pil_image_uint8():
image = (np.random.rand(*DUMMY_HWC) * 255).astype(np.uint8)
image_array_to_pil_image(image)


# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
Expand Down
12 changes: 6 additions & 6 deletions tests/test_image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from lerobot.common.datasets.image_writer import (
AsyncImageWriter,
image_array_to_image,
image_array_to_pil_image,
safe_stop_image_writer,
write_image,
)
Expand Down Expand Up @@ -50,15 +50,15 @@ def test_zero_threads():

def test_image_array_to_image_rgb(img_array_factory):
img_array = img_array_factory(100, 100)
result_image = image_array_to_image(img_array)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"


def test_image_array_to_image_pytorch_format(img_array_factory):
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
result_image = image_array_to_image(img_array)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
Expand All @@ -67,15 +67,15 @@ def test_image_array_to_image_pytorch_format(img_array_factory):
@pytest.mark.skip("TODO: implement")
def test_image_array_to_image_single_channel(img_array_factory):
img_array = img_array_factory(channels=1)
result_image = image_array_to_image(img_array)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "L"


def test_image_array_to_image_float_array(img_array_factory):
img_array = img_array_factory(dtype=np.float32)
result_image = image_array_to_image(img_array)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
Expand All @@ -85,7 +85,7 @@ def test_image_array_to_image_float_array(img_array_factory):
def test_image_array_to_image_out_of_bounds_float():
# Float array with values out of [0, 1]
img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32)
result_image = image_array_to_image(img_array)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
Expand Down

0 comments on commit 7a81132

Please sign in to comment.