diff --git a/robomimic/data/__init__.py b/robomimic/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/robomimic/data/common_transformations.py b/robomimic/data/common_transformations.py new file mode 100644 index 00000000..f9560505 --- /dev/null +++ b/robomimic/data/common_transformations.py @@ -0,0 +1,177 @@ +import dlimp as dl +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +import tqdm +from typing import Any, Callable, Dict, Sequence, Union +from tensorflow_datasets.core.dataset_builder import DatasetBuilder + +import robomimic.utils.tensor_utils as TensorUtils +from robomimic.utils.data_utils import * + + +def add_next_obs(traj: Dict[str, Any], pad: bool = True) -> Dict[str, Any]: + """ + Given a trajectory with a key "observation", add the key "next_observation". If pad is False, discards the last + value of all other keys. Otherwise, the last transition will have "observation" == "next_observation". + """ + if not pad: + traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) + traj_truncated["next_observation"] = tf.nest.map_structure( + lambda x: x[1:], traj["observation"] + ) + return traj_truncated + else: + traj["next_observation"] = tf.nest.map_structure( + lambda x: tf.concat((x[1:], x[-1:]), axis=0), traj["observation"] + ) + return traj + + +def normalize_obs_and_actions(traj, config, metadata): + ''' + For now, only normalize appropriate action keys + ''' + action_config = config.train.action_config + normal_keys = [key for key in config.train.action_keys + if key in action_config.keys() and action_config[key].get('normalization', None) == 'normal'] + + min_max_keys = [key for key in config.train.action_keys + if key in action_config.keys() and action_config[key].get('normalization', None) == 'min_max'] + + for key in normal_keys: + map_nested_dict_index( + traj, + key, + lambda x: (x - metadata[key]["mean"]) / metadata[key]["std"] + ) + + for key in min_max_keys: + map_nested_dict_index( + traj, + key, + lambda x: tf.clip_by_value(2 * (x - metadata[key]["min"]) + / (metadata[key]["max"] - metadata[key]["min"]) - 1, + -1, + 1) + ) + + return traj + + +def random_dataset_sequence_transform(traj, sequence_length): + ''' + Extract a random subsequence of the data given sequence_length given keys + ''' + traj_len = len(traj["action"]) + index_in_demo = tf.cast(tf.random.uniform(shape=[]) + * tf.cast(traj_len, dtype=tf.float32), dtype=tf.int32) + last_index = tf.math.minimum(traj_len, index_in_demo + sequence_length) + seq_end_pad = tf.math.maximum(0, index_in_demo + sequence_length - traj_len) + padding = [0, seq_end_pad] + keys = ["observation", "action", "action_dict", "goal"] + + def random_sequence_func(x): + sequence = x[index_in_demo: last_index] + padding = tf.repeat([x[0]], repeats=[seq_end_pad], axis=0) + return tf.concat((sequence, padding), axis=0) + + traj = dl.transforms.selective_tree_map( + traj, + match=keys, + map_fn=random_sequence_func + ) + return traj + + +def random_dataset_sequence_transform_v2(traj, frame_stack, seq_length, + pad_frame_stack, pad_seq_length): + ''' + Extract a random subsequence of the data given sequence_length given keys + ''' + traj_len = tf.shape(traj["action"])[0] + seq_begin_pad, seq_end_pad = 0, 0 + if pad_frame_stack: + seq_begin_pad = frame_stack - 1 + if pad_seq_length: + seq_end_pad = seq_length - 1 + index_in_demo = tf.random.uniform(shape=[], + maxval=traj_len + seq_end_pad - (seq_length - 1), + dtype=tf.int32) + pad_mask = tf.concat((tf.repeat(0, repeats=seq_begin_pad), + tf.repeat(1, repeats=traj_len), + tf.repeat(0, repeats=seq_end_pad)), axis=0)[:, None] + traj['pad_mask'] = tf.cast(pad_mask, dtype=tf.bool) + keys = ["observation", "action", "action_dict", "goal"] + + def random_sequence_func(x): + begin_padding = tf.repeat([x[0]], repeats=[seq_begin_pad], axis=0) + end_padding = tf.repeat([x[-1]], repeats=[seq_end_pad], axis=0) + sequence = tf.concat((begin_padding, x, end_padding), axis=0) + return sequence[index_in_demo: index_in_demo + seq_length + frame_stack - 1] + + traj = dl.transforms.selective_tree_map( + traj, + match=keys, + map_fn=random_sequence_func + ) + return traj + + + +def relabel_goals_transform(traj, goal_mode): + traj_len = len(traj["action"]) + + if goal_mode == "last": + goal_idxs = tf.ones(traj_len) * (traj_len - 1) + goal_idxs = tf.cast(goal_idxs, tf.int32) + elif goal_mode == "uniform": + rand = tf.random.uniform([traj_len]) + low = tf.cast(tf.range(traj_len) + 1, tf.float32) + high = tf.cast(traj_len, tf.float32) + goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) + + traj["goal_observation"] = tf.nest.map_structure( + lambda x: tf.gather(x, goal_idxs), traj["observation"] + ) + return traj + + +def concatenate_action_transform(traj, action_keys): + ''' + Concatenates the action_keys + ''' + traj["action"] = tf.concat( + list(index_nested_dict(traj, key) for key in action_keys), + axis=-1 + ) + + return traj + + +def frame_stack_transform(traj, num_frames): + ''' + Stacks the previous num_frame-1 frames with the current frame + Converts the trajectory into size + traj_len x num_frames x ... + ''' + traj_len = len(traj["action"]) + + #Pad beginning of observation num_frames times: + traj["observation"] = tf.nest.map_structure( + lambda x: tf.concat((tf.repeat([x[0]], repeats=[num_frames], axis=0) + , x), axis=0) + , traj["observation"]) + + def stack_func(x): + indices = tf.reshape(tf.range(traj_len), [-1, 1]) + tf.range(num_frames) + return tf.gather(x, indices) + + #Concatenate and clip to original size + traj["observation"] = tf.nest.map_structure( + stack_func, + traj["observation"] + ) + + return traj + diff --git a/robomimic/data/dataset.py b/robomimic/data/dataset.py new file mode 100644 index 00000000..f2a01edb --- /dev/null +++ b/robomimic/data/dataset.py @@ -0,0 +1,269 @@ +from typing import Any, Callable, Dict, Sequence, Union, List, Optional, Tuple +import dlimp as dl +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +#Don't use GPU for dataloading +tf.config.set_visible_devices([], "GPU") +import tqdm +import logging +from tensorflow_datasets.core.dataset_builder import DatasetBuilder +from collections import OrderedDict +from functools import partial +import numpy as np +import hashlib +import json +import pickle +import torch + +import robomimic.utils.torch_utils as TorchUtils +from .dataset_transformations import RLDS_TRAJECTORY_MAP_TRANSFORMS +import robomimic.data.common_transformations as CommonTransforms +import robomimic.utils.data_utils as DataUtils + + +class RLDSTorchDataset: + def __init__(self, dataset_iterator, try_to_use_cuda=True): + self.dataset_iterator = dataset_iterator + self.device = TorchUtils.get_torch_device(try_to_use_cuda) + self.keys = ['obs', 'goal_obs', 'actions'] + + def __iter__(self): + for batch in self.dataset_iterator: + torch_batch = {} + for key in self.keys: + if key in batch.keys(): + torch_batch[key] = DataUtils.tree_map( + batch[key], + map_fn=lambda x: torch.tensor(x).to(self.device) + ) + yield torch_batch + + +def get_action_normalization_stats_rlds(obs_action_metadata, config): + action_config = config.train.action_config + normal_keys = [key for key in config.train.action_keys + if action_config[key].get('normalization', None) == 'normal'] + min_max_keys = [key for key in config.train.action_keys + if action_config[key].get('normalization', None) == 'min_max'] + + stats = OrderedDict() + for key in config.train.action_keys: + if key in normal_keys: + normal_stats = { + 'scale': obs_action_metadata[key]['std'].reshape(1, -1), + 'offset': obs_action_metadata[key]['mean'].reshape(1, -1) + } + stats[key] = normal_stats + elif key in min_max_keys: + min_max_range = obs_action_metadata[key]['max'] - obs_action_metadata[key]['min'] + min_max_stats = { + 'scale': (min_max_range / 2).reshape(1, -1), + 'offset': (obs_action_metadata[key]['min'] + min_max_range / 2).reshape(1, -1) + } + stats[key] = min_max_stats + else: + identity_stats = { + 'scale': np.ones_like(obs_action_metadata[key]['std']).reshape(1, -1), + 'offset': np.zeros_like(obs_action_metadata[key]['mean']).reshape(1, -1) + } + stats[key] = identity_stats + return stats + + +def get_obs_normalization_stats_rlds(obs_action_metadata, config): + stats = OrderedDict() + for key, obs_action_stats in obs_action_metadata.items(): + feature_type, feature_key = key.split('/') + if feature_type != 'observation': + continue + stats[feature_key] = { + 'mean': obs_action_stats['mean'][None], + 'std': obs_action_stats['std'][None], + } + return stats + + +def get_obs_action_metadata( + builder: DatasetBuilder, dataset: tf.data.Dataset, keys: List[str], + load_if_exists=True +) -> Dict[str, Dict[str, List[float]]]: + # get statistics file path --> embed unique hash that catches if dataset info changed + data_info_hash = hashlib.sha256( + (str(builder.info) + str(keys)).encode("utf-8") + ).hexdigest() + path = tf.io.gfile.join( + builder.info.data_dir, f"obs_action_stats_{data_info_hash}.pkl" + ) + + # check if stats already exist and load, otherwise compute + if tf.io.gfile.exists(path) and load_if_exists: + print(f"Loading existing statistics for normalization from {path}.") + with tf.io.gfile.GFile(path, "rb") as f: + metadata = pickle.load(f) + else: + print("Computing obs/action statistics for normalization...") + eps_by_key = {key: [] for key in keys} + + i, n_samples = 0, 500 + dataset_iter = dataset.as_numpy_iterator() + for _ in tqdm.tqdm(range(n_samples)): + episode = next(dataset_iter) + i = i + 1 + for key in keys: + eps_by_key[key].append(DataUtils.index_nested_dict(episode, key)) + eps_by_key = {key: np.concatenate(values) for key, values in eps_by_key.items()} + + metadata = {} + for key in keys: + metadata[key] = { + "mean": eps_by_key[key].mean(0), + "std": eps_by_key[key].std(0), + "max": eps_by_key[key].max(0), + "min": eps_by_key[key].min(0), + } + with tf.io.gfile.GFile(path, "wb") as f: + pickle.dump(metadata, f) + logging.info("Done!") + + return metadata + + +def decode_dataset( + dataset: tf.data.Dataset + ): + + #Decode images + dataset = dataset.frame_map( + DataUtils.decode_images + ) + return dataset + + +def apply_common_transforms( + dataset: tf.data.Dataset, + config: dict, + *, + train: bool, + obs_action_metadata: Optional[dict] = None, + ): + + #Normalize observations and actions + if obs_action_metadata is not None: + dataset = dataset.map( + partial( + CommonTransforms.normalize_obs_and_actions, + config=config, + metadata=obs_action_metadata, + ), + num_parallel_calls=tf.data.AUTOTUNE + ) + #Relabel goals + if config.train.goal_mode == 'last' or config.train.goal_mode == 'uniform': + dataset = dataset.map( + partial( + CommonTransforms.relabel_goals_transform, + goal_mode=config.goal_mode + ), + num_parallel_calls=tf.data.AUTOTUNE + ) + + #Concatenate actions + if config.train.action_keys != None: + dataset = dataset.map( + partial( + CommonTransforms.concatenate_action_transform, + action_keys=config.train.action_keys + ), + num_parallel_calls=tf.data.AUTOTUNE + ) + #Get a random subset of length frame_stack + seq_length - 1 + dataset = dataset.map( + partial( + CommonTransforms.random_dataset_sequence_transform_v2, + frame_stack=config.train.frame_stack, + seq_length=config.train.seq_length, + pad_frame_stack=config.train.pad_frame_stack, + pad_seq_length=config.train.pad_seq_length + ), + num_parallel_calls=tf.data.AUTOTUNE + ) + #augmentation? #chunking? + + return dataset + +def decode_trajectory(builder, obs_keys, episode): + steps = episode + new_steps = dict() + new_steps['action_dict'] = dict() + new_steps['observation'] = dict() + for key in steps["action_dict"]: + new_steps['action_dict'][key] = builder.info.features["steps"]['action_dict'][ + key + ].decode_batch_example(steps["action_dict"][key]) + for key in obs_keys: + new_steps['observation'][key] = builder.info.features["steps"]['observation'][ + key + ].decode_batch_example(steps["observation"][key]) + return new_steps + +def make_dataset( + config: dict, + train: bool = True, + shuffle: bool = True, + resize_size: Optional[Tuple[int, int]] = None, + normalization_metadata: Optional[Dict] = None, + **kwargs, +) -> tf.data.Dataset: + + data_info = config.train.data[0] + name = data_info['name'] + data_dir = data_info['path'] + + builder = tfds.builder(name, data_dir=data_dir) + + if "val" not in builder.info.splits: + split = "train[:95%]" if train else "train[95%:]" + else: + split = "train" if train else "val" + + dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle, + num_parallel_reads=8) + if name in RLDS_TRAJECTORY_MAP_TRANSFORMS: + if RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre'] is not None: + dataset = dataset.map(partial( + RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre'], + config=config), + ) + metadata_keys = [k for k in config.train.action_keys] + if config.all_obs_keys is not None: + metadata_keys.extend([f'observation/{k}' + for k in config.all_obs_keys]) + if normalization_metadata is None: + normalization_metadata = get_obs_action_metadata( + builder, + dataset, + keys=metadata_keys, + load_if_exists=True#False + ) + dataset = apply_common_transforms( + dataset, + config=config, + train=train, + obs_action_metadata=normalization_metadata, + **kwargs, + ) + if name in RLDS_TRAJECTORY_MAP_TRANSFORMS: + if RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post'] is not None: + dataset = dataset.map(partial( + RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post'], + config=config), + ) + dataset = decode_dataset(dataset) + dataset = dataset.repeat().batch(config.train.batch_size).prefetch(tf.data.experimental.AUTOTUNE) + dataset = dataset.as_numpy_iterator() + dataset = RLDSTorchDataset(dataset) + + return builder, dataset, normalization_metadata + + diff --git a/robomimic/data/dataset_shapes.py b/robomimic/data/dataset_shapes.py new file mode 100644 index 00000000..c3ef7658 --- /dev/null +++ b/robomimic/data/dataset_shapes.py @@ -0,0 +1,13 @@ +import numpy as np + +r2d2_dataset_shapes = { + 'action_dict/abs_pos': (3,), + 'action_dict/abs_rot_6d': (6,), +} + + +DATASET_SHAPES = { + 'r2_d2': r2d2_dataset_shapes, +} + + diff --git a/robomimic/data/dataset_transformations.py b/robomimic/data/dataset_transformations.py new file mode 100644 index 00000000..b89575ad --- /dev/null +++ b/robomimic/data/dataset_transformations.py @@ -0,0 +1,119 @@ +from typing import Any, Callable, Dict, Sequence, Union +import robomimic.utils.tensorflow_utils as TensorflowUtils +import tensorflow as tf + + +def r2d2_dataset_pre_transform(traj: Dict[str, Any], + config: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + keep_keys = [ + 'observation', + 'action', + ] + ac_keys = ['cartesian_position', 'cartesian_velocity'] + new_traj = {k: v for k, v in traj.items() if k in keep_keys} + new_traj['action_dict'] = { + 'gripper_position': traj['action_dict']['gripper_position'] + } + for key in ac_keys: + in_action = traj['action_dict'][key] + pos = traj['action_dict'][key][:, :3] + rot = traj['action_dict'][key][:, 3:6] + + rot_6d = TensorflowUtils.euler_angles_to_rot_6d( + rot, convention="XYZ", + ) + if key == 'cartesian_position': + prefix = 'abs_' + else: + prefix = 'rel_' + + new_traj['action_dict'].update({ + prefix + 'pos': pos, + prefix + 'rot_euler': rot, + prefix + 'rot_6d': rot_6d + }) + return new_traj + + +def r2d2_dataset_post_transform(traj: Dict[str, Any], + config: Dict[str, Any]) -> Dict[str, Any]: + + new_traj = {'observation': {}} + for key in config.all_obs_keys: + nested_keys = key.split('/') + value = traj['observation'] + assign = new_traj['observation'] + for i, nk in enumerate(nested_keys): + if i == len(nested_keys) - 1: + assign[nk] = value[nk] + break + value = value[nk] + if nk not in assign.keys(): + assign[nk] = dict() + assign = assign[nk] + #Set obs key + new_traj['obs'] = new_traj['observation'] + + #Set actions key + new_traj['actions'] = traj['action'] + + #Use one goal per sequence + if 'goal_observation' in traj.keys(): + new_traj['goal_obs'] = traj['goal_observation'][0] + keep_keys = ['obs', + 'goal_obs', + 'actions', + 'action_dict'] + new_traj = {k: v for k, v in new_traj.items() if k in keep_keys} + return new_traj + + +def robomimic_dataset_pre_transform(traj: Dict[str, Any], + config: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + keep_keys = [ + "observation", + "action", + "action_dict", + "language_instruction", + "is_terminal", + "is_last", + "_traj_index", + ] + traj = {k: v for k, v in traj.items() if k in keep_keys} + return traj + + +def robomimic_dataset_post_transform(traj: Dict[str, Any], + config: Dict[str, Any]) -> Dict[str, Any]: + new_traj = dict() + #Set obs key + traj['obs'] = traj['observation'] + + #Set actions key + traj['actions'] = traj['action'] + + #Use one goal per sequence + if 'goal_observation' in traj.keys(): + new_traj['goal_obs'] = traj['goal_observation'][0] + + keep_keys = ['obs', + 'goal_obs', + 'actions', + 'action_dict'] + traj = {k: v for k, v in traj.items() if k in keep_keys} + return traj + + +RLDS_TRAJECTORY_MAP_TRANSFORMS = { + 'r2_d2': { + 'pre': r2d2_dataset_pre_transform, + 'post': r2d2_dataset_post_transform, + }, + 'robomimic_dataset': { + 'pre': robomimic_dataset_pre_transform, + 'post': robomimic_dataset_post_transform, + } +} + diff --git a/robomimic/envs/env_base.py b/robomimic/envs/env_base.py index ee3184c2..3b5e8cf3 100644 --- a/robomimic/envs/env_base.py +++ b/robomimic/envs/env_base.py @@ -14,6 +14,7 @@ class EnvType: ROBOSUITE_TYPE = 1 GYM_TYPE = 2 IG_MOMART_TYPE = 3 + DUMMY_TYPE = 4 class EnvBase(abc.ABC): diff --git a/robomimic/envs/env_dummy.py b/robomimic/envs/env_dummy.py new file mode 100644 index 00000000..d539a68c --- /dev/null +++ b/robomimic/envs/env_dummy.py @@ -0,0 +1,98 @@ +import json +import numpy as np +from copy import deepcopy +import robomimic.envs.env_base as EB + + +class EnvDummy(EB.EnvBase): + """Dummy env used for real-world cases when env doesn't exist""" + def __init__( + self, + env_name, + render=False, + render_offscreen=False, + use_image_obs=False, + postprocess_visual_obs=True, + **kwargs, + ): + self._env_name = env_name + + def step(self, action): + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def reset_to(self, state): + raise NotImplementedError + + def render(self, mode="human", height=None, width=None, camera_name=None, **kwargs): + raise NotImplementedError + + def get_observation(self, obs=None): + raise NotImplementedError + + def get_state(self): + raise NotImplementedError + + def get_reward(self): + raise NotImplementedError + + def get_goal(self): + raise NotImplementedError + + def set_goal(self, **kwargs): + raise NotImplementedError + + def is_done(self): + raise NotImplementedError + + def is_success(self): + raise NotImplementedError + + @property + def action_dimension(self): + raise NotImplementedError + + @property + def name(self): + """ + Returns name of environment name (str). + """ + return self._env_name + + @property + def type(self): + """ + Returns environment type (int) for this kind of environment. + This helps identify this env class. + """ + return EB.EnvType.GYM_TYPE + + def serialize(self): + """ + Save all information needed to re-instantiate this environment in a dictionary. + This is the same as @env_meta - environment metadata stored in hdf5 datasets, + and used in utils/env_utils.py. + """ + return dict(env_name=self.name, type=self.type) + + @classmethod + def create_for_data_processing(cls, env_name, camera_names, camera_height, camera_width, reward_shaping, **kwargs): + raise NotImplementedError + + @property + def rollout_exceptions(self): + """ + Return tuple of exceptions to except when doing rollouts. This is useful to ensure + that the entire training run doesn't crash because of a bad policy that causes unstable + simulation computations. + """ + raise NotImplementedError + + def __repr__(self): + """ + Pretty-print env description. + """ + return f'{self.name} Dummy Env' + diff --git a/robomimic/scripts/config_gen/diffusion_gen.py b/robomimic/scripts/config_gen/diffusion_gen.py index 93e53b32..3fe11a7b 100644 --- a/robomimic/scripts/config_gen/diffusion_gen.py +++ b/robomimic/scripts/config_gen/diffusion_gen.py @@ -1,4 +1,5 @@ from robomimic.scripts.config_gen.helper import * +import os def make_generator_helper(args): algo_name_short = "diffusion_policy" @@ -27,6 +28,16 @@ def make_generator_helper(args): values=[1000], ) + generator.add_param( + key="train.data_format", + name="df", + group=1123, + values=[args.data_format], + value_names=[ + "data_format", + ] + ) + # use ddim by default generator.add_param( key="algo.ddim.enabled", @@ -51,74 +62,158 @@ def make_generator_helper(args): if args.env == "r2d2": generator.add_param( - key="train.data", - name="ds", - group=2, - values=[ - [{"path": p} for p in scan_datasets("~/Downloads/example_pen_in_cup", postfix="trajectory_im128.h5")], - ], + key="train.data_format", + name="df", + group=1123, + values=[args.data_format], value_names=[ - "pen-in-cup", - ], - ) - generator.add_param( - key="train.action_keys", - name="ac_keys", - group=-1, - values=[ - [ - "action/abs_pos", - "action/abs_rot_6d", - "action/gripper_position", - ], - ], - value_names=[ - "abs", - ], - hidename=True, - ) - generator.add_param( - key="observation.modalities.obs.rgb", - name="cams", - group=130, - values=[ - # ["camera/image/hand_camera_left_image"], - # ["camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image"], - ["camera/image/hand_camera_left_image", "camera/image/varied_camera_1_left_image", "camera/image/varied_camera_2_left_image"], - # [ - # "camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image", - # "camera/image/varied_camera_1_left_image", "camera/image/varied_camera_1_right_image", - # "camera/image/varied_camera_2_left_image", "camera/image/varied_camera_2_right_image", - # ], - ], - value_names=[ - # "wrist", - # "wrist-stereo", - "3cams", - # "3cams-stereo", + "data_format", ] - ) + ) + if args.data_format == 'hdf5': + generator.add_param( + key="train.data", + name="ds", + group=2, + values=[ + [{"path": p} for p in scan_datasets("~/Downloads/example_pen_in_cup", postfix="trajectory_im128.h5")], + ], + value_names=[ + "pen-in-cup", + ], + ) + generator.add_param( + key="train.action_keys", + name="ac_keys", + group=-1, + values=[ + [ + "action/abs_pos", + "action/abs_rot_6d", + "action/gripper_position", + ], + ], + value_names=[ + "abs", + ], + hidename=True, + ) + generator.add_param( + key="observation.modalities.obs.rgb", + name="cams", + group=130, + values=[ + # ["camera/image/hand_camera_left_image"], + # ["camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image"], + ["camera/image/hand_camera_left_image", "camera/image/varied_camera_1_left_image", "camera/image/varied_camera_2_left_image"], + # [ + # "camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image", + # "camera/image/varied_camera_1_left_image", "camera/image/varied_camera_1_right_image", + # "camera/image/varied_camera_2_left_image", "camera/image/varied_camera_2_right_image", + # ], + ], + value_names=[ + # "wrist", + # "wrist-stereo", + "3cams", + # "3cams-stereo", + ] + ) + generator.add_param( + key="observation.modalities.obs.low_dim", + name="ldkeys", + group=2498, + values=[ + ["robot_state/cartesian_position", "robot_state/gripper_position"], + # [ + # "robot_state/cartesian_position", "robot_state/gripper_position", + # "camera/extrinsics/hand_camera_left", "camera/extrinsics/hand_camera_left_gripper_offset", + # "camera/extrinsics/hand_camera_right", "camera/extrinsics/hand_camera_right_gripper_offset", + # "camera/extrinsics/varied_camera_1_left", "camera/extrinsics/varied_camera_1_right", + # "camera/extrinsics/varied_camera_2_left", "camera/extrinsics/varied_camera_2_right", + # ] + ], + value_names=[ + "proprio", + # "proprio-extrinsics", + ] + ) - generator.add_param( - key="observation.modalities.obs.low_dim", - name="ldkeys", - group=2498, - values=[ - ["robot_state/cartesian_position", "robot_state/gripper_position"], - # [ - # "robot_state/cartesian_position", "robot_state/gripper_position", - # "camera/extrinsics/hand_camera_left", "camera/extrinsics/hand_camera_left_gripper_offset", - # "camera/extrinsics/hand_camera_right", "camera/extrinsics/hand_camera_right_gripper_offset", - # "camera/extrinsics/varied_camera_1_left", "camera/extrinsics/varied_camera_1_right", - # "camera/extrinsics/varied_camera_2_left", "camera/extrinsics/varied_camera_2_right", - # ] - ], - value_names=[ - "proprio", - # "proprio-extrinsics", - ] - ) + elif args.data_format == 'rlds': + generator.add_param( + key="train.data", + name="ds", + group=2, + values=[ + [ + {"path": "/iris/u/jyang27/rlds_data", + "name": "r2_d2"}, # replace with your own path + ], + ], + value_names=[ + "pen-in-cup" + ], + ) + generator.add_param( + key="train.action_keys", + name="ac_keys", + group=-1, + values=[ + [ + "action_dict/abs_pos", + "action_dict/abs_rot_6d", + "action_dict/gripper_position", + ], + ], + value_names=[ + "abs", + ], + hidename=True, + ) + generator.add_param( + key="observation.modalities.obs.rgb", + name="cams", + group=130, + values=[ + # ["camera/image/hand_camera_left_image"], + # ["camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image"], + ["wrist_image_left", "exterior_image_1_left", "exterior_image_2_left"], + # [ + # "camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image", + # "camera/image/varied_camera_1_left_image", "camera/image/varied_camera_1_right_image", + # "camera/image/varied_camera_2_left_image", "camera/image/varied_camera_2_right_image", + # ], + ], + value_names=[ + # "wrist", + # "wrist-stereo", + "3cams", + # "3cams-stereo", + ] + ) + generator.add_param( + key="observation.modalities.obs.low_dim", + name="ldkeys", + group=2498, + values=[ + ["cartesian_position", "gripper_position"], + # [ + # "robot_state/cartesian_position", "robot_state/gripper_position", + # "camera/extrinsics/hand_camera_left", "camera/extrinsics/hand_camera_left_gripper_offset", + # "camera/extrinsics/hand_camera_right", "camera/extrinsics/hand_camera_right_gripper_offset", + # "camera/extrinsics/varied_camera_1_left", "camera/extrinsics/varied_camera_1_right", + # "camera/extrinsics/varied_camera_2_left", "camera/extrinsics/varied_camera_2_right", + # ] + ], + value_names=[ + "proprio", + # "proprio-extrinsics", + ] + ) + + else: + raise ValueError generator.add_param( key="observation.encoder.rgb.core_kwargs.backbone_class", name="backbone", @@ -186,19 +281,37 @@ def make_generator_helper(args): hidename=True, ) elif args.env == "square": - generator.add_param( - key="train.data", - name="ds", - group=2, - values=[ - [ - {"path": "~/datasets/square/ph/square_ph_abs_tmp.hdf5"}, # replace with your own path + if args.data_format == 'hdf5': + generator.add_param( + key="train.data", + name="ds", + group=2, + values=[ + [ + {"path": "/iris/u/jyang27/dev/robomimic/datasets/square/ph/low_dim_v141.hdf5"}, # replace with your own path + ], ], - ], - value_names=[ - "square", - ], - ) + value_names=[ + "square", + ], + ) + elif args.data_format == 'rlds': + generator.add_param( + key="train.data", + name="ds", + group=2, + values=[ + [ + {"path": "/iris/u/jyang27/rlds_data", + "name": "robomimic_dataset"}, # replace with your own path + ], + ], + value_names=[ + "square", + ], + ) + else: + raise ValueError # update env config to use absolute action control generator.add_param( @@ -236,7 +349,7 @@ def make_generator_helper(args): name="", group=-1, values=[ - "~/expdata/{env}/{mod}/{algo_name_short}".format( + "/iris/u/jyang27/expdata/{env}/{mod}/{algo_name_short}".format( env=args.env, mod=args.mod, algo_name_short=algo_name_short, @@ -250,4 +363,4 @@ def make_generator_helper(args): parser = get_argparser() args = parser.parse_args() - make_generator(args, make_generator_helper) \ No newline at end of file + make_generator(args, make_generator_helper) diff --git a/robomimic/scripts/config_gen/helper.py b/robomimic/scripts/config_gen/helper.py index 48a3af07..d4dfd4e6 100644 --- a/robomimic/scripts/config_gen/helper.py +++ b/robomimic/scripts/config_gen/helper.py @@ -183,56 +183,143 @@ def set_env_settings(generator, args): "r2d2" ], ) - - # here, we list how each action key should be treated (normalized etc) - generator.add_param( - key="train.action_config", - name="", - group=-1, - values=[ - { - "action/cartesian_position":{ - "normalization": "min_max", - }, - "action/abs_pos":{ - "normalization": "min_max", - }, - "action/abs_rot_6d":{ - "normalization": "min_max", - "format": "rot_6d", - "convert_at_runtime": "rot_euler", - }, - "action/abs_rot_euler":{ - "normalization": "min_max", - "format": "rot_euler", - }, - "action/gripper_position":{ - "normalization": "min_max", - }, - "action/cartesian_velocity":{ - "normalization": None, - }, - "action/rel_pos":{ - "normalization": None, - }, - "action/rel_rot_6d":{ - "format": "rot_6d", - "normalization": None, - "convert_at_runtime": "rot_euler", - }, - "action/rel_rot_euler":{ - "format": "rot_euler", - "normalization": None, - }, - "action/gripper_velocity":{ - "normalization": None, - }, - } - ], - ) + + if args.data_format == 'hdf5': + # here, we list how each action key should be treated (normalized etc) + generator.add_param( + key="train.action_config", + name="", + group=-1, + values=[ + { + "action/cartesian_position":{ + "normalization": "min_max", + }, + "action/abs_pos":{ + "normalization": "min_max", + }, + "action/abs_rot_6d":{ + "normalization": "min_max", + "format": "rot_6d", + "convert_at_runtime": "rot_euler", + }, + "action/abs_rot_euler":{ + "normalization": "min_max", + "format": "rot_euler", + }, + "action/gripper_position":{ + "normalization": "min_max", + }, + "action/cartesian_velocity":{ + "normalization": None, + }, + "action/rel_pos":{ + "normalization": None, + }, + "action/rel_rot_6d":{ + "format": "rot_6d", + "normalization": None, + "convert_at_runtime": "rot_euler", + }, + "action/rel_rot_euler":{ + "format": "rot_euler", + "normalization": None, + }, + "action/gripper_velocity":{ + "normalization": None, + }, + } + ], + ) + generator.add_param( + key="train.shuffled_obs_key_groups", + name="", + group=-1, + values=[[[ + ( + "camera/image/varied_camera_1_left_image", + "camera/image/varied_camera_1_right_image", + "camera/extrinsics/varied_camera_1_left", + "camera/extrinsics/varied_camera_1_right", + ), + ( + "camera/image/varied_camera_2_left_image", + "camera/image/varied_camera_2_right_image", + "camera/extrinsics/varied_camera_2_left", + "camera/extrinsics/varied_camera_2_right", + ), + ]]], + ) + elif args.data_format == 'rlds': + # here, we list how each action key should be treated (normalized etc) + generator.add_param( + key="train.action_config", + name="", + group=-1, + values=[ + { + "action_dict/cartesian_position":{ + "normalization": "min_max", + }, + "action_dict/abs_pos":{ + "normalization": "min_max", + }, + "action_dict/abs_rot_6d":{ + "normalization": "min_max", + "format": "rot_6d", + "convert_at_runtime": "rot_euler", + }, + "action_dict/abs_rot_euler":{ + "normalization": "min_max", + "format": "rot_euler", + }, + "action_dict/gripper_position":{ + "normalization": "min_max", + }, + "action_dict/cartesian_velocity":{ + "normalization": None, + }, + "action_dict/rel_pos":{ + "normalization": None, + }, + "action_dict/rel_rot_6d":{ + "format": "rot_6d", + "normalization": None, + "convert_at_runtime": "rot_euler", + }, + "action_dict/rel_rot_euler":{ + "format": "rot_euler", + "normalization": None, + }, + "action_dict/gripper_velocity":{ + "normalization": None, + }, + } + ], + ) + generator.add_param( + key="train.shuffled_obs_key_groups", + name="", + group=-1, + values=[[[ + ( + "wrist_image_left", + "exterior_image_1_left", + "exterior_image_2_left", + ), + ( + "wrist_image_right", + "exterior_image_1_right", + "exterior_image_2_right", + ), + ]]], + ) + + else: + raise ValueError generator.add_param( key="train.dataset_keys", - name="", + name="", group=-1, values=[[]], ) @@ -252,26 +339,6 @@ def set_env_settings(generator, args): "rel", ], ) - # observation key groups to swap - generator.add_param( - key="train.shuffled_obs_key_groups", - name="", - group=-1, - values=[[[ - ( - "camera/image/varied_camera_1_left_image", - "camera/image/varied_camera_1_right_image", - "camera/extrinsics/varied_camera_1_left", - "camera/extrinsics/varied_camera_1_right", - ), - ( - "camera/image/varied_camera_2_left_image", - "camera/image/varied_camera_2_right_image", - "camera/extrinsics/varied_camera_2_left", - "camera/extrinsics/varied_camera_2_right", - ), - ]]], - ) elif args.env == "kitchen": generator.add_param( key="train.action_config", @@ -630,6 +697,17 @@ def set_env_settings(generator, args): else: raise ValueError + if args.data_format == 'rlds': + #If using rlds, overwrite data format to rlds + generator.add_param( + key="train.data_format", + name="", + group=-1, + values=[ + "rlds" + ], + ) + def set_mod_settings(generator, args): if args.mod == 'ld': @@ -840,6 +918,12 @@ def get_argparser(): default='im', ) + parser.add_argument( + "--data_format", + type=str, + default='hdf5', + ) + parser.add_argument( "--ckpt_mode", type=str, diff --git a/robomimic/scripts/train.py b/robomimic/scripts/train_rlds.py similarity index 90% rename from robomimic/scripts/train.py rename to robomimic/scripts/train_rlds.py index ffbc4666..5d586ddd 100644 --- a/robomimic/scripts/train.py +++ b/robomimic/scripts/train_rlds.py @@ -40,7 +40,8 @@ from robomimic.config import config_factory from robomimic.algo import algo_factory, RolloutPolicy from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings - +from robomimic.data.dataset import (make_dataset, get_obs_normalization_stats_rlds, + get_action_normalization_stats_rlds) def train(config, device): """ @@ -68,27 +69,40 @@ def train(config, device): # read config to set up metadata for observation modalities (e.g. detecting rgb observations) ObsUtils.initialize_obs_utils_with_config(config) - # make sure the dataset exists - eval_dataset_cfg = config.train.data[0] - dataset_path = os.path.expanduser(eval_dataset_cfg["path"]) + # Load the datasets + train_builder, train_loader, normalization_metadata = make_dataset( + config, + train=True, + shuffle=True + ) ds_format = config.train.data_format - if not os.path.exists(dataset_path): - raise Exception("Dataset at provided path {} not found!".format(dataset_path)) + assert ds_format == 'rlds' + + if config.experiment.validate: + # cap num workers for validation dataset at 1 + num_workers = min(config.train.num_data_workers, 1) + valid_builder, valid_loader, _ = make_dataset( + config, + train=True, + shuffle=True, + normalization_metadata=normalization_metadata + ) + + else: + valid_loader = None + # load basic metadata from training file print("\n============= Loaded Environment Metadata =============") - env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path, ds_format=ds_format) - + env_meta = FileUtils.get_env_metadata_from_dataset_rlds(train_builder) # update env meta if applicable from robomimic.utils.script_utils import deep_update deep_update(env_meta, config.experiment.env_meta_update_dict) - shape_meta = FileUtils.get_shape_metadata_from_dataset( - dataset_path=dataset_path, + shape_meta = FileUtils.get_shape_metadata_from_dataset_rlds( + train_builder, action_keys=config.train.action_keys, all_obs_keys=config.all_obs_keys, - ds_format=ds_format, - verbose=True ) if config.experiment.env is not None: @@ -150,51 +164,21 @@ def train(config, device): print(model) # print model summary print("") - # load training data - trainset, validset = TrainUtils.load_data_for_training( - config, obs_keys=shape_meta["all_obs_keys"]) - train_sampler = trainset.get_dataset_sampler() - print("\n============= Training Dataset =============") - print(trainset) - print("") - if validset is not None: - print("\n============= Validation Dataset =============") - print(validset) - print("") # maybe retreve statistics for normalizing observations obs_normalization_stats = None if config.train.hdf5_normalize_obs: - obs_normalization_stats = trainset.get_obs_normalization_stats() + obs_normalization_stats = get_obs_normalization_stats_rlds( + normalization_metadata, + config + ) # maybe retreve statistics for normalizing actions - action_normalization_stats = trainset.get_action_normalization_stats() - - # initialize data loaders - train_loader = DataLoader( - dataset=trainset, - sampler=train_sampler, - batch_size=config.train.batch_size, - shuffle=(train_sampler is None), - num_workers=config.train.num_data_workers, - drop_last=True + action_normalization_stats = get_action_normalization_stats_rlds( + normalization_metadata, + config ) - if config.experiment.validate: - # cap num workers for validation dataset at 1 - num_workers = min(config.train.num_data_workers, 1) - valid_sampler = validset.get_dataset_sampler() - valid_loader = DataLoader( - dataset=validset, - sampler=valid_sampler, - batch_size=config.train.batch_size, - shuffle=(valid_sampler is None), - num_workers=num_workers, - drop_last=True - ) - else: - valid_loader = None - # print all warnings before training begins print("*" * 50) print("Warnings generated by robomimic have been duplicated here (from above) for convenience. Please check them carefully.") @@ -218,7 +202,7 @@ def train(config, device): data_loader=train_loader, epoch=epoch, num_steps=train_num_steps, - obs_normalization_stats=obs_normalization_stats, + obs_normalization_stats=obs_normalization_stats ) model.on_epoch_end(epoch) diff --git a/robomimic/utils/data_utils.py b/robomimic/utils/data_utils.py new file mode 100644 index 00000000..5fa8bbcb --- /dev/null +++ b/robomimic/utils/data_utils.py @@ -0,0 +1,117 @@ +import numpy as np +import tensorflow as tf +from functools import partial +from typing import Any, Callable, Dict, Sequence, Union + + +def index_nested_dict(d: Dict[str, Any], index: int): + """ + Indexes a nested dictionary with backslashes separating hierarchies + """ + indices = index.split("/") + for i in indices: + if i not in d.keys(): + raise ValueError(f"Index {index} not found") + d = d[i] + return d + + +def set_nested_dict_index(d: Dict[str, Any], index: int, value): + """ + Sets an index in a nested dictionary with a value + Indexes have backslashes separating hierarchies + """ + indices = index.split("/") + for i in indices[:-1]: + if i not in d.keys(): + raise ValueError(f"Index {index} not found") + d = d[i] + d[indices[-1]] = value + + +def map_nested_dict_index(d: Dict[str, Any], index: int, map_func): + """ + Maps an index in a nested dictionary with a function + Indexes have backslashes separating hierarchies + """ + indices = index.split("/") + for i in indices[:-1]: + if i not in d.keys(): + raise ValueError(f"Index {index} not found") + d = d[i] + d[indices[-1]] = map_func(d[indices[-1]]) + + +def tree_map( + x: Dict[str, Any], + map_fn: Callable, + *, + _keypath: str = "", +) -> Dict[str, Any]: + + if not isinstance(x, dict): + out = map_fn(x) + return out + out = {} + for key in x.keys(): + if isinstance(x[key], dict): + out[key] = tree_map( + x[key], map_fn, _keypath=_keypath + key + "/" + ) + else: + out[key] = map_fn(x[key]) + return out + + +def selective_tree_map( + x: Dict[str, Any], + match: Union[str, Sequence[str], Callable[[str, Any], bool]], + map_fn: Callable, + *, + _keypath: str = "", +) -> Dict[str, Any]: + """Maps a function over a nested dictionary, only applying it leaves that match a criterion. + + Args: + x (Dict[str, Any]): The dictionary to map over. + match (str or Sequence[str] or Callable[[str, Any], bool]): If a string or list of strings, `map_fn` will only + be applied to leaves whose key path contains one of `match`. If a function, `map_fn` will only be applied to + leaves for which `match(key_path, value)` returns True. + map_fn (Callable): The function to apply. + """ + if not callable(match): + if isinstance(match, str): + match = [match] + match_fn = lambda keypath, value: any([s in keypath for s in match]) + else: + match_fn = match + + out = {} + for key in x: + if isinstance(x[key], dict): + out[key] = selective_tree_map( + x[key], match_fn, map_fn, _keypath=_keypath + key + "/" + ) + elif match_fn(_keypath + key, x[key]): + out[key] = map_fn(x[key]) + else: + out[key] = x[key] + return out + + +def decode_images( + x: Dict[str, Any], match: Union[str, Sequence[str]] = "image" +) -> Dict[str, Any]: + if isinstance(match, str): + match = [match] + + def match_fn(keypath, value): + image_in_keypath = any([s in keypath for s in match]) + return image_in_keypath + + return selective_tree_map( + x, + match=match_fn, + map_fn=partial(tf.io.decode_image, expand_animations=False), + ) + diff --git a/robomimic/utils/env_utils.py b/robomimic/utils/env_utils.py index 6514de32..8b56e922 100644 --- a/robomimic/utils/env_utils.py +++ b/robomimic/utils/env_utils.py @@ -41,6 +41,9 @@ def get_env_class(env_meta=None, env_type=None, env=None): elif env_type == EB.EnvType.IG_MOMART_TYPE: from robomimic.envs.env_ig_momart import EnvGibsonMOMART return EnvGibsonMOMART + elif env_type == EB.EnvType.DUMMY_TYPE: + from robomimic.envs.env_dummy import EnvDummy + return EnvDummy raise Exception("code should never reach this point") diff --git a/robomimic/utils/file_utils.py b/robomimic/utils/file_utils.py index eb590bd8..baaf6c03 100644 --- a/robomimic/utils/file_utils.py +++ b/robomimic/utils/file_utils.py @@ -12,10 +12,12 @@ from tqdm import tqdm import torch +import tensorflow_datasets as tfds import robomimic.utils.obs_utils as ObsUtils import robomimic.utils.env_utils as EnvUtils import robomimic.utils.torch_utils as TorchUtils +import robomimic.utils.data_utils as DataUtils from robomimic.config import config_factory from robomimic.algo import algo_factory from robomimic.algo import RolloutPolicy @@ -108,6 +110,39 @@ def get_env_metadata_from_dataset(dataset_path, ds_format="robomimic"): return env_meta +def get_env_metadata_from_dataset_rlds(builder): + """ + Retrieves env metadata from dataset. + + Args: + dataset_path (str): path to dataset + + Returns: + env_meta (dict): environment metadata. Contains 3 keys: + + :`'env_name'`: name of environment + :`'type'`: type of environment, should be a value in EB.EnvType + :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor + """ + if builder.info.metadata is None: + env_meta = None + else: + env_meta = builder.info.metadata.get('env_metadata', None) + if env_meta is not None: + #Fix weird json property that turns bool into _bool + DataUtils.tree_map(env_meta, + lambda x: bool(x) if isinstance(x, bool) else x + ) + else: + env_meta = { + 'env_name': 'rlds', + 'type': 4, + 'env_kwargs': {}, + } + + return env_meta + + def get_shape_metadata_from_dataset(dataset_path, action_keys, all_obs_keys=None, ds_format="robomimic", verbose=False): """ Retrieves shape metadata from dataset. @@ -208,6 +243,38 @@ def get_shape_metadata_from_dataset(dataset_path, action_keys, all_obs_keys=None return shape_meta +def get_shape_metadata_from_dataset_rlds(builder, action_keys, all_obs_keys=None): + from robomimic.data.dataset_shapes import DATASET_SHAPES + + shape_meta = {} + info = builder.info + name = builder.name + action_dim = 0 + for key in action_keys: + if name in DATASET_SHAPES.keys() and key in DATASET_SHAPES[name].keys(): + action_dim += DATASET_SHAPES[name][key][0] + else: + key_shape = DataUtils.index_nested_dict( + info.features['steps'], key).shape + assert len(key_shape) == 1 + action_dim += key_shape[0] + shape_meta["ac_dim"] = action_dim + + if all_obs_keys is None: + all_obs_keys = info.features['steps']['observation'].keys() + shape_meta['all_obs_keys'] = all_obs_keys + shape_meta['all_shapes'] = OrderedDict() + for key, feature in info.features['steps']['observation'].items(): + shape = feature.shape + if feature.shape[-1] == min(feature.shape): + shape = [shape[-1]] + list(shape[:-1]) + shape_meta['all_shapes'][key] = shape + shape_meta['use_images'] = np.any([isinstance(feature, tfds.features.Image) + for key, feature in info.features['steps']['observation'].items()]) + + return shape_meta + + def load_dict_from_checkpoint(ckpt_path): """ Load checkpoint dictionary from a checkpoint file. diff --git a/robomimic/utils/hyperparam_utils.py b/robomimic/utils/hyperparam_utils.py index 0ff6397e..cf422bae 100644 --- a/robomimic/utils/hyperparam_utils.py +++ b/robomimic/utils/hyperparam_utils.py @@ -294,8 +294,11 @@ def _script_from_jsons(self, json_paths): for path in json_paths: # write python command to file import robomimic - cmd = "python {}/scripts/train.py --config {}\n".format(robomimic.__path__[0], path) - + data_format = self.parameters.get('train.data_format', None) + if data_format is None or data_format.values[0] != 'rlds': + cmd = "python {}/scripts/train.py --config {}\n".format(robomimic.__path__[0], path) + else: + cmd = "python {}/scripts/train_rlds.py --config {}\n".format(robomimic.__path__[0], path) print() print(cmd) f.write(cmd) diff --git a/robomimic/utils/tensorflow_utils.py b/robomimic/utils/tensorflow_utils.py new file mode 100644 index 00000000..6407a06e --- /dev/null +++ b/robomimic/utils/tensorflow_utils.py @@ -0,0 +1,95 @@ +import tensorflow as tf + + +def euler_angles_to_rot_6d(euler_angles, convention="XYZ"): + """ + Converts tensor with rot_6d representation to euler representation. + """ + rot_mat = euler_angles_to_matrix(euler_angles, convention="XYZ") + rot_6d = matrix_to_rotation_6d(rot_mat) + return rot_6d + + +def matrix_to_rotation_6d(matrix): + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + Returns: + 6D rotation representation, of size (*, 6) + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = tf.shape(matrix)[:-2] + return tf.reshape(matrix[..., :2, :], tf.concat([batch_dim, [6]], axis=0)) + + +def euler_angles_to_matrix(euler_angles, convention): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, tf.unstack(euler_angles, axis=-1)) + ] + + # TensorFlow doesn't have a native functools.reduce or torch.matmul, so we use a loop + result = matrices[0] + for mat in matrices[1:]: + result = tf.matmul(result, mat) + + return result + + +def _axis_angle_rotation(axis, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = tf.cos(angle) + sin = tf.sin(angle) + one = tf.ones_like(angle) + zero = tf.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return tf.reshape(tf.stack(R_flat, axis=-1), tf.concat([tf.shape(angle), [3, 3]], axis=0)) + + + diff --git a/robomimic/utils/train_utils.py b/robomimic/utils/train_utils.py index 16202fc0..3f65c57a 100644 --- a/robomimic/utils/train_utils.py +++ b/robomimic/utils/train_utils.py @@ -602,7 +602,6 @@ def run_epoch(model, data_loader, epoch, validate=False, num_steps=None, obs_nor data_loader_iter = iter(data_loader) for _ in LogUtils.custom_tqdm(range(num_steps)): - # load next batch from data loader try: t = time.time() @@ -613,7 +612,7 @@ def run_epoch(model, data_loader, epoch, validate=False, num_steps=None, obs_nor t = time.time() batch = next(data_loader_iter) timing_stats["Data_Loading"].append(time.time() - t) - + # process batch for training t = time.time() input_batch = model.process_batch_for_training(batch)