Skip to content

Commit

Permalink
Validate features during add_frame + Add 2D-to-5D + Add string (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene authored Feb 14, 2025
1 parent 9d6886d commit 7c2bbee
Show file tree
Hide file tree
Showing 8 changed files with 448 additions and 53 deletions.
28 changes: 23 additions & 5 deletions lerobot/common/datasets/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,40 @@ def wrapper(*args, **kwargs):
return wrapper


def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")

if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)

elif image_array.shape[-1] != 3:
raise NotImplementedError(
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
)

if image_array.dtype != np.uint8:
# Assume the image is in [0, 1] range for floating-point data
image_array = np.clip(image_array, 0, 1)
if range_check:
max_ = image_array.max().item()
min_ = image_array.min().item()
if max_ > 1.0 or min_ < 0.0:
raise ValueError(
"The image data type is float, which requires values in the range [0.0, 1.0]. "
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
"provide a uint8 image with values in the range [0, 255]."
)

image_array = (image_array * 255).astype(np.uint8)

return PIL.Image.fromarray(image_array)


def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
try:
if isinstance(image, np.ndarray):
img = image_array_to_image(image)
img = image_array_to_pil_image(image)
elif isinstance(image, PIL.Image.Image):
img = image
else:
Expand Down
21 changes: 9 additions & 12 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
TASKS_PATH,
append_jsonlines,
check_delta_timestamps,
check_frame_features,
check_timestamps_sync,
check_version_compatibility,
create_branch,
Expand Down Expand Up @@ -724,10 +725,12 @@ def add_frame(self, frame: dict) -> None:
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
then needs to be called.
"""
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
# check the dtype and shape matches, etc.
if "task" not in frame:
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.")
# Convert torch to numpy if needed
for name in frame:
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()

check_frame_features(frame, self.features)

if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
Expand Down Expand Up @@ -757,8 +760,7 @@ def add_frame(self, frame: dict) -> None:
self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path))
else:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
self.episode_buffer[key].append(item)
self.episode_buffer[key].append(frame[key])

self.episode_buffer["size"] += 1

Expand Down Expand Up @@ -815,12 +817,7 @@ def save_episode(self, encode_videos: bool = True, episode_data: dict | None = N
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
episode_buffer[key] = np.stack(episode_buffer[key])
else:
raise ValueError(key)
episode_buffer[key] = np.stack(episode_buffer[key])

self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index)
Expand Down
103 changes: 100 additions & 3 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torchvision import transforms

from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature

DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
Expand Down Expand Up @@ -203,7 +204,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif first_item is None:
pass
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict


Expand Down Expand Up @@ -285,11 +286,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"])
else:
assert len(ft["shape"]) == 1
elif len(ft["shape"]) == 1:
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
elif len(ft["shape"]) == 2:
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 3:
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 4:
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 5:
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
else:
raise ValueError(f"Corresponding feature is not valid: {ft}")

return datasets.Features(hf_features)

Expand Down Expand Up @@ -606,3 +616,90 @@ def values(self):

def keys(self):
return vars(self).keys()


def check_frame_features(frame: dict, features: dict):
optional_features = {"timestamp"}
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())

error_message = check_features_presence(actual_features, expected_features, optional_features)

if "task" in frame:
error_message += check_feature_string("task", frame["task"])

common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}:
error_message += check_feature_dtype_and_shape(name, features[name], frame[name])

if error_message:
raise ValueError(error_message)


def check_features_presence(
actual_features: set[str], expected_features: set[str], optional_features: set[str]
):
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - (expected_features | optional_features)

if missing_features or extra_features:
error_message += "Feature mismatch in `frame` dictionary:\n"
if missing_features:
error_message += f"Missing features: {missing_features}\n"
if extra_features:
error_message += f"Extra features: {extra_features}\n"

return error_message


def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
return check_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return check_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return check_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")


def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray):
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
actual_shape = value.shape

if actual_dtype != np.dtype(expected_dtype):
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"

if actual_shape != expected_shape:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
else:
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"

return error_message


def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"

return error_message


def check_feature_string(name: str, value: str):
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
16 changes: 15 additions & 1 deletion lerobot/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
import torch


Expand Down Expand Up @@ -200,5 +201,18 @@ def get_channel_first_image_shape(image_shape: tuple) -> tuple:
return shape


def has_method(cls: object, method_name: str):
def has_method(cls: object, method_name: str) -> bool:
return hasattr(cls, method_name) and callable(getattr(cls, method_name))


def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
"""
Return True if a given string can be converted to a numpy dtype.
"""
try:
# Attempt to convert the string to a numpy dtype
np.dtype(dtype_str)
return True
except TypeError:
# If a TypeError is raised, the string is not a valid dtype
return False
2 changes: 2 additions & 0 deletions tests/fixtures/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@
"video.is_depth_map": False,
"has_audio": False,
}
DUMMY_CHW = (3, 96, 128)
DUMMY_HWC = (96, 128, 3)
18 changes: 18 additions & 0 deletions tests/fixtures/dataset_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_hf_features_from_features,
hf_transform_to_torch,
)
from lerobot.common.robot_devices.robots.utils import Robot
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
Expand Down Expand Up @@ -394,3 +395,20 @@ def _create_lerobot_dataset(
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)

return _create_lerobot_dataset


@pytest.fixture(scope="session")
def empty_lerobot_dataset_factory():
def _create_empty_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
fps: int = DEFAULT_FPS,
robot: Robot | None = None,
robot_type: str | None = None,
features: dict | None = None,
):
return LeRobotDataset.create(
repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features
)

return _create_empty_lerobot_dataset
Loading

0 comments on commit 7c2bbee

Please sign in to comment.