From 31b1f610ec4c133f3f3d1fda5dd0e5a885e2aa86 Mon Sep 17 00:00:00 2001 From: Jonathan Yang Date: Thu, 5 Oct 2023 16:18:48 -0700 Subject: [PATCH 1/4] Add RLDS dataloader --- robomimic/data/__init__.py | 0 robomimic/data/common_transformations.py | 176 +++++++ robomimic/data/dataset.py | 250 ++++++++++ robomimic/data/dataset_transformations.py | 78 +++ robomimic/scripts/config_gen/diffusion_gen.py | 46 +- robomimic/scripts/config_gen/helper.py | 17 + robomimic/scripts/train_rlds.py | 466 ++++++++++++++++++ robomimic/utils/data_utils.py | 61 +++ robomimic/utils/file_utils.py | 45 ++ 9 files changed, 1125 insertions(+), 14 deletions(-) create mode 100644 robomimic/data/__init__.py create mode 100644 robomimic/data/common_transformations.py create mode 100644 robomimic/data/dataset.py create mode 100644 robomimic/data/dataset_transformations.py create mode 100644 robomimic/scripts/train_rlds.py create mode 100644 robomimic/utils/data_utils.py diff --git a/robomimic/data/__init__.py b/robomimic/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/robomimic/data/common_transformations.py b/robomimic/data/common_transformations.py new file mode 100644 index 00000000..a2c3c5cd --- /dev/null +++ b/robomimic/data/common_transformations.py @@ -0,0 +1,176 @@ +import dlimp as dl +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +import tqdm +from typing import Any, Callable, Dict, Sequence, Union +from tensorflow_datasets.core.dataset_builder import DatasetBuilder + +import robomimic.utils.tensor_utils as TensorUtils +from robomimic.utils.data_utils import * + + +def add_next_obs(traj: Dict[str, Any], pad: bool = True) -> Dict[str, Any]: + """ + Given a trajectory with a key "observation", add the key "next_observation". If pad is False, discards the last + value of all other keys. Otherwise, the last transition will have "observation" == "next_observation". + """ + if not pad: + traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) + traj_truncated["next_observation"] = tf.nest.map_structure( + lambda x: x[1:], traj["observation"] + ) + return traj_truncated + else: + traj["next_observation"] = tf.nest.map_structure( + lambda x: tf.concat((x[1:], x[-1:]), axis=0), traj["observation"] + ) + return traj + + +def normalize_obs_and_actions(traj, config, metadata): + ''' + For now, only normalize appropriate action keys + ''' + action_config = config.train.action_config + normal_keys = [key for key in config.train.action_keys + if action_config[key].get('normalization', None) == 'normal'] + min_max_keys = [key for key in config.train.action_keys + if action_config[key].get('normalization', None) == 'min_max'] + + for key in normal_keys: + map_nested_dict_index( + traj, + key, + lambda x: (x - metadata[key]["mean"]) / metadata[key]["std"] + ) + + for key in min_max_keys: + map_nested_dict_index( + traj, + key, + lambda x: tf.clip_by_value(2 * (x - metadata[key]["min"]) + / (metadata[key]["max"] - metadata[key]["min"]) - 1, + -1, + 1) + ) + + return traj + + +def random_dataset_sequence_transform(traj, sequence_length): + ''' + Extract a random subsequence of the data given sequence_length given keys + ''' + traj_len = len(traj["action"]) + index_in_demo = tf.cast(tf.random.uniform(shape=[]) + * tf.cast(traj_len, dtype=tf.float32), dtype=tf.int32) + last_index = tf.math.minimum(traj_len, index_in_demo + sequence_length) + seq_end_pad = tf.math.maximum(0, index_in_demo + sequence_length - traj_len) + padding = [0, seq_end_pad] + keys = ["observation", "action", "action_dict", "goal"] + + def random_sequence_func(x): + sequence = x[index_in_demo: last_index] + padding = tf.repeat([x[0]], repeats=[seq_end_pad], axis=0) + return tf.concat((sequence, padding), axis=0) + + traj = dl.transforms.selective_tree_map( + traj, + match=keys, + map_fn=random_sequence_func + ) + return traj + + +def random_dataset_sequence_transform_v2(traj, frame_stack, seq_length, + pad_frame_stack, pad_seq_length): + ''' + Extract a random subsequence of the data given sequence_length given keys + ''' + traj_len = tf.shape(traj["action"])[0] + seq_begin_pad, seq_end_pad = 0, 0 + if pad_frame_stack: + seq_begin_pad = frame_stack - 1 + if pad_seq_length: + seq_end_pad = seq_length - 1 + index_in_demo = tf.random.uniform(shape=[], + maxval=traj_len + seq_end_pad - (seq_length - 1), + dtype=tf.int32) + pad_mask = tf.concat((tf.repeat(0, repeats=seq_begin_pad), + tf.repeat(1, repeats=traj_len), + tf.repeat(0, repeats=seq_end_pad)), axis=0)[:, None] + traj['pad_mask'] = tf.cast(pad_mask, dtype=tf.bool) + keys = ["observation", "action", "action_dict", "goal"] + + def random_sequence_func(x): + begin_padding = tf.repeat([x[0]], repeats=[seq_begin_pad], axis=0) + end_padding = tf.repeat([x[-1]], repeats=[seq_end_pad], axis=0) + sequence = tf.concat((begin_padding, x, end_padding), axis=0) + return sequence[index_in_demo: index_in_demo + seq_length + frame_stack - 1] + + traj = dl.transforms.selective_tree_map( + traj, + match=keys, + map_fn=random_sequence_func + ) + return traj + + + +def relabel_goals_transform(traj, goal_mode): + traj_len = len(traj["action"]) + + if goal_mode == "last": + goal_idxs = tf.ones(traj_len) * (traj_len - 1) + goal_idxs = tf.cast(goal_idxs, tf.int32) + elif goal_mode == "uniform": + rand = tf.random.uniform([traj_len]) + low = tf.cast(tf.range(traj_len) + 1, tf.float32) + high = tf.cast(traj_len, tf.float32) + goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) + + traj["goal_observation"] = tf.nest.map_structure( + lambda x: tf.gather(x, goal_idxs), traj["observation"] + ) + return traj + + +def concatenate_action_transform(traj, action_keys): + ''' + Concatenates the action_keys + ''' + traj["action"] = tf.concat( + list(index_nested_dict(traj, key) for key in action_keys), + axis=-1 + ) + + return traj + + +def frame_stack_transform(traj, num_frames): + ''' + Stacks the previous num_frame-1 frames with the current frame + Converts the trajectory into size + traj_len x num_frames x ... + ''' + traj_len = len(traj["action"]) + + #Pad beginning of observation num_frames times: + traj["observation"] = tf.nest.map_structure( + lambda x: tf.concat((tf.repeat([x[0]], repeats=[num_frames], axis=0) + , x), axis=0) + , traj["observation"]) + + def stack_func(x): + indices = tf.reshape(tf.range(traj_len), [-1, 1]) + tf.range(num_frames) + return tf.gather(x, indices) + + #Concatenate and clip to original size + traj["observation"] = tf.nest.map_structure( + stack_func, + traj["observation"] + ) + + return traj + diff --git a/robomimic/data/dataset.py b/robomimic/data/dataset.py new file mode 100644 index 00000000..23873793 --- /dev/null +++ b/robomimic/data/dataset.py @@ -0,0 +1,250 @@ +from typing import Any, Callable, Dict, Sequence, Union, List, Optional, Tuple +import dlimp as dl +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +#Don't use GPU for dataloading +tf.config.set_visible_devices([], "GPU") +import tqdm +import logging +from tensorflow_datasets.core.dataset_builder import DatasetBuilder +from collections import OrderedDict +from functools import partial +import numpy as np +import hashlib +import json +import pickle +import torch + +import robomimic.utils.torch_utils as TorchUtils +from .dataset_transformations import RLDS_TRAJECTORY_MAP_TRANSFORMS +from .common_transformations import * +from robomimic.utils.data_utils import * + + +class RLDSTorchDataset: + def __init__(self, dataset_iterator, try_to_use_cuda=True): + self.dataset_iterator = dataset_iterator + self.device = TorchUtils.get_torch_device(try_to_use_cuda) + self.keys = ['obs', 'goal_obs', 'actions'] + + def __iter__(self): + for batch in self.dataset_iterator: + torch_batch = {} + for key in self.keys: + if key in batch.keys(): + torch_batch[key] = tree_map( + batch[key], + map_fn=lambda x: torch.Tensor(x).to(self.device) + ) + yield torch_batch + + +def get_action_normalization_stats_rlds(obs_action_metadata, config): + action_config = config.train.action_config + normal_keys = [key for key in config.train.action_keys + if action_config[key].get('normalization', None) == 'normal'] + min_max_keys = [key for key in config.train.action_keys + if action_config[key].get('normalization', None) == 'min_max'] + + stats = OrderedDict() + for key in config.train.action_keys: + if key in normal_keys: + normal_stats = { + 'scale': obs_action_metadata[key]['std'].numpy().reshape(1, -1), + 'offset': obs_action_metadata[key]['mean'].numpy().reshape(1, -1) + } + stats[key] = normal_stats + elif key in min_max_keys: + min_max_range = obs_action_metadata[key]['max'] - obs_action_metadata[key]['min'] + min_max_stats = { + 'scale': (min_max_range / 2).numpy().reshape(1, -1), + 'offset': (obs_action_metadata[key]['min'] + min_max_range / 2).numpy().reshape(1, -1) + } + stats[key] = min_max_stats + else: + identity_stats = { + 'scale': np.ones_like(obs_action_metadata[key]['std'].numpy()).reshape(1, -1), + 'offset': np.zeros_like(obs_action_metadata[key]['mean'].numpy()).reshape(1, -1) + } + stats[key] = identity_stats + return stats + + +def get_obs_normalization_stats_rlds(obs_action_metadata, config): + stats = OrderedDict() + for key, obs_action_stats in obs_action_metadata.items(): + feature_type, feature_key = key.split('/') + if feature_type != 'observation': + continue + stats[feature_key] = { + 'mean': obs_action_stats['mean'][None], + 'std': obs_action_stats['std'][None], + } + return stats + + +def get_obs_action_metadata( + builder: DatasetBuilder, dataset: tf.data.Dataset, keys: List[str], + load_if_exists=True +) -> Dict[str, Dict[str, List[float]]]: + # get statistics file path --> embed unique hash that catches if dataset info changed + data_info_hash = hashlib.sha256( + (str(builder.info) + str(keys)).encode("utf-8") + ).hexdigest() + path = tf.io.gfile.join( + builder.info.data_dir, f"obs_action_stats_{data_info_hash}.pkl" + ) + + # check if stats already exist and load, otherwise compute + if tf.io.gfile.exists(path) and load_if_exists: + print(f"Loading existing statistics for normalization from {path}.") + with tf.io.gfile.GFile(path, "rb") as f: + metadata = pickle.load(f) + else: + print("Computing obs/action statistics for normalization...") + eps_by_key = {key: [] for key in keys} + for episode in tqdm.tqdm(dataset.take(30000)): + for key in keys: + eps_by_key[key].append(index_nested_dict(episode, key).numpy()) + eps_by_key = {key: np.concatenate(values) for key, values in eps_by_key.items()} + + metadata = {} + for key in keys: + metadata[key] = { + "mean": eps_by_key[key].mean(0), + "std": eps_by_key[key].std(0), + "max": eps_by_key[key].max(0), + "min": eps_by_key[key].min(0), + } + del eps_by_key + with tf.io.gfile.GFile(path, "wb") as f: + pickle.dump(metadata, f) + logging.info("Done!") + + return { + k: {k2: tf.convert_to_tensor(v2, dtype=tf.float32) for k2, v2 in v.items()} + for k, v in metadata.items() + } + + +def apply_common_transforms( + dataset: tf.data.Dataset, + config: dict, + *, + train: bool, + obs_action_metadata: Optional[dict] = None, + ): + + #Decode images + dataset = dataset.frame_map(dl.transforms.decode_images) + + #Normalize observations and actions + if obs_action_metadata is not None: + dataset = dataset.map( + partial( + normalize_obs_and_actions, + config=config, + metadata=obs_action_metadata, + ), + num_parallel_calls=tf.data.AUTOTUNE + ) + #Relabel goals + if config.train.goal_mode == 'last' or config.train.goal_mode == 'uniform': + dataset = dataset.map( + partial( + relabel_goals_transform, + goal_mode=config.goal_mode + ), + num_parallel_calls=tf.data.AUTOTUNE + ) + #Stack frames + ''' + if config.train.frame_stack is not None and config.train.frame_stack > 1: + dataset = dataset.map( + partial( + frame_stack_transform, + num_frames=config.train.frame_stack + ) + ) + ''' + #Concatenate actions + if config.train.action_keys != None: + dataset = dataset.map( + partial( + concatenate_action_transform, + action_keys=config.train.action_keys + ), + num_parallel_calls=tf.data.AUTOTUNE + ) + #Get a random subset of length seq_length + dataset = dataset.map( + partial( + random_dataset_sequence_transform_v2, + frame_stack=config.train.frame_stack, + seq_length=config.train.seq_length, + pad_frame_stack=config.train.pad_frame_stack, + pad_seq_length=config.train.pad_seq_length + ), + num_parallel_calls=tf.data.AUTOTUNE + ) + #augmentation? #chunking? + + return dataset + + +def make_dataset( + config: dict, + train: bool = True, + shuffle: bool = True, + resize_size: Optional[Tuple[int, int]] = None, + normalization_metadata: Optional[Dict] = None, + **kwargs, +) -> tf.data.Dataset: + + data_info = config.train.data[0] + name = data_info['name'] + data_dir = data_info['path'] + + builder = tfds.builder(name, data_dir=data_dir) + if "val" not in builder.info.splits: + split = "train[:95%]" if train else "train[95%:]" + else: + split = "train" if train else "val" + + dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle) + if name in RLDS_TRAJECTORY_MAP_TRANSFORMS: + if RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre'] is not None: + dataset = dataset.map(RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre']) + + metadata_keys = ['action'] + if config.train.action_keys is not None: + metadata_keys.extend(config.train.action_keys) + if config.all_obs_keys is not None: + metadata_keys.extend([f'observation/{k}' + for k in config.all_obs_keys]) + if normalization_metadata is None: + normalization_metadata = get_obs_action_metadata( + builder, + dataset, + keys=metadata_keys, + load_if_exists=False + ) + + dataset = apply_common_transforms( + dataset, + config=config, + train=train, + obs_action_metadata=normalization_metadata, + **kwargs, + ) + if name in RLDS_TRAJECTORY_MAP_TRANSFORMS: + if RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post'] is not None: + dataset = dataset.map(RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post']) + dataset = dataset.repeat().batch(config.train.batch_size).prefetch(tf.data.experimental.AUTOTUNE) + dataset = dataset.as_numpy_iterator() + dataset = RLDSTorchDataset(dataset) + + return builder, dataset, normalization_metadata + + diff --git a/robomimic/data/dataset_transformations.py b/robomimic/data/dataset_transformations.py new file mode 100644 index 00000000..99650ab8 --- /dev/null +++ b/robomimic/data/dataset_transformations.py @@ -0,0 +1,78 @@ +from typing import Any, Callable, Dict, Sequence, Union +import tensorflow as tf + + +def r2d2_dataset_pre_transform(traj: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + keep_keys = [ + "observation", + "action", + "action_dict", + "language_instruction", + "is_terminal", + "is_last", + "_traj_index", + ] + traj = {k: v for k, v in traj.items() if k in keep_keys} + return traj + + +def r2d2_dataset_post_transform(traj: Dict[str, Any]) -> Dict[str, Any]: + #Set obs key + traj['obs'] = traj['observation'] + + #Set actions key + traj['actions'] = traj['action'] + + #Use one goal per sequence + if 'goal_observation' in traj.keys(): + traj['goal_obs'] = traj['goal_observation'][0] + return traj + + +def robomimic_dataset_pre_transform(traj: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + keep_keys = [ + "observation", + "action", + "action_dict", + "language_instruction", + "is_terminal", + "is_last", + "_traj_index", + ] + traj = {k: v for k, v in traj.items() if k in keep_keys} + return traj + + +def robomimic_dataset_post_transform(traj: Dict[str, Any]) -> Dict[str, Any]: + new_traj = dict() + #Set obs key + traj['obs'] = traj['observation'] + + #Set actions key + traj['actions'] = traj['action'] + + #Use one goal per sequence + if 'goal_observation' in traj.keys(): + new_traj['goal_obs'] = traj['goal_observation'][0] + + keep_keys = ['obs', + 'goal_obs', + 'actions', + 'action_dict'] + traj = {k: v for k, v in traj.items() if k in keep_keys} + return traj + + +RLDS_TRAJECTORY_MAP_TRANSFORMS = { + 'r2d2': { + 'pre': r2d2_dataset_pre_transform, + 'post': r2d2_dataset_post_transform, + }, + 'robomimic_dataset': { + 'pre': robomimic_dataset_pre_transform, + 'post': robomimic_dataset_post_transform, + } +} + diff --git a/robomimic/scripts/config_gen/diffusion_gen.py b/robomimic/scripts/config_gen/diffusion_gen.py index 93e53b32..8a86a46e 100644 --- a/robomimic/scripts/config_gen/diffusion_gen.py +++ b/robomimic/scripts/config_gen/diffusion_gen.py @@ -186,19 +186,37 @@ def make_generator_helper(args): hidename=True, ) elif args.env == "square": - generator.add_param( - key="train.data", - name="ds", - group=2, - values=[ - [ - {"path": "~/datasets/square/ph/square_ph_abs_tmp.hdf5"}, # replace with your own path + if args.data_format == 'hdf5': + generator.add_param( + key="train.data", + name="ds", + group=2, + values=[ + [ + {"path": "/iris/u/jyang27/dev/robomimic/datasets/square/ph/low_dim_v141.hdf5"}, # replace with your own path + ], ], - ], - value_names=[ - "square", - ], - ) + value_names=[ + "square", + ], + ) + elif args.data_format == 'rlds': + generator.add_param( + key="train.data", + name="ds", + group=2, + values=[ + [ + {"path": "/iris/u/jyang27/rlds_data", + "name": "robomimic_dataset"}, # replace with your own path + ], + ], + value_names=[ + "square", + ], + ) + else: + raise ValueError # update env config to use absolute action control generator.add_param( @@ -236,7 +254,7 @@ def make_generator_helper(args): name="", group=-1, values=[ - "~/expdata/{env}/{mod}/{algo_name_short}".format( + "/iris/u/jyang27/expdata/{env}/{mod}/{algo_name_short}".format( env=args.env, mod=args.mod, algo_name_short=algo_name_short, @@ -250,4 +268,4 @@ def make_generator_helper(args): parser = get_argparser() args = parser.parse_args() - make_generator(args, make_generator_helper) \ No newline at end of file + make_generator(args, make_generator_helper) diff --git a/robomimic/scripts/config_gen/helper.py b/robomimic/scripts/config_gen/helper.py index 48a3af07..7d9829a8 100644 --- a/robomimic/scripts/config_gen/helper.py +++ b/robomimic/scripts/config_gen/helper.py @@ -630,6 +630,17 @@ def set_env_settings(generator, args): else: raise ValueError + if args.data_format == 'rlds': + #If using rlds, overwrite data format to rlds + generator.add_param( + key="train.data_format", + name="", + group=-1, + values=[ + "rlds" + ], + ) + def set_mod_settings(generator, args): if args.mod == 'ld': @@ -840,6 +851,12 @@ def get_argparser(): default='im', ) + parser.add_argument( + "--data_format", + type=str, + default='hdf5', + ) + parser.add_argument( "--ckpt_mode", type=str, diff --git a/robomimic/scripts/train_rlds.py b/robomimic/scripts/train_rlds.py new file mode 100644 index 00000000..bc76c6dc --- /dev/null +++ b/robomimic/scripts/train_rlds.py @@ -0,0 +1,466 @@ +""" +The main entry point for training policies. + +Args: + config (str): path to a config json that will be used to override the default settings. + If omitted, default settings are used. This is the preferred way to run experiments. + + algo (str): name of the algorithm to run. Only needs to be provided if @config is not + provided. + + name (str): if provided, override the experiment name defined in the config + + dataset (str): if provided, override the dataset path defined in the config + + debug (bool): set this flag to run a quick training run for debugging purposes +""" + +import argparse +import json +import numpy as np +import time +import os +import shutil +import psutil +import sys +import socket +import traceback + +from collections import OrderedDict + +import torch +from torch.utils.data import DataLoader + +import robomimic +import robomimic.utils.train_utils as TrainUtils +import robomimic.utils.torch_utils as TorchUtils +import robomimic.utils.obs_utils as ObsUtils +import robomimic.utils.env_utils as EnvUtils +import robomimic.utils.file_utils as FileUtils +from robomimic.config import config_factory +from robomimic.algo import algo_factory, RolloutPolicy +from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings +from robomimic.data.dataset import (make_dataset, get_obs_normalization_stats_rlds, + get_action_normalization_stats_rlds) + +def train(config, device): + """ + Train a model using the algorithm. + """ + + # first set seeds + np.random.seed(config.train.seed) + torch.manual_seed(config.train.seed) + + # set num workers + torch.set_num_threads(1) + + print("\n============= New Training Run with Config =============") + print(config) + print("") + log_dir, ckpt_dir, video_dir, vis_dir = TrainUtils.get_exp_dir(config) + + if config.experiment.logging.terminal_output_to_txt: + # log stdout and stderr to a text file + logger = PrintLogger(os.path.join(log_dir, 'log.txt')) + sys.stdout = logger + sys.stderr = logger + + # read config to set up metadata for observation modalities (e.g. detecting rgb observations) + ObsUtils.initialize_obs_utils_with_config(config) + + # Load the datasets + train_builder, train_loader, normalization_metadata = make_dataset( + config, + train=True, + shuffle=True + ) + ds_format = config.train.data_format + assert ds_format == 'rlds' + + if config.experiment.validate: + # cap num workers for validation dataset at 1 + num_workers = min(config.train.num_data_workers, 1) + valid_builder, valid_loader, _ = make_dataset( + config, + train=True, + shuffle=True, + normalization_metadata=normalization_metadata + ) + + else: + valid_loader = None + + + # load basic metadata from training file + print("\n============= Loaded Environment Metadata =============") + env_meta = FileUtils.get_env_metadata_from_dataset_rlds(train_builder) + + # update env meta if applicable + from robomimic.utils.script_utils import deep_update + deep_update(env_meta, config.experiment.env_meta_update_dict) + + shape_meta = FileUtils.get_shape_metadata_from_dataset_rlds( + train_builder, + action_keys=config.train.action_keys, + all_obs_keys=config.all_obs_keys, + ) + + if config.experiment.env is not None: + env_meta["env_name"] = config.experiment.env + print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30) + + # create environment + envs = OrderedDict() + if config.experiment.rollout.enabled: + # create environments for validation runs + env_names = [env_meta["env_name"]] + + if config.experiment.additional_envs is not None: + for name in config.experiment.additional_envs: + env_names.append(name) + + for env_name in env_names: + env = EnvUtils.create_env_from_metadata( + env_meta=env_meta, + env_name=env_name, + render=False, + render_offscreen=config.experiment.render_video, + use_image_obs=shape_meta["use_images"], + ) + env = EnvUtils.wrap_env_from_config(env, config=config) # apply environment warpper, if applicable + envs[env.name] = env + print(envs[env.name]) + + print("") + + # setup for a new training run + data_logger = DataLogger( + log_dir, + config, + log_tb=config.experiment.logging.log_tb, + log_wandb=config.experiment.logging.log_wandb, + ) + model = algo_factory( + algo_name=config.algo_name, + config=config, + obs_key_shapes=shape_meta["all_shapes"], + ac_dim=shape_meta["ac_dim"], + device=device, + ) + + # save the config as a json file + with open(os.path.join(log_dir, '..', 'config.json'), 'w') as outfile: + json.dump(config, outfile, indent=4) + + # if checkpoint is specified, load in model weights + ckpt_path = config.experiment.ckpt_path + if ckpt_path is not None: + print("LOADING MODEL WEIGHTS FROM " + ckpt_path) + from robomimic.utils.file_utils import maybe_dict_from_checkpoint + ckpt_dict = maybe_dict_from_checkpoint(ckpt_path=ckpt_path) + model.deserialize(ckpt_dict["model"]) + + print("\n============= Model Summary =============") + print(model) # print model summary + print("") + + + # maybe retreve statistics for normalizing observations + obs_normalization_stats = None + if config.train.hdf5_normalize_obs: + obs_normalization_stats = get_obs_normalization_stats_rlds( + normalization_metadata, + config + ) + + # maybe retreve statistics for normalizing actions + action_normalization_stats = get_action_normalization_stats_rlds( + normalization_metadata, + config + ) + + # print all warnings before training begins + print("*" * 50) + print("Warnings generated by robomimic have been duplicated here (from above) for convenience. Please check them carefully.") + flush_warnings() + print("*" * 50) + print("") + + # main training loop + best_valid_loss = None + best_return = {k: -np.inf for k in envs} if config.experiment.rollout.enabled else None + best_success_rate = {k: -1. for k in envs} if config.experiment.rollout.enabled else None + last_ckpt_time = time.time() + + # number of learning steps per epoch (defaults to a full dataset pass) + train_num_steps = config.experiment.epoch_every_n_steps + valid_num_steps = config.experiment.validation_epoch_every_n_steps + + for epoch in range(1, config.train.num_epochs + 1): # epoch numbers start at 1 + step_log = TrainUtils.run_epoch( + model=model, + data_loader=train_loader, + epoch=epoch, + num_steps=train_num_steps, + obs_normalization_stats=obs_normalization_stats + ) + model.on_epoch_end(epoch) + + # setup checkpoint path + epoch_ckpt_name = "model_epoch_{}".format(epoch) + + # check for recurring checkpoint saving conditions + should_save_ckpt = False + if config.experiment.save.enabled: + time_check = (config.experiment.save.every_n_seconds is not None) and \ + (time.time() - last_ckpt_time > config.experiment.save.every_n_seconds) + epoch_check = (config.experiment.save.every_n_epochs is not None) and \ + (epoch > 0) and (epoch % config.experiment.save.every_n_epochs == 0) + epoch_list_check = (epoch in config.experiment.save.epochs) + should_save_ckpt = (time_check or epoch_check or epoch_list_check) + ckpt_reason = None + if should_save_ckpt: + last_ckpt_time = time.time() + ckpt_reason = "time" + + print("Train Epoch {}".format(epoch)) + print(json.dumps(step_log, sort_keys=True, indent=4)) + for k, v in step_log.items(): + if k.startswith("Time_"): + data_logger.record("Timing_Stats/Train_{}".format(k[5:]), v, epoch) + else: + data_logger.record("Train/{}".format(k), v, epoch) + + # Evaluate the model on validation set + if config.experiment.validate: + with torch.no_grad(): + step_log = TrainUtils.run_epoch(model=model, data_loader=valid_loader, epoch=epoch, validate=True, num_steps=valid_num_steps) + for k, v in step_log.items(): + if k.startswith("Time_"): + data_logger.record("Timing_Stats/Valid_{}".format(k[5:]), v, epoch) + else: + data_logger.record("Valid/{}".format(k), v, epoch) + + print("Validation Epoch {}".format(epoch)) + print(json.dumps(step_log, sort_keys=True, indent=4)) + + # save checkpoint if achieve new best validation loss + valid_check = "Loss" in step_log + if valid_check and (best_valid_loss is None or (step_log["Loss"] <= best_valid_loss)): + best_valid_loss = step_log["Loss"] + if config.experiment.save.enabled and config.experiment.save.on_best_validation: + epoch_ckpt_name += "_best_validation_{}".format(best_valid_loss) + should_save_ckpt = True + ckpt_reason = "valid" if ckpt_reason is None else ckpt_reason + + # Evaluate the model by by running rollouts + + # do rollouts at fixed rate or if it's time to save a new ckpt + video_paths = None + rollout_check = (epoch % config.experiment.rollout.rate == 0) or (should_save_ckpt and ckpt_reason == "time") + if config.experiment.rollout.enabled and (epoch > config.experiment.rollout.warmstart) and rollout_check: + + # wrap model as a RolloutPolicy to prepare for rollouts + rollout_model = RolloutPolicy( + model, + obs_normalization_stats=obs_normalization_stats, + action_normalization_stats=action_normalization_stats, + ) + + num_episodes = config.experiment.rollout.n + all_rollout_logs, video_paths = TrainUtils.rollout_with_stats( + policy=rollout_model, + envs=envs, + horizon=config.experiment.rollout.horizon, + use_goals=config.use_goals, + num_episodes=num_episodes, + render=False, + video_dir=video_dir if config.experiment.render_video else None, + epoch=epoch, + video_skip=config.experiment.get("video_skip", 5), + terminate_on_success=config.experiment.rollout.terminate_on_success, + ) + + # summarize results from rollouts to tensorboard and terminal + for env_name in all_rollout_logs: + rollout_logs = all_rollout_logs[env_name] + for k, v in rollout_logs.items(): + if k.startswith("Time_"): + data_logger.record("Timing_Stats/Rollout_{}_{}".format(env_name, k[5:]), v, epoch) + else: + data_logger.record("Rollout/{}/{}".format(k, env_name), v, epoch, log_stats=True) + + print("\nEpoch {} Rollouts took {}s (avg) with results:".format(epoch, rollout_logs["time"])) + print('Env: {}'.format(env_name)) + print(json.dumps(rollout_logs, sort_keys=True, indent=4)) + + # checkpoint and video saving logic + updated_stats = TrainUtils.should_save_from_rollout_logs( + all_rollout_logs=all_rollout_logs, + best_return=best_return, + best_success_rate=best_success_rate, + epoch_ckpt_name=epoch_ckpt_name, + save_on_best_rollout_return=config.experiment.save.on_best_rollout_return, + save_on_best_rollout_success_rate=config.experiment.save.on_best_rollout_success_rate, + ) + best_return = updated_stats["best_return"] + best_success_rate = updated_stats["best_success_rate"] + epoch_ckpt_name = updated_stats["epoch_ckpt_name"] + should_save_ckpt = (config.experiment.save.enabled and updated_stats["should_save_ckpt"]) or should_save_ckpt + if updated_stats["ckpt_reason"] is not None: + ckpt_reason = updated_stats["ckpt_reason"] + + # check if we need to save model MSE + should_save_mse = False + if config.experiment.mse.enabled: + if config.experiment.mse.every_n_epochs is not None and epoch % config.experiment.mse.every_n_epochs == 0: + should_save_mse = True + if config.experiment.mse.on_save_ckpt and should_save_ckpt: + should_save_mse = True + if should_save_mse: + print("Computing MSE ...") + if config.experiment.mse.visualize: + save_vis_dir = os.path.join(vis_dir, epoch_ckpt_name) + else: + save_vis_dir = None + mse_log, vis_log = model.compute_mse_visualize( + trainset, + validset, + num_samples=config.experiment.mse.num_samples, + savedir=save_vis_dir, + ) + for k, v in mse_log.items(): + data_logger.record("{}".format(k), v, epoch) + + for k, v in vis_log.items(): + data_logger.record("{}".format(k), v, epoch, data_type='image') + + + print("MSE Log Epoch {}".format(epoch)) + print(json.dumps(mse_log, sort_keys=True, indent=4)) + + # Only keep saved videos if the ckpt should be saved (but not because of validation score) + should_save_video = (should_save_ckpt and (ckpt_reason != "valid")) or config.experiment.keep_all_videos + if video_paths is not None and not should_save_video: + for env_name in video_paths: + os.remove(video_paths[env_name]) + + # Save model checkpoints based on conditions (success rate, validation loss, etc) + if should_save_ckpt: + TrainUtils.save_model( + model=model, + config=config, + env_meta=env_meta, + shape_meta=shape_meta, + ckpt_path=os.path.join(ckpt_dir, epoch_ckpt_name + ".pth"), + obs_normalization_stats=obs_normalization_stats, + action_normalization_stats=action_normalization_stats, + ) + + # Finally, log memory usage in MB + process = psutil.Process(os.getpid()) + mem_usage = int(process.memory_info().rss / 1000000) + data_logger.record("System/RAM Usage (MB)", mem_usage, epoch) + print("\nEpoch {} Memory Usage: {} MB\n".format(epoch, mem_usage)) + + # terminate logging + data_logger.close() + + +def main(args): + + if args.config is not None: + ext_cfg = json.load(open(args.config, 'r')) + config = config_factory(ext_cfg["algo_name"]) + # update config with external json - this will throw errors if + # the external config has keys not present in the base algo config + with config.values_unlocked(): + config.update(ext_cfg) + else: + config = config_factory(args.algo) + + if args.dataset is not None: + config.train.data = args.dataset + + if args.name is not None: + config.experiment.name = args.name + + # get torch device + device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda) + + # maybe modify config for debugging purposes + if args.debug: + # shrink length of training to test whether this run is likely to crash + config.unlock() + config.lock_keys() + + # train and validate (if enabled) for 3 gradient steps, for 2 epochs + config.experiment.epoch_every_n_steps = 3 + config.experiment.validation_epoch_every_n_steps = 3 + config.train.num_epochs = 2 + + # if rollouts are enabled, try 2 rollouts at end of each epoch, with 10 environment steps + config.experiment.rollout.rate = 1 + config.experiment.rollout.n = 2 + config.experiment.rollout.horizon = 10 + + # send output to a temporary directory + config.train.output_dir = "/tmp/tmp_trained_models" + + # lock config to prevent further modifications and ensure missing keys raise errors + config.lock() + + # catch error during training and print it + res_str = "finished run successfully!" + try: + train(config, device=device) + except Exception as e: + res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc()) + print(res_str) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # External config file that overwrites default config + parser.add_argument( + "--config", + type=str, + default=None, + help="(optional) path to a config json that will be used to override the default settings. \ + If omitted, default settings are used. This is the preferred way to run experiments.", + ) + + # Algorithm Name + parser.add_argument( + "--algo", + type=str, + help="(optional) name of algorithm to run. Only needs to be provided if --config is not provided", + ) + + # Experiment Name (for tensorboard, saving models, etc.) + parser.add_argument( + "--name", + type=str, + default=None, + help="(optional) if provided, override the experiment name defined in the config", + ) + + # Dataset path, to override the one in the config + parser.add_argument( + "--dataset", + type=str, + default=None, + help="(optional) if provided, override the dataset path defined in the config", + ) + + # debug mode + parser.add_argument( + "--debug", + action='store_true', + help="set this flag to run a quick training run for debugging purposes" + ) + + args = parser.parse_args() + main(args) diff --git a/robomimic/utils/data_utils.py b/robomimic/utils/data_utils.py new file mode 100644 index 00000000..7b071000 --- /dev/null +++ b/robomimic/utils/data_utils.py @@ -0,0 +1,61 @@ +import numpy as np +from typing import Any, Callable, Dict, Sequence, Union + + +def index_nested_dict(d: Dict[str, Any], index: int): + """ + Indexes a nested dictionary with backslashes separating hierarchies + """ + indices = index.split("/") + for i in indices: + if i not in d.keys(): + raise ValueError(f"Index {index} not found") + d = d[i] + return d + + +def set_nested_dict_index(d: Dict[str, Any], index: int, value): + """ + Sets an index in a nested dictionary with a value + Indexes have backslashes separating hierarchies + """ + indices = index.split("/") + for i in indices[:-1]: + if i not in d.keys(): + raise ValueError(f"Index {index} not found") + d = d[i] + d[indices[-1]] = value + + +def map_nested_dict_index(d: Dict[str, Any], index: int, map_func): + """ + Maps an index in a nested dictionary with a function + Indexes have backslashes separating hierarchies + """ + indices = index.split("/") + for i in indices[:-1]: + if i not in d.keys(): + raise ValueError(f"Index {index} not found") + d = d[i] + d[indices[-1]] = map_func(d[indices[-1]]) + + +def tree_map( + x: Dict[str, Any], + map_fn: Callable, + *, + _keypath: str = "", +) -> Dict[str, Any]: + + if not isinstance(x, dict): + out = map_fn(x) + return out + out = {} + for key in x.keys(): + if isinstance(x[key], dict): + out[key] = tree_map( + x[key], map_fn, _keypath=_keypath + key + "/" + ) + else: + out[key] = map_fn(x[key]) + return out diff --git a/robomimic/utils/file_utils.py b/robomimic/utils/file_utils.py index eb590bd8..efe5e312 100644 --- a/robomimic/utils/file_utils.py +++ b/robomimic/utils/file_utils.py @@ -12,10 +12,12 @@ from tqdm import tqdm import torch +import tensorflow_datasets as tfds import robomimic.utils.obs_utils as ObsUtils import robomimic.utils.env_utils as EnvUtils import robomimic.utils.torch_utils as TorchUtils +import robomimic.utils.data_utils as DataUtils from robomimic.config import config_factory from robomimic.algo import algo_factory from robomimic.algo import RolloutPolicy @@ -108,6 +110,28 @@ def get_env_metadata_from_dataset(dataset_path, ds_format="robomimic"): return env_meta +def get_env_metadata_from_dataset_rlds(builder): + """ + Retrieves env metadata from dataset. + + Args: + dataset_path (str): path to dataset + + Returns: + env_meta (dict): environment metadata. Contains 3 keys: + + :`'env_name'`: name of environment + :`'type'`: type of environment, should be a value in EB.EnvType + :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor + """ + env_meta = builder.info.metadata['env_metadata'] + #Fix weird json property that turns bool into _bool + DataUtils.tree_map(env_meta, + lambda x: bool(x) if isinstance(x, bool) else x + ) + return env_meta + + def get_shape_metadata_from_dataset(dataset_path, action_keys, all_obs_keys=None, ds_format="robomimic", verbose=False): """ Retrieves shape metadata from dataset. @@ -208,6 +232,27 @@ def get_shape_metadata_from_dataset(dataset_path, action_keys, all_obs_keys=None return shape_meta +def get_shape_metadata_from_dataset_rlds(builder, action_keys, all_obs_keys=None): + shape_meta = {} + info = builder.info + for key in action_keys: + assert len(DataUtils.index_nested_dict( + info.features['steps'], key).shape) == 1 # shape should be (D) + action_dim = sum([DataUtils.index_nested_dict( + info.features['steps'], key).shape[0] for key in action_keys]) + shape_meta["ac_dim"] = action_dim + + if all_obs_keys is None: + all_obs_keys = info.features['steps']['observation'].keys() + shape_meta['all_obs_keys'] = all_obs_keys + shape_meta['all_shapes'] = OrderedDict({key: list(feature.shape) for key, feature + in info.features['steps']['observation'].items()}) + shape_meta['use_images'] = np.any([isinstance(feature, tfds.features.Image) + for key, feature in info.features['steps']['observation'].items()]) + + return shape_meta + + def load_dict_from_checkpoint(ckpt_path): """ Load checkpoint dictionary from a checkpoint file. From cc2c0ac26a4e409ecf29f593b2cc43577b524191 Mon Sep 17 00:00:00 2001 From: Jonathan Yang Date: Sat, 14 Oct 2023 01:37:37 -0700 Subject: [PATCH 2/4] Add real-world training --- robomimic/data/common_transformations.py | 5 +- robomimic/data/dataset.py | 88 +++---- robomimic/data/dataset_shapes.py | 13 + robomimic/data/dataset_transformations.py | 78 ++++-- robomimic/envs/env_base.py | 1 + robomimic/envs/env_dummy.py | 98 ++++++++ robomimic/scripts/config_gen/diffusion_gen.py | 223 +++++++++++++----- robomimic/scripts/config_gen/helper.py | 203 ++++++++++------ robomimic/scripts/train_rlds.py | 1 - robomimic/utils/data_utils.py | 56 +++++ robomimic/utils/env_utils.py | 3 + robomimic/utils/file_utils.py | 44 +++- robomimic/utils/hyperparam_utils.py | 7 +- robomimic/utils/tensorflow_utils.py | 95 ++++++++ 14 files changed, 706 insertions(+), 209 deletions(-) create mode 100644 robomimic/data/dataset_shapes.py create mode 100644 robomimic/envs/env_dummy.py create mode 100644 robomimic/utils/tensorflow_utils.py diff --git a/robomimic/data/common_transformations.py b/robomimic/data/common_transformations.py index a2c3c5cd..f9560505 100644 --- a/robomimic/data/common_transformations.py +++ b/robomimic/data/common_transformations.py @@ -34,9 +34,10 @@ def normalize_obs_and_actions(traj, config, metadata): ''' action_config = config.train.action_config normal_keys = [key for key in config.train.action_keys - if action_config[key].get('normalization', None) == 'normal'] + if key in action_config.keys() and action_config[key].get('normalization', None) == 'normal'] + min_max_keys = [key for key in config.train.action_keys - if action_config[key].get('normalization', None) == 'min_max'] + if key in action_config.keys() and action_config[key].get('normalization', None) == 'min_max'] for key in normal_keys: map_nested_dict_index( diff --git a/robomimic/data/dataset.py b/robomimic/data/dataset.py index 23873793..8f51aeaf 100644 --- a/robomimic/data/dataset.py +++ b/robomimic/data/dataset.py @@ -18,8 +18,8 @@ import robomimic.utils.torch_utils as TorchUtils from .dataset_transformations import RLDS_TRAJECTORY_MAP_TRANSFORMS -from .common_transformations import * -from robomimic.utils.data_utils import * +import robomimic.data.common_transformations as CommonTransforms +import robomimic.utils.data_utils as DataUtils class RLDSTorchDataset: @@ -33,7 +33,7 @@ def __iter__(self): torch_batch = {} for key in self.keys: if key in batch.keys(): - torch_batch[key] = tree_map( + torch_batch[key] = DataUtils.tree_map( batch[key], map_fn=lambda x: torch.Tensor(x).to(self.device) ) @@ -51,21 +51,21 @@ def get_action_normalization_stats_rlds(obs_action_metadata, config): for key in config.train.action_keys: if key in normal_keys: normal_stats = { - 'scale': obs_action_metadata[key]['std'].numpy().reshape(1, -1), - 'offset': obs_action_metadata[key]['mean'].numpy().reshape(1, -1) + 'scale': obs_action_metadata[key]['std'].reshape(1, -1), + 'offset': obs_action_metadata[key]['mean'].reshape(1, -1) } stats[key] = normal_stats elif key in min_max_keys: min_max_range = obs_action_metadata[key]['max'] - obs_action_metadata[key]['min'] min_max_stats = { - 'scale': (min_max_range / 2).numpy().reshape(1, -1), - 'offset': (obs_action_metadata[key]['min'] + min_max_range / 2).numpy().reshape(1, -1) + 'scale': (min_max_range / 2).reshape(1, -1), + 'offset': (obs_action_metadata[key]['min'] + min_max_range / 2).reshape(1, -1) } stats[key] = min_max_stats else: identity_stats = { - 'scale': np.ones_like(obs_action_metadata[key]['std'].numpy()).reshape(1, -1), - 'offset': np.zeros_like(obs_action_metadata[key]['mean'].numpy()).reshape(1, -1) + 'scale': np.ones_like(obs_action_metadata[key]['std']).reshape(1, -1), + 'offset': np.zeros_like(obs_action_metadata[key]['mean']).reshape(1, -1) } stats[key] = identity_stats return stats @@ -104,11 +104,16 @@ def get_obs_action_metadata( else: print("Computing obs/action statistics for normalization...") eps_by_key = {key: [] for key in keys} - for episode in tqdm.tqdm(dataset.take(30000)): + + i, n_samples = 0, 10 + dataset_iter = dataset.as_numpy_iterator() + for _ in tqdm.tqdm(range(n_samples)): + episode = next(dataset_iter) + i = i + 1 for key in keys: - eps_by_key[key].append(index_nested_dict(episode, key).numpy()) + eps_by_key[key].append(DataUtils.index_nested_dict(episode, key)) eps_by_key = {key: np.concatenate(values) for key, values in eps_by_key.items()} - + metadata = {} for key in keys: metadata[key] = { @@ -117,15 +122,21 @@ def get_obs_action_metadata( "max": eps_by_key[key].max(0), "min": eps_by_key[key].min(0), } - del eps_by_key with tf.io.gfile.GFile(path, "wb") as f: pickle.dump(metadata, f) logging.info("Done!") - return { - k: {k2: tf.convert_to_tensor(v2, dtype=tf.float32) for k2, v2 in v.items()} - for k, v in metadata.items() - } + return metadata + + +def decode_dataset( + dataset: tf.data.Dataset + ): + #Decode images + dataset = dataset.frame_map( + DataUtils.decode_images + ) + return dataset def apply_common_transforms( @@ -136,14 +147,11 @@ def apply_common_transforms( obs_action_metadata: Optional[dict] = None, ): - #Decode images - dataset = dataset.frame_map(dl.transforms.decode_images) - #Normalize observations and actions if obs_action_metadata is not None: dataset = dataset.map( partial( - normalize_obs_and_actions, + CommonTransforms.normalize_obs_and_actions, config=config, metadata=obs_action_metadata, ), @@ -153,34 +161,25 @@ def apply_common_transforms( if config.train.goal_mode == 'last' or config.train.goal_mode == 'uniform': dataset = dataset.map( partial( - relabel_goals_transform, + CommonTransforms.relabel_goals_transform, goal_mode=config.goal_mode ), num_parallel_calls=tf.data.AUTOTUNE ) - #Stack frames - ''' - if config.train.frame_stack is not None and config.train.frame_stack > 1: - dataset = dataset.map( - partial( - frame_stack_transform, - num_frames=config.train.frame_stack - ) - ) - ''' + #Concatenate actions if config.train.action_keys != None: dataset = dataset.map( partial( - concatenate_action_transform, + CommonTransforms.concatenate_action_transform, action_keys=config.train.action_keys ), num_parallel_calls=tf.data.AUTOTUNE ) - #Get a random subset of length seq_length + #Get a random subset of length frame_stack + seq_length - 1 dataset = dataset.map( partial( - random_dataset_sequence_transform_v2, + CommonTransforms.random_dataset_sequence_transform_v2, frame_stack=config.train.frame_stack, seq_length=config.train.seq_length, pad_frame_stack=config.train.pad_frame_stack, @@ -213,13 +212,14 @@ def make_dataset( split = "train" if train else "val" dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle) + dataset = decode_dataset(dataset) if name in RLDS_TRAJECTORY_MAP_TRANSFORMS: if RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre'] is not None: - dataset = dataset.map(RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre']) - - metadata_keys = ['action'] - if config.train.action_keys is not None: - metadata_keys.extend(config.train.action_keys) + dataset = dataset.map(partial( + RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre'], + config=config) + ) + metadata_keys = [k for k in config.train.action_keys] if config.all_obs_keys is not None: metadata_keys.extend([f'observation/{k}' for k in config.all_obs_keys]) @@ -228,9 +228,8 @@ def make_dataset( builder, dataset, keys=metadata_keys, - load_if_exists=False + load_if_exists=True#False ) - dataset = apply_common_transforms( dataset, config=config, @@ -240,7 +239,10 @@ def make_dataset( ) if name in RLDS_TRAJECTORY_MAP_TRANSFORMS: if RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post'] is not None: - dataset = dataset.map(RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post']) + dataset = dataset.map(partial( + RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post'], + config=config) + ) dataset = dataset.repeat().batch(config.train.batch_size).prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.as_numpy_iterator() dataset = RLDSTorchDataset(dataset) diff --git a/robomimic/data/dataset_shapes.py b/robomimic/data/dataset_shapes.py new file mode 100644 index 00000000..c3ef7658 --- /dev/null +++ b/robomimic/data/dataset_shapes.py @@ -0,0 +1,13 @@ +import numpy as np + +r2d2_dataset_shapes = { + 'action_dict/abs_pos': (3,), + 'action_dict/abs_rot_6d': (6,), +} + + +DATASET_SHAPES = { + 'r2_d2': r2d2_dataset_shapes, +} + + diff --git a/robomimic/data/dataset_transformations.py b/robomimic/data/dataset_transformations.py index 99650ab8..6502d981 100644 --- a/robomimic/data/dataset_transformations.py +++ b/robomimic/data/dataset_transformations.py @@ -1,36 +1,77 @@ from typing import Any, Callable, Dict, Sequence, Union +import robomimic.utils.tensorflow_utils as TensorflowUtils import tensorflow as tf -def r2d2_dataset_pre_transform(traj: Dict[str, Any]) -> Dict[str, Any]: +def r2d2_dataset_pre_transform(traj: Dict[str, Any], + config: Dict[str, Any]) -> Dict[str, Any]: # every input feature is batched, ie has leading batch dimension keep_keys = [ - "observation", - "action", - "action_dict", - "language_instruction", - "is_terminal", - "is_last", - "_traj_index", + 'observation', + 'action', ] - traj = {k: v for k, v in traj.items() if k in keep_keys} - return traj + ac_keys = ['cartesian_position', 'cartesian_velocity'] + new_traj = {k: v for k, v in traj.items() if k in keep_keys} + new_traj['action_dict'] = { + 'gripper_position': traj['action_dict']['gripper_position'] + } + for key in ac_keys: + in_action = traj['action_dict'][key] + pos = traj['action_dict'][key][:, :3] + rot = traj['action_dict'][key][:, 3:6] + + rot_6d = TensorflowUtils.euler_angles_to_rot_6d( + rot, convention="XYZ", + ) + if key == 'cartesian_position': + prefix = 'abs_' + else: + prefix = 'rel_' + + new_traj['action_dict'].update({ + prefix + 'pos': pos, + prefix + 'rot_euler': rot, + prefix + 'rot_6d': rot_6d + }) + return new_traj -def r2d2_dataset_post_transform(traj: Dict[str, Any]) -> Dict[str, Any]: +def r2d2_dataset_post_transform(traj: Dict[str, Any], + config: Dict[str, Any]) -> Dict[str, Any]: + + import pdb; pdb.set_trace() + new_traj = {'observation': {}} + for key in config.all_obs_keys: + nested_keys = key.split('/') + value = traj['observation'] + assign = new_traj['observation'] + for i, nk in enumerate(nested_keys): + if i == len(nested_keys) - 1: + assign[nk] = value[nk] + break + value = value[nk] + if nk not in assign.keys(): + assign[nk] = dict() + assign = assign[nk] #Set obs key - traj['obs'] = traj['observation'] + new_traj['obs'] = new_traj['observation'] #Set actions key - traj['actions'] = traj['action'] + new_traj['actions'] = traj['action'] #Use one goal per sequence if 'goal_observation' in traj.keys(): - traj['goal_obs'] = traj['goal_observation'][0] - return traj + new_traj['goal_obs'] = traj['goal_observation'][0] + keep_keys = ['obs', + 'goal_obs', + 'actions', + 'action_dict'] + new_traj = {k: v for k, v in new_traj.items() if k in keep_keys} + return new_traj -def robomimic_dataset_pre_transform(traj: Dict[str, Any]) -> Dict[str, Any]: +def robomimic_dataset_pre_transform(traj: Dict[str, Any], + config: Dict[str, Any]) -> Dict[str, Any]: # every input feature is batched, ie has leading batch dimension keep_keys = [ "observation", @@ -45,7 +86,8 @@ def robomimic_dataset_pre_transform(traj: Dict[str, Any]) -> Dict[str, Any]: return traj -def robomimic_dataset_post_transform(traj: Dict[str, Any]) -> Dict[str, Any]: +def robomimic_dataset_post_transform(traj: Dict[str, Any], + config: Dict[str, Any]) -> Dict[str, Any]: new_traj = dict() #Set obs key traj['obs'] = traj['observation'] @@ -66,7 +108,7 @@ def robomimic_dataset_post_transform(traj: Dict[str, Any]) -> Dict[str, Any]: RLDS_TRAJECTORY_MAP_TRANSFORMS = { - 'r2d2': { + 'r2_d2': { 'pre': r2d2_dataset_pre_transform, 'post': r2d2_dataset_post_transform, }, diff --git a/robomimic/envs/env_base.py b/robomimic/envs/env_base.py index ee3184c2..3b5e8cf3 100644 --- a/robomimic/envs/env_base.py +++ b/robomimic/envs/env_base.py @@ -14,6 +14,7 @@ class EnvType: ROBOSUITE_TYPE = 1 GYM_TYPE = 2 IG_MOMART_TYPE = 3 + DUMMY_TYPE = 4 class EnvBase(abc.ABC): diff --git a/robomimic/envs/env_dummy.py b/robomimic/envs/env_dummy.py new file mode 100644 index 00000000..d539a68c --- /dev/null +++ b/robomimic/envs/env_dummy.py @@ -0,0 +1,98 @@ +import json +import numpy as np +from copy import deepcopy +import robomimic.envs.env_base as EB + + +class EnvDummy(EB.EnvBase): + """Dummy env used for real-world cases when env doesn't exist""" + def __init__( + self, + env_name, + render=False, + render_offscreen=False, + use_image_obs=False, + postprocess_visual_obs=True, + **kwargs, + ): + self._env_name = env_name + + def step(self, action): + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def reset_to(self, state): + raise NotImplementedError + + def render(self, mode="human", height=None, width=None, camera_name=None, **kwargs): + raise NotImplementedError + + def get_observation(self, obs=None): + raise NotImplementedError + + def get_state(self): + raise NotImplementedError + + def get_reward(self): + raise NotImplementedError + + def get_goal(self): + raise NotImplementedError + + def set_goal(self, **kwargs): + raise NotImplementedError + + def is_done(self): + raise NotImplementedError + + def is_success(self): + raise NotImplementedError + + @property + def action_dimension(self): + raise NotImplementedError + + @property + def name(self): + """ + Returns name of environment name (str). + """ + return self._env_name + + @property + def type(self): + """ + Returns environment type (int) for this kind of environment. + This helps identify this env class. + """ + return EB.EnvType.GYM_TYPE + + def serialize(self): + """ + Save all information needed to re-instantiate this environment in a dictionary. + This is the same as @env_meta - environment metadata stored in hdf5 datasets, + and used in utils/env_utils.py. + """ + return dict(env_name=self.name, type=self.type) + + @classmethod + def create_for_data_processing(cls, env_name, camera_names, camera_height, camera_width, reward_shaping, **kwargs): + raise NotImplementedError + + @property + def rollout_exceptions(self): + """ + Return tuple of exceptions to except when doing rollouts. This is useful to ensure + that the entire training run doesn't crash because of a bad policy that causes unstable + simulation computations. + """ + raise NotImplementedError + + def __repr__(self): + """ + Pretty-print env description. + """ + return f'{self.name} Dummy Env' + diff --git a/robomimic/scripts/config_gen/diffusion_gen.py b/robomimic/scripts/config_gen/diffusion_gen.py index 8a86a46e..3fe11a7b 100644 --- a/robomimic/scripts/config_gen/diffusion_gen.py +++ b/robomimic/scripts/config_gen/diffusion_gen.py @@ -1,4 +1,5 @@ from robomimic.scripts.config_gen.helper import * +import os def make_generator_helper(args): algo_name_short = "diffusion_policy" @@ -27,6 +28,16 @@ def make_generator_helper(args): values=[1000], ) + generator.add_param( + key="train.data_format", + name="df", + group=1123, + values=[args.data_format], + value_names=[ + "data_format", + ] + ) + # use ddim by default generator.add_param( key="algo.ddim.enabled", @@ -51,74 +62,158 @@ def make_generator_helper(args): if args.env == "r2d2": generator.add_param( - key="train.data", - name="ds", - group=2, - values=[ - [{"path": p} for p in scan_datasets("~/Downloads/example_pen_in_cup", postfix="trajectory_im128.h5")], - ], - value_names=[ - "pen-in-cup", - ], - ) - generator.add_param( - key="train.action_keys", - name="ac_keys", - group=-1, - values=[ - [ - "action/abs_pos", - "action/abs_rot_6d", - "action/gripper_position", - ], - ], - value_names=[ - "abs", - ], - hidename=True, - ) - generator.add_param( - key="observation.modalities.obs.rgb", - name="cams", - group=130, - values=[ - # ["camera/image/hand_camera_left_image"], - # ["camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image"], - ["camera/image/hand_camera_left_image", "camera/image/varied_camera_1_left_image", "camera/image/varied_camera_2_left_image"], - # [ - # "camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image", - # "camera/image/varied_camera_1_left_image", "camera/image/varied_camera_1_right_image", - # "camera/image/varied_camera_2_left_image", "camera/image/varied_camera_2_right_image", - # ], - ], + key="train.data_format", + name="df", + group=1123, + values=[args.data_format], value_names=[ - # "wrist", - # "wrist-stereo", - "3cams", - # "3cams-stereo", + "data_format", ] - ) + ) + if args.data_format == 'hdf5': + generator.add_param( + key="train.data", + name="ds", + group=2, + values=[ + [{"path": p} for p in scan_datasets("~/Downloads/example_pen_in_cup", postfix="trajectory_im128.h5")], + ], + value_names=[ + "pen-in-cup", + ], + ) + generator.add_param( + key="train.action_keys", + name="ac_keys", + group=-1, + values=[ + [ + "action/abs_pos", + "action/abs_rot_6d", + "action/gripper_position", + ], + ], + value_names=[ + "abs", + ], + hidename=True, + ) + generator.add_param( + key="observation.modalities.obs.rgb", + name="cams", + group=130, + values=[ + # ["camera/image/hand_camera_left_image"], + # ["camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image"], + ["camera/image/hand_camera_left_image", "camera/image/varied_camera_1_left_image", "camera/image/varied_camera_2_left_image"], + # [ + # "camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image", + # "camera/image/varied_camera_1_left_image", "camera/image/varied_camera_1_right_image", + # "camera/image/varied_camera_2_left_image", "camera/image/varied_camera_2_right_image", + # ], + ], + value_names=[ + # "wrist", + # "wrist-stereo", + "3cams", + # "3cams-stereo", + ] + ) + generator.add_param( + key="observation.modalities.obs.low_dim", + name="ldkeys", + group=2498, + values=[ + ["robot_state/cartesian_position", "robot_state/gripper_position"], + # [ + # "robot_state/cartesian_position", "robot_state/gripper_position", + # "camera/extrinsics/hand_camera_left", "camera/extrinsics/hand_camera_left_gripper_offset", + # "camera/extrinsics/hand_camera_right", "camera/extrinsics/hand_camera_right_gripper_offset", + # "camera/extrinsics/varied_camera_1_left", "camera/extrinsics/varied_camera_1_right", + # "camera/extrinsics/varied_camera_2_left", "camera/extrinsics/varied_camera_2_right", + # ] + ], + value_names=[ + "proprio", + # "proprio-extrinsics", + ] + ) - generator.add_param( - key="observation.modalities.obs.low_dim", - name="ldkeys", - group=2498, - values=[ - ["robot_state/cartesian_position", "robot_state/gripper_position"], - # [ - # "robot_state/cartesian_position", "robot_state/gripper_position", - # "camera/extrinsics/hand_camera_left", "camera/extrinsics/hand_camera_left_gripper_offset", - # "camera/extrinsics/hand_camera_right", "camera/extrinsics/hand_camera_right_gripper_offset", - # "camera/extrinsics/varied_camera_1_left", "camera/extrinsics/varied_camera_1_right", - # "camera/extrinsics/varied_camera_2_left", "camera/extrinsics/varied_camera_2_right", - # ] - ], - value_names=[ - "proprio", - # "proprio-extrinsics", - ] - ) + elif args.data_format == 'rlds': + generator.add_param( + key="train.data", + name="ds", + group=2, + values=[ + [ + {"path": "/iris/u/jyang27/rlds_data", + "name": "r2_d2"}, # replace with your own path + ], + ], + value_names=[ + "pen-in-cup" + ], + ) + generator.add_param( + key="train.action_keys", + name="ac_keys", + group=-1, + values=[ + [ + "action_dict/abs_pos", + "action_dict/abs_rot_6d", + "action_dict/gripper_position", + ], + ], + value_names=[ + "abs", + ], + hidename=True, + ) + generator.add_param( + key="observation.modalities.obs.rgb", + name="cams", + group=130, + values=[ + # ["camera/image/hand_camera_left_image"], + # ["camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image"], + ["wrist_image_left", "exterior_image_1_left", "exterior_image_2_left"], + # [ + # "camera/image/hand_camera_left_image", "camera/image/hand_camera_right_image", + # "camera/image/varied_camera_1_left_image", "camera/image/varied_camera_1_right_image", + # "camera/image/varied_camera_2_left_image", "camera/image/varied_camera_2_right_image", + # ], + ], + value_names=[ + # "wrist", + # "wrist-stereo", + "3cams", + # "3cams-stereo", + ] + ) + generator.add_param( + key="observation.modalities.obs.low_dim", + name="ldkeys", + group=2498, + values=[ + ["cartesian_position", "gripper_position"], + # [ + # "robot_state/cartesian_position", "robot_state/gripper_position", + # "camera/extrinsics/hand_camera_left", "camera/extrinsics/hand_camera_left_gripper_offset", + # "camera/extrinsics/hand_camera_right", "camera/extrinsics/hand_camera_right_gripper_offset", + # "camera/extrinsics/varied_camera_1_left", "camera/extrinsics/varied_camera_1_right", + # "camera/extrinsics/varied_camera_2_left", "camera/extrinsics/varied_camera_2_right", + # ] + ], + value_names=[ + "proprio", + # "proprio-extrinsics", + ] + ) + + else: + raise ValueError generator.add_param( key="observation.encoder.rgb.core_kwargs.backbone_class", name="backbone", diff --git a/robomimic/scripts/config_gen/helper.py b/robomimic/scripts/config_gen/helper.py index 7d9829a8..d4dfd4e6 100644 --- a/robomimic/scripts/config_gen/helper.py +++ b/robomimic/scripts/config_gen/helper.py @@ -183,56 +183,143 @@ def set_env_settings(generator, args): "r2d2" ], ) - - # here, we list how each action key should be treated (normalized etc) - generator.add_param( - key="train.action_config", - name="", - group=-1, - values=[ - { - "action/cartesian_position":{ - "normalization": "min_max", - }, - "action/abs_pos":{ - "normalization": "min_max", - }, - "action/abs_rot_6d":{ - "normalization": "min_max", - "format": "rot_6d", - "convert_at_runtime": "rot_euler", - }, - "action/abs_rot_euler":{ - "normalization": "min_max", - "format": "rot_euler", - }, - "action/gripper_position":{ - "normalization": "min_max", - }, - "action/cartesian_velocity":{ - "normalization": None, - }, - "action/rel_pos":{ - "normalization": None, - }, - "action/rel_rot_6d":{ - "format": "rot_6d", - "normalization": None, - "convert_at_runtime": "rot_euler", - }, - "action/rel_rot_euler":{ - "format": "rot_euler", - "normalization": None, - }, - "action/gripper_velocity":{ - "normalization": None, - }, - } - ], - ) + + if args.data_format == 'hdf5': + # here, we list how each action key should be treated (normalized etc) + generator.add_param( + key="train.action_config", + name="", + group=-1, + values=[ + { + "action/cartesian_position":{ + "normalization": "min_max", + }, + "action/abs_pos":{ + "normalization": "min_max", + }, + "action/abs_rot_6d":{ + "normalization": "min_max", + "format": "rot_6d", + "convert_at_runtime": "rot_euler", + }, + "action/abs_rot_euler":{ + "normalization": "min_max", + "format": "rot_euler", + }, + "action/gripper_position":{ + "normalization": "min_max", + }, + "action/cartesian_velocity":{ + "normalization": None, + }, + "action/rel_pos":{ + "normalization": None, + }, + "action/rel_rot_6d":{ + "format": "rot_6d", + "normalization": None, + "convert_at_runtime": "rot_euler", + }, + "action/rel_rot_euler":{ + "format": "rot_euler", + "normalization": None, + }, + "action/gripper_velocity":{ + "normalization": None, + }, + } + ], + ) + generator.add_param( + key="train.shuffled_obs_key_groups", + name="", + group=-1, + values=[[[ + ( + "camera/image/varied_camera_1_left_image", + "camera/image/varied_camera_1_right_image", + "camera/extrinsics/varied_camera_1_left", + "camera/extrinsics/varied_camera_1_right", + ), + ( + "camera/image/varied_camera_2_left_image", + "camera/image/varied_camera_2_right_image", + "camera/extrinsics/varied_camera_2_left", + "camera/extrinsics/varied_camera_2_right", + ), + ]]], + ) + elif args.data_format == 'rlds': + # here, we list how each action key should be treated (normalized etc) + generator.add_param( + key="train.action_config", + name="", + group=-1, + values=[ + { + "action_dict/cartesian_position":{ + "normalization": "min_max", + }, + "action_dict/abs_pos":{ + "normalization": "min_max", + }, + "action_dict/abs_rot_6d":{ + "normalization": "min_max", + "format": "rot_6d", + "convert_at_runtime": "rot_euler", + }, + "action_dict/abs_rot_euler":{ + "normalization": "min_max", + "format": "rot_euler", + }, + "action_dict/gripper_position":{ + "normalization": "min_max", + }, + "action_dict/cartesian_velocity":{ + "normalization": None, + }, + "action_dict/rel_pos":{ + "normalization": None, + }, + "action_dict/rel_rot_6d":{ + "format": "rot_6d", + "normalization": None, + "convert_at_runtime": "rot_euler", + }, + "action_dict/rel_rot_euler":{ + "format": "rot_euler", + "normalization": None, + }, + "action_dict/gripper_velocity":{ + "normalization": None, + }, + } + ], + ) + generator.add_param( + key="train.shuffled_obs_key_groups", + name="", + group=-1, + values=[[[ + ( + "wrist_image_left", + "exterior_image_1_left", + "exterior_image_2_left", + ), + ( + "wrist_image_right", + "exterior_image_1_right", + "exterior_image_2_right", + ), + ]]], + ) + + else: + raise ValueError generator.add_param( key="train.dataset_keys", - name="", + name="", group=-1, values=[[]], ) @@ -252,26 +339,6 @@ def set_env_settings(generator, args): "rel", ], ) - # observation key groups to swap - generator.add_param( - key="train.shuffled_obs_key_groups", - name="", - group=-1, - values=[[[ - ( - "camera/image/varied_camera_1_left_image", - "camera/image/varied_camera_1_right_image", - "camera/extrinsics/varied_camera_1_left", - "camera/extrinsics/varied_camera_1_right", - ), - ( - "camera/image/varied_camera_2_left_image", - "camera/image/varied_camera_2_right_image", - "camera/extrinsics/varied_camera_2_left", - "camera/extrinsics/varied_camera_2_right", - ), - ]]], - ) elif args.env == "kitchen": generator.add_param( key="train.action_config", diff --git a/robomimic/scripts/train_rlds.py b/robomimic/scripts/train_rlds.py index bc76c6dc..5d586ddd 100644 --- a/robomimic/scripts/train_rlds.py +++ b/robomimic/scripts/train_rlds.py @@ -95,7 +95,6 @@ def train(config, device): # load basic metadata from training file print("\n============= Loaded Environment Metadata =============") env_meta = FileUtils.get_env_metadata_from_dataset_rlds(train_builder) - # update env meta if applicable from robomimic.utils.script_utils import deep_update deep_update(env_meta, config.experiment.env_meta_update_dict) diff --git a/robomimic/utils/data_utils.py b/robomimic/utils/data_utils.py index 7b071000..5fa8bbcb 100644 --- a/robomimic/utils/data_utils.py +++ b/robomimic/utils/data_utils.py @@ -1,4 +1,6 @@ import numpy as np +import tensorflow as tf +from functools import partial from typing import Any, Callable, Dict, Sequence, Union @@ -59,3 +61,57 @@ def tree_map( else: out[key] = map_fn(x[key]) return out + + +def selective_tree_map( + x: Dict[str, Any], + match: Union[str, Sequence[str], Callable[[str, Any], bool]], + map_fn: Callable, + *, + _keypath: str = "", +) -> Dict[str, Any]: + """Maps a function over a nested dictionary, only applying it leaves that match a criterion. + + Args: + x (Dict[str, Any]): The dictionary to map over. + match (str or Sequence[str] or Callable[[str, Any], bool]): If a string or list of strings, `map_fn` will only + be applied to leaves whose key path contains one of `match`. If a function, `map_fn` will only be applied to + leaves for which `match(key_path, value)` returns True. + map_fn (Callable): The function to apply. + """ + if not callable(match): + if isinstance(match, str): + match = [match] + match_fn = lambda keypath, value: any([s in keypath for s in match]) + else: + match_fn = match + + out = {} + for key in x: + if isinstance(x[key], dict): + out[key] = selective_tree_map( + x[key], match_fn, map_fn, _keypath=_keypath + key + "/" + ) + elif match_fn(_keypath + key, x[key]): + out[key] = map_fn(x[key]) + else: + out[key] = x[key] + return out + + +def decode_images( + x: Dict[str, Any], match: Union[str, Sequence[str]] = "image" +) -> Dict[str, Any]: + if isinstance(match, str): + match = [match] + + def match_fn(keypath, value): + image_in_keypath = any([s in keypath for s in match]) + return image_in_keypath + + return selective_tree_map( + x, + match=match_fn, + map_fn=partial(tf.io.decode_image, expand_animations=False), + ) + diff --git a/robomimic/utils/env_utils.py b/robomimic/utils/env_utils.py index 6514de32..8b56e922 100644 --- a/robomimic/utils/env_utils.py +++ b/robomimic/utils/env_utils.py @@ -41,6 +41,9 @@ def get_env_class(env_meta=None, env_type=None, env=None): elif env_type == EB.EnvType.IG_MOMART_TYPE: from robomimic.envs.env_ig_momart import EnvGibsonMOMART return EnvGibsonMOMART + elif env_type == EB.EnvType.DUMMY_TYPE: + from robomimic.envs.env_dummy import EnvDummy + return EnvDummy raise Exception("code should never reach this point") diff --git a/robomimic/utils/file_utils.py b/robomimic/utils/file_utils.py index efe5e312..41d0fffb 100644 --- a/robomimic/utils/file_utils.py +++ b/robomimic/utils/file_utils.py @@ -124,11 +124,22 @@ def get_env_metadata_from_dataset_rlds(builder): :`'type'`: type of environment, should be a value in EB.EnvType :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor """ - env_meta = builder.info.metadata['env_metadata'] - #Fix weird json property that turns bool into _bool - DataUtils.tree_map(env_meta, - lambda x: bool(x) if isinstance(x, bool) else x - ) + if builder.info.metadata is None: + env_meta = None + else: + env_meta = builder.info.metadata.get('env_metadata', None) + if env_meta is not None: + #Fix weird json property that turns bool into _bool + DataUtils.tree_map(env_meta, + lambda x: bool(x) if isinstance(x, bool) else x + ) + else: + env_meta = { + 'env_name': 'rlds', + 'type': 4, + 'env_kwargs': {}, + } + return env_meta @@ -233,20 +244,31 @@ def get_shape_metadata_from_dataset(dataset_path, action_keys, all_obs_keys=None def get_shape_metadata_from_dataset_rlds(builder, action_keys, all_obs_keys=None): + from robomimic.data.dataset_shapes import DATASET_SHAPES + shape_meta = {} info = builder.info + name = builder.name + action_dim = 0 for key in action_keys: - assert len(DataUtils.index_nested_dict( - info.features['steps'], key).shape) == 1 # shape should be (D) - action_dim = sum([DataUtils.index_nested_dict( - info.features['steps'], key).shape[0] for key in action_keys]) + if key in DATASET_SHAPES[name].keys(): + action_dim += DATASET_SHAPES[name][key][0] + else: + key_shape = DataUtils.index_nested_dict( + info.features['steps'], key).shape + assert len(key_shape) == 1 + action_dim += key_shape[0] shape_meta["ac_dim"] = action_dim if all_obs_keys is None: all_obs_keys = info.features['steps']['observation'].keys() shape_meta['all_obs_keys'] = all_obs_keys - shape_meta['all_shapes'] = OrderedDict({key: list(feature.shape) for key, feature - in info.features['steps']['observation'].items()}) + shape_meta['all_shapes'] = OrderedDict() + for key, feature in info.features['steps']['observation'].items(): + shape = feature.shape + if feature.shape[-1] == min(feature.shape): + shape = [shape[-1]] + list(shape[:-1]) + shape_meta['all_shapes'][key] = shape shape_meta['use_images'] = np.any([isinstance(feature, tfds.features.Image) for key, feature in info.features['steps']['observation'].items()]) diff --git a/robomimic/utils/hyperparam_utils.py b/robomimic/utils/hyperparam_utils.py index 0ff6397e..cf422bae 100644 --- a/robomimic/utils/hyperparam_utils.py +++ b/robomimic/utils/hyperparam_utils.py @@ -294,8 +294,11 @@ def _script_from_jsons(self, json_paths): for path in json_paths: # write python command to file import robomimic - cmd = "python {}/scripts/train.py --config {}\n".format(robomimic.__path__[0], path) - + data_format = self.parameters.get('train.data_format', None) + if data_format is None or data_format.values[0] != 'rlds': + cmd = "python {}/scripts/train.py --config {}\n".format(robomimic.__path__[0], path) + else: + cmd = "python {}/scripts/train_rlds.py --config {}\n".format(robomimic.__path__[0], path) print() print(cmd) f.write(cmd) diff --git a/robomimic/utils/tensorflow_utils.py b/robomimic/utils/tensorflow_utils.py new file mode 100644 index 00000000..6407a06e --- /dev/null +++ b/robomimic/utils/tensorflow_utils.py @@ -0,0 +1,95 @@ +import tensorflow as tf + + +def euler_angles_to_rot_6d(euler_angles, convention="XYZ"): + """ + Converts tensor with rot_6d representation to euler representation. + """ + rot_mat = euler_angles_to_matrix(euler_angles, convention="XYZ") + rot_6d = matrix_to_rotation_6d(rot_mat) + return rot_6d + + +def matrix_to_rotation_6d(matrix): + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + Returns: + 6D rotation representation, of size (*, 6) + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = tf.shape(matrix)[:-2] + return tf.reshape(matrix[..., :2, :], tf.concat([batch_dim, [6]], axis=0)) + + +def euler_angles_to_matrix(euler_angles, convention): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, tf.unstack(euler_angles, axis=-1)) + ] + + # TensorFlow doesn't have a native functools.reduce or torch.matmul, so we use a loop + result = matrices[0] + for mat in matrices[1:]: + result = tf.matmul(result, mat) + + return result + + +def _axis_angle_rotation(axis, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = tf.cos(angle) + sin = tf.sin(angle) + one = tf.ones_like(angle) + zero = tf.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return tf.reshape(tf.stack(R_flat, axis=-1), tf.concat([tf.shape(angle), [3, 3]], axis=0)) + + + From f8d0a15f775c7364b4d2db61a36e0c1c3baa0151 Mon Sep 17 00:00:00 2001 From: Jonathan Yang Date: Wed, 18 Oct 2023 21:20:14 -0700 Subject: [PATCH 3/4] Remove unncesssary comments --- robomimic/data/dataset_transformations.py | 1 - robomimic/scripts/train.py | 481 ---------------------- robomimic/utils/file_utils.py | 2 +- robomimic/utils/train_utils.py | 3 +- 4 files changed, 2 insertions(+), 485 deletions(-) delete mode 100644 robomimic/scripts/train.py diff --git a/robomimic/data/dataset_transformations.py b/robomimic/data/dataset_transformations.py index 6502d981..b89575ad 100644 --- a/robomimic/data/dataset_transformations.py +++ b/robomimic/data/dataset_transformations.py @@ -39,7 +39,6 @@ def r2d2_dataset_pre_transform(traj: Dict[str, Any], def r2d2_dataset_post_transform(traj: Dict[str, Any], config: Dict[str, Any]) -> Dict[str, Any]: - import pdb; pdb.set_trace() new_traj = {'observation': {}} for key in config.all_obs_keys: nested_keys = key.split('/') diff --git a/robomimic/scripts/train.py b/robomimic/scripts/train.py deleted file mode 100644 index ffbc4666..00000000 --- a/robomimic/scripts/train.py +++ /dev/null @@ -1,481 +0,0 @@ -""" -The main entry point for training policies. - -Args: - config (str): path to a config json that will be used to override the default settings. - If omitted, default settings are used. This is the preferred way to run experiments. - - algo (str): name of the algorithm to run. Only needs to be provided if @config is not - provided. - - name (str): if provided, override the experiment name defined in the config - - dataset (str): if provided, override the dataset path defined in the config - - debug (bool): set this flag to run a quick training run for debugging purposes -""" - -import argparse -import json -import numpy as np -import time -import os -import shutil -import psutil -import sys -import socket -import traceback - -from collections import OrderedDict - -import torch -from torch.utils.data import DataLoader - -import robomimic -import robomimic.utils.train_utils as TrainUtils -import robomimic.utils.torch_utils as TorchUtils -import robomimic.utils.obs_utils as ObsUtils -import robomimic.utils.env_utils as EnvUtils -import robomimic.utils.file_utils as FileUtils -from robomimic.config import config_factory -from robomimic.algo import algo_factory, RolloutPolicy -from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings - - -def train(config, device): - """ - Train a model using the algorithm. - """ - - # first set seeds - np.random.seed(config.train.seed) - torch.manual_seed(config.train.seed) - - # set num workers - torch.set_num_threads(1) - - print("\n============= New Training Run with Config =============") - print(config) - print("") - log_dir, ckpt_dir, video_dir, vis_dir = TrainUtils.get_exp_dir(config) - - if config.experiment.logging.terminal_output_to_txt: - # log stdout and stderr to a text file - logger = PrintLogger(os.path.join(log_dir, 'log.txt')) - sys.stdout = logger - sys.stderr = logger - - # read config to set up metadata for observation modalities (e.g. detecting rgb observations) - ObsUtils.initialize_obs_utils_with_config(config) - - # make sure the dataset exists - eval_dataset_cfg = config.train.data[0] - dataset_path = os.path.expanduser(eval_dataset_cfg["path"]) - ds_format = config.train.data_format - if not os.path.exists(dataset_path): - raise Exception("Dataset at provided path {} not found!".format(dataset_path)) - - # load basic metadata from training file - print("\n============= Loaded Environment Metadata =============") - env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path, ds_format=ds_format) - - # update env meta if applicable - from robomimic.utils.script_utils import deep_update - deep_update(env_meta, config.experiment.env_meta_update_dict) - - shape_meta = FileUtils.get_shape_metadata_from_dataset( - dataset_path=dataset_path, - action_keys=config.train.action_keys, - all_obs_keys=config.all_obs_keys, - ds_format=ds_format, - verbose=True - ) - - if config.experiment.env is not None: - env_meta["env_name"] = config.experiment.env - print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30) - - # create environment - envs = OrderedDict() - if config.experiment.rollout.enabled: - # create environments for validation runs - env_names = [env_meta["env_name"]] - - if config.experiment.additional_envs is not None: - for name in config.experiment.additional_envs: - env_names.append(name) - - for env_name in env_names: - env = EnvUtils.create_env_from_metadata( - env_meta=env_meta, - env_name=env_name, - render=False, - render_offscreen=config.experiment.render_video, - use_image_obs=shape_meta["use_images"], - ) - env = EnvUtils.wrap_env_from_config(env, config=config) # apply environment warpper, if applicable - envs[env.name] = env - print(envs[env.name]) - - print("") - - # setup for a new training run - data_logger = DataLogger( - log_dir, - config, - log_tb=config.experiment.logging.log_tb, - log_wandb=config.experiment.logging.log_wandb, - ) - model = algo_factory( - algo_name=config.algo_name, - config=config, - obs_key_shapes=shape_meta["all_shapes"], - ac_dim=shape_meta["ac_dim"], - device=device, - ) - - # save the config as a json file - with open(os.path.join(log_dir, '..', 'config.json'), 'w') as outfile: - json.dump(config, outfile, indent=4) - - # if checkpoint is specified, load in model weights - ckpt_path = config.experiment.ckpt_path - if ckpt_path is not None: - print("LOADING MODEL WEIGHTS FROM " + ckpt_path) - from robomimic.utils.file_utils import maybe_dict_from_checkpoint - ckpt_dict = maybe_dict_from_checkpoint(ckpt_path=ckpt_path) - model.deserialize(ckpt_dict["model"]) - - print("\n============= Model Summary =============") - print(model) # print model summary - print("") - - # load training data - trainset, validset = TrainUtils.load_data_for_training( - config, obs_keys=shape_meta["all_obs_keys"]) - train_sampler = trainset.get_dataset_sampler() - print("\n============= Training Dataset =============") - print(trainset) - print("") - if validset is not None: - print("\n============= Validation Dataset =============") - print(validset) - print("") - - # maybe retreve statistics for normalizing observations - obs_normalization_stats = None - if config.train.hdf5_normalize_obs: - obs_normalization_stats = trainset.get_obs_normalization_stats() - - # maybe retreve statistics for normalizing actions - action_normalization_stats = trainset.get_action_normalization_stats() - - # initialize data loaders - train_loader = DataLoader( - dataset=trainset, - sampler=train_sampler, - batch_size=config.train.batch_size, - shuffle=(train_sampler is None), - num_workers=config.train.num_data_workers, - drop_last=True - ) - - if config.experiment.validate: - # cap num workers for validation dataset at 1 - num_workers = min(config.train.num_data_workers, 1) - valid_sampler = validset.get_dataset_sampler() - valid_loader = DataLoader( - dataset=validset, - sampler=valid_sampler, - batch_size=config.train.batch_size, - shuffle=(valid_sampler is None), - num_workers=num_workers, - drop_last=True - ) - else: - valid_loader = None - - # print all warnings before training begins - print("*" * 50) - print("Warnings generated by robomimic have been duplicated here (from above) for convenience. Please check them carefully.") - flush_warnings() - print("*" * 50) - print("") - - # main training loop - best_valid_loss = None - best_return = {k: -np.inf for k in envs} if config.experiment.rollout.enabled else None - best_success_rate = {k: -1. for k in envs} if config.experiment.rollout.enabled else None - last_ckpt_time = time.time() - - # number of learning steps per epoch (defaults to a full dataset pass) - train_num_steps = config.experiment.epoch_every_n_steps - valid_num_steps = config.experiment.validation_epoch_every_n_steps - - for epoch in range(1, config.train.num_epochs + 1): # epoch numbers start at 1 - step_log = TrainUtils.run_epoch( - model=model, - data_loader=train_loader, - epoch=epoch, - num_steps=train_num_steps, - obs_normalization_stats=obs_normalization_stats, - ) - model.on_epoch_end(epoch) - - # setup checkpoint path - epoch_ckpt_name = "model_epoch_{}".format(epoch) - - # check for recurring checkpoint saving conditions - should_save_ckpt = False - if config.experiment.save.enabled: - time_check = (config.experiment.save.every_n_seconds is not None) and \ - (time.time() - last_ckpt_time > config.experiment.save.every_n_seconds) - epoch_check = (config.experiment.save.every_n_epochs is not None) and \ - (epoch > 0) and (epoch % config.experiment.save.every_n_epochs == 0) - epoch_list_check = (epoch in config.experiment.save.epochs) - should_save_ckpt = (time_check or epoch_check or epoch_list_check) - ckpt_reason = None - if should_save_ckpt: - last_ckpt_time = time.time() - ckpt_reason = "time" - - print("Train Epoch {}".format(epoch)) - print(json.dumps(step_log, sort_keys=True, indent=4)) - for k, v in step_log.items(): - if k.startswith("Time_"): - data_logger.record("Timing_Stats/Train_{}".format(k[5:]), v, epoch) - else: - data_logger.record("Train/{}".format(k), v, epoch) - - # Evaluate the model on validation set - if config.experiment.validate: - with torch.no_grad(): - step_log = TrainUtils.run_epoch(model=model, data_loader=valid_loader, epoch=epoch, validate=True, num_steps=valid_num_steps) - for k, v in step_log.items(): - if k.startswith("Time_"): - data_logger.record("Timing_Stats/Valid_{}".format(k[5:]), v, epoch) - else: - data_logger.record("Valid/{}".format(k), v, epoch) - - print("Validation Epoch {}".format(epoch)) - print(json.dumps(step_log, sort_keys=True, indent=4)) - - # save checkpoint if achieve new best validation loss - valid_check = "Loss" in step_log - if valid_check and (best_valid_loss is None or (step_log["Loss"] <= best_valid_loss)): - best_valid_loss = step_log["Loss"] - if config.experiment.save.enabled and config.experiment.save.on_best_validation: - epoch_ckpt_name += "_best_validation_{}".format(best_valid_loss) - should_save_ckpt = True - ckpt_reason = "valid" if ckpt_reason is None else ckpt_reason - - # Evaluate the model by by running rollouts - - # do rollouts at fixed rate or if it's time to save a new ckpt - video_paths = None - rollout_check = (epoch % config.experiment.rollout.rate == 0) or (should_save_ckpt and ckpt_reason == "time") - if config.experiment.rollout.enabled and (epoch > config.experiment.rollout.warmstart) and rollout_check: - - # wrap model as a RolloutPolicy to prepare for rollouts - rollout_model = RolloutPolicy( - model, - obs_normalization_stats=obs_normalization_stats, - action_normalization_stats=action_normalization_stats, - ) - - num_episodes = config.experiment.rollout.n - all_rollout_logs, video_paths = TrainUtils.rollout_with_stats( - policy=rollout_model, - envs=envs, - horizon=config.experiment.rollout.horizon, - use_goals=config.use_goals, - num_episodes=num_episodes, - render=False, - video_dir=video_dir if config.experiment.render_video else None, - epoch=epoch, - video_skip=config.experiment.get("video_skip", 5), - terminate_on_success=config.experiment.rollout.terminate_on_success, - ) - - # summarize results from rollouts to tensorboard and terminal - for env_name in all_rollout_logs: - rollout_logs = all_rollout_logs[env_name] - for k, v in rollout_logs.items(): - if k.startswith("Time_"): - data_logger.record("Timing_Stats/Rollout_{}_{}".format(env_name, k[5:]), v, epoch) - else: - data_logger.record("Rollout/{}/{}".format(k, env_name), v, epoch, log_stats=True) - - print("\nEpoch {} Rollouts took {}s (avg) with results:".format(epoch, rollout_logs["time"])) - print('Env: {}'.format(env_name)) - print(json.dumps(rollout_logs, sort_keys=True, indent=4)) - - # checkpoint and video saving logic - updated_stats = TrainUtils.should_save_from_rollout_logs( - all_rollout_logs=all_rollout_logs, - best_return=best_return, - best_success_rate=best_success_rate, - epoch_ckpt_name=epoch_ckpt_name, - save_on_best_rollout_return=config.experiment.save.on_best_rollout_return, - save_on_best_rollout_success_rate=config.experiment.save.on_best_rollout_success_rate, - ) - best_return = updated_stats["best_return"] - best_success_rate = updated_stats["best_success_rate"] - epoch_ckpt_name = updated_stats["epoch_ckpt_name"] - should_save_ckpt = (config.experiment.save.enabled and updated_stats["should_save_ckpt"]) or should_save_ckpt - if updated_stats["ckpt_reason"] is not None: - ckpt_reason = updated_stats["ckpt_reason"] - - # check if we need to save model MSE - should_save_mse = False - if config.experiment.mse.enabled: - if config.experiment.mse.every_n_epochs is not None and epoch % config.experiment.mse.every_n_epochs == 0: - should_save_mse = True - if config.experiment.mse.on_save_ckpt and should_save_ckpt: - should_save_mse = True - if should_save_mse: - print("Computing MSE ...") - if config.experiment.mse.visualize: - save_vis_dir = os.path.join(vis_dir, epoch_ckpt_name) - else: - save_vis_dir = None - mse_log, vis_log = model.compute_mse_visualize( - trainset, - validset, - num_samples=config.experiment.mse.num_samples, - savedir=save_vis_dir, - ) - for k, v in mse_log.items(): - data_logger.record("{}".format(k), v, epoch) - - for k, v in vis_log.items(): - data_logger.record("{}".format(k), v, epoch, data_type='image') - - - print("MSE Log Epoch {}".format(epoch)) - print(json.dumps(mse_log, sort_keys=True, indent=4)) - - # Only keep saved videos if the ckpt should be saved (but not because of validation score) - should_save_video = (should_save_ckpt and (ckpt_reason != "valid")) or config.experiment.keep_all_videos - if video_paths is not None and not should_save_video: - for env_name in video_paths: - os.remove(video_paths[env_name]) - - # Save model checkpoints based on conditions (success rate, validation loss, etc) - if should_save_ckpt: - TrainUtils.save_model( - model=model, - config=config, - env_meta=env_meta, - shape_meta=shape_meta, - ckpt_path=os.path.join(ckpt_dir, epoch_ckpt_name + ".pth"), - obs_normalization_stats=obs_normalization_stats, - action_normalization_stats=action_normalization_stats, - ) - - # Finally, log memory usage in MB - process = psutil.Process(os.getpid()) - mem_usage = int(process.memory_info().rss / 1000000) - data_logger.record("System/RAM Usage (MB)", mem_usage, epoch) - print("\nEpoch {} Memory Usage: {} MB\n".format(epoch, mem_usage)) - - # terminate logging - data_logger.close() - - -def main(args): - - if args.config is not None: - ext_cfg = json.load(open(args.config, 'r')) - config = config_factory(ext_cfg["algo_name"]) - # update config with external json - this will throw errors if - # the external config has keys not present in the base algo config - with config.values_unlocked(): - config.update(ext_cfg) - else: - config = config_factory(args.algo) - - if args.dataset is not None: - config.train.data = args.dataset - - if args.name is not None: - config.experiment.name = args.name - - # get torch device - device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda) - - # maybe modify config for debugging purposes - if args.debug: - # shrink length of training to test whether this run is likely to crash - config.unlock() - config.lock_keys() - - # train and validate (if enabled) for 3 gradient steps, for 2 epochs - config.experiment.epoch_every_n_steps = 3 - config.experiment.validation_epoch_every_n_steps = 3 - config.train.num_epochs = 2 - - # if rollouts are enabled, try 2 rollouts at end of each epoch, with 10 environment steps - config.experiment.rollout.rate = 1 - config.experiment.rollout.n = 2 - config.experiment.rollout.horizon = 10 - - # send output to a temporary directory - config.train.output_dir = "/tmp/tmp_trained_models" - - # lock config to prevent further modifications and ensure missing keys raise errors - config.lock() - - # catch error during training and print it - res_str = "finished run successfully!" - try: - train(config, device=device) - except Exception as e: - res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc()) - print(res_str) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - # External config file that overwrites default config - parser.add_argument( - "--config", - type=str, - default=None, - help="(optional) path to a config json that will be used to override the default settings. \ - If omitted, default settings are used. This is the preferred way to run experiments.", - ) - - # Algorithm Name - parser.add_argument( - "--algo", - type=str, - help="(optional) name of algorithm to run. Only needs to be provided if --config is not provided", - ) - - # Experiment Name (for tensorboard, saving models, etc.) - parser.add_argument( - "--name", - type=str, - default=None, - help="(optional) if provided, override the experiment name defined in the config", - ) - - # Dataset path, to override the one in the config - parser.add_argument( - "--dataset", - type=str, - default=None, - help="(optional) if provided, override the dataset path defined in the config", - ) - - # debug mode - parser.add_argument( - "--debug", - action='store_true', - help="set this flag to run a quick training run for debugging purposes" - ) - - args = parser.parse_args() - main(args) diff --git a/robomimic/utils/file_utils.py b/robomimic/utils/file_utils.py index 41d0fffb..baaf6c03 100644 --- a/robomimic/utils/file_utils.py +++ b/robomimic/utils/file_utils.py @@ -251,7 +251,7 @@ def get_shape_metadata_from_dataset_rlds(builder, action_keys, all_obs_keys=None name = builder.name action_dim = 0 for key in action_keys: - if key in DATASET_SHAPES[name].keys(): + if name in DATASET_SHAPES.keys() and key in DATASET_SHAPES[name].keys(): action_dim += DATASET_SHAPES[name][key][0] else: key_shape = DataUtils.index_nested_dict( diff --git a/robomimic/utils/train_utils.py b/robomimic/utils/train_utils.py index 16202fc0..3f65c57a 100644 --- a/robomimic/utils/train_utils.py +++ b/robomimic/utils/train_utils.py @@ -602,7 +602,6 @@ def run_epoch(model, data_loader, epoch, validate=False, num_steps=None, obs_nor data_loader_iter = iter(data_loader) for _ in LogUtils.custom_tqdm(range(num_steps)): - # load next batch from data loader try: t = time.time() @@ -613,7 +612,7 @@ def run_epoch(model, data_loader, epoch, validate=False, num_steps=None, obs_nor t = time.time() batch = next(data_loader_iter) timing_stats["Data_Loading"].append(time.time() - t) - + # process batch for training t = time.time() input_batch = model.process_batch_for_training(batch) From cae90b9ad64ffc259cf350dd888d42f79e976ba5 Mon Sep 17 00:00:00 2001 From: Jonathan Yang Date: Thu, 26 Oct 2023 19:54:00 -0700 Subject: [PATCH 4/4] Move dataset decoding back --- robomimic/data/dataset.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/robomimic/data/dataset.py b/robomimic/data/dataset.py index 8f51aeaf..f2a01edb 100644 --- a/robomimic/data/dataset.py +++ b/robomimic/data/dataset.py @@ -35,7 +35,7 @@ def __iter__(self): if key in batch.keys(): torch_batch[key] = DataUtils.tree_map( batch[key], - map_fn=lambda x: torch.Tensor(x).to(self.device) + map_fn=lambda x: torch.tensor(x).to(self.device) ) yield torch_batch @@ -105,7 +105,7 @@ def get_obs_action_metadata( print("Computing obs/action statistics for normalization...") eps_by_key = {key: [] for key in keys} - i, n_samples = 0, 10 + i, n_samples = 0, 500 dataset_iter = dataset.as_numpy_iterator() for _ in tqdm.tqdm(range(n_samples)): episode = next(dataset_iter) @@ -132,6 +132,7 @@ def get_obs_action_metadata( def decode_dataset( dataset: tf.data.Dataset ): + #Decode images dataset = dataset.frame_map( DataUtils.decode_images @@ -174,7 +175,7 @@ def apply_common_transforms( CommonTransforms.concatenate_action_transform, action_keys=config.train.action_keys ), - num_parallel_calls=tf.data.AUTOTUNE + num_parallel_calls=tf.data.AUTOTUNE ) #Get a random subset of length frame_stack + seq_length - 1 dataset = dataset.map( @@ -191,6 +192,20 @@ def apply_common_transforms( return dataset +def decode_trajectory(builder, obs_keys, episode): + steps = episode + new_steps = dict() + new_steps['action_dict'] = dict() + new_steps['observation'] = dict() + for key in steps["action_dict"]: + new_steps['action_dict'][key] = builder.info.features["steps"]['action_dict'][ + key + ].decode_batch_example(steps["action_dict"][key]) + for key in obs_keys: + new_steps['observation'][key] = builder.info.features["steps"]['observation'][ + key + ].decode_batch_example(steps["observation"][key]) + return new_steps def make_dataset( config: dict, @@ -206,18 +221,19 @@ def make_dataset( data_dir = data_info['path'] builder = tfds.builder(name, data_dir=data_dir) + if "val" not in builder.info.splits: split = "train[:95%]" if train else "train[95%:]" else: split = "train" if train else "val" - dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle) - dataset = decode_dataset(dataset) + dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle, + num_parallel_reads=8) if name in RLDS_TRAJECTORY_MAP_TRANSFORMS: if RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre'] is not None: dataset = dataset.map(partial( RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre'], - config=config) + config=config), ) metadata_keys = [k for k in config.train.action_keys] if config.all_obs_keys is not None: @@ -241,8 +257,9 @@ def make_dataset( if RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post'] is not None: dataset = dataset.map(partial( RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post'], - config=config) - ) + config=config), + ) + dataset = decode_dataset(dataset) dataset = dataset.repeat().batch(config.train.batch_size).prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.as_numpy_iterator() dataset = RLDSTorchDataset(dataset)