From e907dbcf5d32b4dc17370ed27b031ccbbe2e6d4a Mon Sep 17 00:00:00 2001 From: Abhi Date: Sun, 7 Jan 2024 20:28:00 +0000 Subject: [PATCH 1/2] small changes to help with memory management --- env.yaml | 145 ++++++++++++++++++++++++++++++++++ requirements.txt | 135 +++++++++++++++++++++++++++---- robomimic/data/rtx_dataset.py | 11 ++- setup.py | 2 +- 4 files changed, 276 insertions(+), 17 deletions(-) create mode 100644 env.yaml diff --git a/env.yaml b/env.yaml new file mode 100644 index 00000000..aa75fa02 --- /dev/null +++ b/env.yaml @@ -0,0 +1,145 @@ +# This file may be used to create an environment using: +# $ conda create --name --file +# platform: linux-64 +_libgcc_mutex=0.1=main +_openmp_mutex=5.1=1_gnu +absl-py=1.4.0=pypi_0 +array-record=0.5.0=pypi_0 +astunparse=1.6.3=pypi_0 +attrs=23.2.0=pypi_0 +black=23.12.1=pypi_0 +ca-certificates=2023.12.12=h06a4308_0 +cachetools=5.3.2=pypi_0 +certifi=2023.11.17=pypi_0 +charset-normalizer=3.3.2=pypi_0 +click=8.1.7=pypi_0 +contourpy=1.2.0=pypi_0 +cycler=0.12.1=pypi_0 +detr=0.0.0=dev_0 +diffusers=0.11.1=pypi_0 +dlimp=0.0.1=dev_0 +dm-reverb=0.14.0=pypi_0 +dm-tree=0.1.8=pypi_0 +egl-probe=1.0.2=pypi_0 +etils=1.5.2=pypi_0 +filelock=3.13.1=pypi_0 +flake8=7.0.0=pypi_0 +flake8-bugbear=23.12.2=pypi_0 +flake8-comprehensions=3.14.0=pypi_0 +flatbuffers=23.5.26=pypi_0 +fonttools=4.47.0=pypi_0 +fsspec=2023.12.2=pypi_0 +gast=0.5.4=pypi_0 +google-auth=2.26.1=pypi_0 +google-auth-oauthlib=1.2.0=pypi_0 +google-pasta=0.2.0=pypi_0 +googleapis-common-protos=1.62.0=pypi_0 +grpcio=1.60.0=pypi_0 +h5py=3.10.0=pypi_0 +huggingface-hub=0.20.2=pypi_0 +idna=3.6=pypi_0 +imageio=2.33.1=pypi_0 +imageio-ffmpeg=0.4.9=pypi_0 +importlib-metadata=7.0.1=pypi_0 +importlib-resources=6.1.1=pypi_0 +jinja2=3.1.2=pypi_0 +keras=2.15.0=pypi_0 +kiwisolver=1.4.5=pypi_0 +lazy-loader=0.3=pypi_0 +ld_impl_linux-64=2.38=h1181459_1 +libclang=16.0.6=pypi_0 +libcst=1.1.0=pypi_0 +libffi=3.3=he6710b0_2 +libgcc-ng=11.2.0=h1234567_1 +libgomp=11.2.0=h1234567_1 +libstdcxx-ng=11.2.0=h1234567_1 +markdown=3.5.1=pypi_0 +markupsafe=2.1.3=pypi_0 +matplotlib=3.8.2=pypi_0 +mccabe=0.7.0=pypi_0 +ml-dtypes=0.2.0=pypi_0 +moreorless=0.4.0=pypi_0 +mpmath=1.3.0=pypi_0 +mypy-extensions=1.0.0=pypi_0 +ncurses=6.4=h6a678d5_0 +networkx=3.2.1=pypi_0 +numpy=1.26.3=pypi_0 +nvidia-cublas-cu12=12.1.3.1=pypi_0 +nvidia-cuda-cupti-cu12=12.1.105=pypi_0 +nvidia-cuda-nvrtc-cu12=12.1.105=pypi_0 +nvidia-cuda-runtime-cu12=12.1.105=pypi_0 +nvidia-cudnn-cu12=8.9.2.26=pypi_0 +nvidia-cufft-cu12=11.0.2.54=pypi_0 +nvidia-curand-cu12=10.3.2.106=pypi_0 +nvidia-cusolver-cu12=11.4.5.107=pypi_0 +nvidia-cusparse-cu12=12.1.0.106=pypi_0 +nvidia-nccl-cu12=2.18.1=pypi_0 +nvidia-nvjitlink-cu12=12.3.101=pypi_0 +nvidia-nvtx-cu12=12.1.105=pypi_0 +oauthlib=3.2.2=pypi_0 +opencv-python=4.9.0.80=pypi_0 +openssl=1.1.1w=h7f8727e_0 +opt-einsum=3.3.0=pypi_0 +packaging=23.2=pypi_0 +pathspec=0.12.1=pypi_0 +pillow=10.2.0=pypi_0 +pip=23.3.1=py39h06a4308_0 +platformdirs=4.1.0=pypi_0 +plotly=5.18.0=pypi_0 +portpicker=1.6.0=pypi_0 +promise=2.3=pypi_0 +protobuf=3.20.3=pypi_0 +psutil=5.9.7=pypi_0 +pyasn1=0.5.1=pypi_0 +pyasn1-modules=0.3.0=pypi_0 +pycodestyle=2.11.1=pypi_0 +pycosat=0.6.6=pypi_0 +pyflakes=3.2.0=pypi_0 +pyparsing=3.1.1=pypi_0 +python=3.9.0=hdb3f193_2 +python-dateutil=2.8.2=pypi_0 +pyyaml=6.0.1=pypi_0 +readline=8.2=h5eee18b_0 +regex=2023.12.25=pypi_0 +requests=2.31.0=pypi_0 +requests-oauthlib=1.3.1=pypi_0 +rlds=0.1.8=pypi_0 +robomimic=0.3.0=dev_0 +rsa=4.9=pypi_0 +scikit-image=0.22.0=pypi_0 +scipy=1.11.4=pypi_0 +setuptools=68.2.2=py39h06a4308_0 +six=1.16.0=pypi_0 +sqlite=3.41.2=h5eee18b_0 +stdlibs=2023.12.15=pypi_0 +sympy=1.12=pypi_0 +tenacity=8.2.3=pypi_0 +tensorboard=2.15.1=pypi_0 +tensorboard-data-server=0.7.2=pypi_0 +tensorboardx=2.6.2.2=pypi_0 +tensorflow=2.15.0=pypi_0 +tensorflow-datasets=4.9.3=pypi_0 +tensorflow-estimator=2.15.0=pypi_0 +tensorflow-io-gcs-filesystem=0.35.0=pypi_0 +tensorflow-metadata=1.14.0=pypi_0 +termcolor=2.4.0=pypi_0 +tifffile=2023.12.9=pypi_0 +tk=8.6.12=h1ccaba5_0 +toml=0.10.2=pypi_0 +tomli=2.0.1=pypi_0 +torch=2.1.2=pypi_0 +torchvision=0.16.2=pypi_0 +tqdm=4.66.1=pypi_0 +trailrunner=1.4.0=pypi_0 +triton=2.1.0=pypi_0 +typing-extensions=4.9.0=pypi_0 +typing-inspect=0.9.0=pypi_0 +tzdata=2023d=h04d1e81_0 +urllib3=2.1.0=pypi_0 +usort=1.0.7=pypi_0 +werkzeug=3.0.1=pypi_0 +wheel=0.41.2=py39h06a4308_0 +wrapt=1.14.1=pypi_0 +xz=5.4.5=h5eee18b_0 +zipp=3.17.0=pypi_0 +zlib=1.2.13=h5eee18b_0 diff --git a/requirements.txt b/requirements.txt index c451769d..7c32df4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,122 @@ -numpy>=1.13.3 -h5py -psutil -tqdm -termcolor -tensorboard -tensorboardX -imageio -imageio-ffmpeg -matplotlib -egl_probe>=1.0.1 -torch -torchvision +absl-py==1.4.0 +array-record==0.5.0 +astunparse==1.6.3 +attrs==23.2.0 +black==23.12.1 +cachetools==5.3.2 +certifi==2023.11.17 +charset-normalizer==3.3.2 +click==8.1.7 +contourpy==1.2.0 +cycler==0.12.1 +-e git+https://github.com/tonyzhaozh/act@73071e16a6595662d753415b90c0abb64815009c#egg=detr&subdirectory=../../act/detr diffusers==0.11.1 -pytorch3d==0.7.3 +-e git+https://github.com/kvablack/dlimp@ad72ce3a9b414db2185bc0b38461d4101a65477a#egg=dlimp +dm-reverb==0.14.0 +dm-tree==0.1.8 +egl-probe==1.0.2 +etils==1.5.2 +filelock==3.13.1 +flake8==7.0.0 +flake8-bugbear==23.12.2 +flake8-comprehensions==3.14.0 +flatbuffers==23.5.26 +fonttools==4.47.0 +fsspec==2023.12.2 +gast==0.5.4 +google-auth==2.26.1 +google-auth-oauthlib==1.2.0 +google-pasta==0.2.0 +googleapis-common-protos==1.62.0 +grpcio==1.60.0 +h5py==3.10.0 +huggingface-hub==0.20.2 +idna==3.6 +imageio==2.33.1 +imageio-ffmpeg==0.4.9 +importlib-metadata==7.0.1 +importlib-resources==6.1.1 +Jinja2==3.1.2 +keras==2.15.0 +kiwisolver==1.4.5 +lazy_loader==0.3 +libclang==16.0.6 +libcst==1.1.0 +Markdown==3.5.1 +MarkupSafe==2.1.3 +matplotlib==3.8.2 +mccabe==0.7.0 +ml-dtypes==0.2.0 +moreorless==0.4.0 +mpmath==1.3.0 +mypy-extensions==1.0.0 +networkx==3.2.1 +numpy==1.26.3 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.18.1 +nvidia-nvjitlink-cu12==12.3.101 +nvidia-nvtx-cu12==12.1.105 +oauthlib==3.2.2 +opencv-python==4.9.0.80 +opt-einsum==3.3.0 +packaging==23.2 +pathspec==0.12.1 +pillow==10.2.0 +platformdirs==4.1.0 +plotly==5.18.0 +portpicker==1.6.0 +promise==2.3 +protobuf==3.20.3 +psutil==5.9.7 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pycodestyle==2.11.1 +pycosat==0.6.6 +pyflakes==3.2.0 +pyparsing==3.1.1 +python-dateutil==2.8.2 +PyYAML==6.0.1 +regex==2023.12.25 +requests==2.31.0 +requests-oauthlib==1.3.1 +rlds==0.1.8 +-e git+https://github.com/ARISE-Initiative/robomimic.git@dbe18cc3f2623a6e73ad1353e55de6e1266aabe1#egg=robomimic +rsa==4.9 +scikit-image==0.22.0 +scipy==1.11.4 +six==1.16.0 +stdlibs==2023.12.15 +sympy==1.12 +tenacity==8.2.3 +tensorboard==2.15.1 +tensorboard-data-server==0.7.2 +tensorboardX==2.6.2.2 +tensorflow==2.15.0 +tensorflow-datasets==4.9.3 +tensorflow-estimator==2.15.0 +tensorflow-io-gcs-filesystem==0.35.0 +tensorflow-metadata==1.14.0 +termcolor==2.4.0 +tifffile==2023.12.9 +toml==0.10.2 +tomli==2.0.1 +torch==2.1.2 +torchvision==0.16.2 +tqdm==4.66.1 +trailrunner==1.4.0 +triton==2.1.0 +typing-inspect==0.9.0 +typing_extensions==4.9.0 +urllib3==2.1.0 +usort==1.0.7 +Werkzeug==3.0.1 +wrapt==1.14.1 +zipp==3.17.0 diff --git a/robomimic/data/rtx_dataset.py b/robomimic/data/rtx_dataset.py index 3e38ba78..45e8d8bf 100644 --- a/robomimic/data/rtx_dataset.py +++ b/robomimic/data/rtx_dataset.py @@ -672,6 +672,7 @@ def step_map_fn(traj: Dict[str, Any]) -> Dict[str, Any]: ).build(validate_expected_tensor_spec=False) trajectory_dataset = trajectory_transform.transform_episodic_rlds_dataset(dataset) + # combined_dataset = tf.data.Dataset.sample_from_datasets([trajectory_dataset]) # combined_dataset = combined_dataset.batch(2) @@ -682,8 +683,14 @@ def step_map_fn(traj: Dict[str, Any]) -> Dict[str, Any]: dataset = trajectory_dataset # shuffle, repeat, pre-fetch, batch # dataset = dataset.cache() # optionally keep full dataset in memory - dataset = dataset.shuffle(1000) # set shuffle buffer size - dataset = dataset.repeat().batch(config.train.batch_size).prefetch(tf.data.experimental.AUTOTUNE) + dataset = dataset.shuffle(10000) # set shuffle buffer size + dataset = dataset.repeat().batch(config.train.batch_size)#.prefetch(tf.data.experimental.AUTOTUNE) + + # memory management + # options = tf.data.Options() + # options.autotune.ram_budget = 1024 * 1024 * 1024 + # dataset = dataset.with_options(options) + dataset = dataset.as_numpy_iterator() dataset = RLDSTorchDataset(dataset) diff --git a/setup.py b/setup.py index 053e2eb3..60009289 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ "torch", "torchvision", "diffusers==0.11.1", - "pytorch3d", + # "pytorch3d", "tensorflow_datasets", "tensorflow", ], From 9fdbd9cdef6e75863ad3948f9e27b1ab760b94f1 Mon Sep 17 00:00:00 2001 From: Abhiram824 Date: Mon, 8 Jan 2024 20:39:33 +0000 Subject: [PATCH 2/2] ongoing changes to making dataloader similar to octo codebase --- robomimic/data/rtx_dataset_octo.py | 275 +++++++++++++++++++++++++++++ robomimic/scripts/train_rlds.py | 10 +- 2 files changed, 279 insertions(+), 6 deletions(-) create mode 100644 robomimic/data/rtx_dataset_octo.py diff --git a/robomimic/data/rtx_dataset_octo.py b/robomimic/data/rtx_dataset_octo.py new file mode 100644 index 00000000..c3430582 --- /dev/null +++ b/robomimic/data/rtx_dataset_octo.py @@ -0,0 +1,275 @@ +from functools import partial +import inspect +import json +import tensorflow as tf + +import tensorflow_datasets as tfds +#Don't use GPU for dataloading +tf.config.set_visible_devices([], "GPU") +from typing import Callable, Mapping, Optional, Sequence, Tuple, Union +from .dataset_transformations import RLDS_TRAJECTORY_MAP_TRANSFORMS +from typing import Any, Dict, List, Union, Tuple, Optional +import tree +import hashlib +import pickle +import torch +import robomimic.utils.torch_utils as TorchUtils +from .dataset_transformations import RLDS_TRAJECTORY_MAP_TRANSFORMS +import robomimic.data.common_transformations as CommonTransforms +import robomimic.utils.data_utils as DataUtils +from tensorflow_datasets.core.dataset_builder import DatasetBuilder +import tqdm +from torch.utils.data import DataLoader + + + + +from absl import logging +import dlimp as dl +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + + +def subsample(traj: dict, subsample_length: int) -> dict: + """Subsamples trajectories to the given length.""" + traj_len = tf.shape(traj["actions"])[0] + if traj_len > subsample_length: + indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length] + traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj) + return traj + + +def apply_trajectory_transforms(dataset, train, subsample_length = None, num_parallel_calls: int = tf.data.AUTOTUNE): + if train and subsample_length is not None: + + #TODO CHANGE THIS TO FOLLOW SAME APPROACH AS REVERB + dataset = dataset.filter( + lambda x: tf.shape(x["actions"])[0] >= subsample_length + ) + + dataset = dataset.traj_map( + partial(subsample, subsample_length=subsample_length), + num_parallel_calls, + ) + return dataset + + +class RLDSTorchDataset(torch.utils.data.IterableDataset): + def __init__(self, dataset_iterator, try_to_use_cuda=True): + self.dataset_iterator = dataset_iterator + self.device = TorchUtils.get_torch_device(try_to_use_cuda) + self.keys = ['obs', 'goal_obs', 'actions'] + + def __iter__(self): + for batch in self.dataset_iterator.as_numpy_iterator(): + torch_batch = {} + for key in self.keys: + if key in batch.keys(): + torch_batch[key] = DataUtils.tree_map( + batch[key], + map_fn=lambda x: torch.tensor(x).to(self.device) + ) + yield torch_batch + + +def decode_images( + obs: dict, +): + for key in obs["obs"]: + if "image" in key: + image = obs["obs"][key] + assert image.dtype == tf.string + image_decoded = tf.io.decode_image( + image, expand_animations=False, dtype=tf.uint8 + ) + obs["obs"][key] = image_decoded + return obs + + + +def apply_frame_transforms( + dataset: dl.DLataset, + *, + num_parallel_calls: int = tf.data.AUTOTUNE, +) -> dl.DLataset: + + # decode + resize images (and depth images) + dataset = dataset.frame_map( + decode_images, + num_parallel_calls, + ) + return dataset + +def get_obs_action_metadata( + builder: DatasetBuilder, dataset: tf.data.Dataset, keys: List[str], + load_if_exists=True +) -> Dict[str, Dict[str, List[float]]]: + # get statistics file path --> embed unique hash that catches if dataset info changed + data_info_hash = hashlib.sha256( + (str(builder.info) + str(keys)).encode("utf-8") + ).hexdigest() + path = tf.io.gfile.join( + builder.info.data_dir, f"obs_action_stats_{data_info_hash}.pkl" + ) + + # check if stats already exist and load, otherwise compute + if tf.io.gfile.exists(path) and load_if_exists: + print(f"Loading existing statistics for normalization from {path}.") + with tf.io.gfile.GFile(path, "rb") as f: + metadata = pickle.load(f) + else: + print("Computing obs/action statistics for normalization...") + eps_by_key = {key: [] for key in keys} + + i, n_samples = 0, 50 + dataset_iter = dataset.as_numpy_iterator() + for _ in tqdm.tqdm(range(n_samples)): + episode = next(dataset_iter) + i = i + 1 + for key in keys: + eps_by_key[key].append(DataUtils.index_nested_dict(episode, key)) + eps_by_key = {key: np.concatenate(values) for key, values in eps_by_key.items()} + + metadata = {} + # breakpoint() + for key in keys: + # #print(key) + # #print(eps_by_key[key]) + # breakpoint() + if "image" not in key: + metadata[key] = { + "mean": eps_by_key[key].mean(0), + "std": eps_by_key[key].std(0), + "max": eps_by_key[key].max(0), + "min": eps_by_key[key].min(0), + } + else: + metadata[key] = { + "mean": np.frombuffer(eps_by_key[key], dtype=np.uint8).mean(0), + "std": np.frombuffer(eps_by_key[key], dtype=np.uint8).std(0), + "max": np.frombuffer(eps_by_key[key], dtype=np.uint8).max(0), + "min": np.frombuffer(eps_by_key[key], dtype=np.uint8).min(0), + } + # breakpoint() + # with tf.io.gfile.GFile(path, "wb") as f: + # pickle.dump(metadata, f) + logging.info("Done!") + + return metadata + + +def make_dataset_from_rlds( + config:dict, + train: bool, + shuffle: bool = True, + num_parallel_reads: int = tf.data.AUTOTUNE, + num_parallel_calls: int = tf.data.AUTOTUNE, +) -> Tuple[dl.DLataset, dict]: + + data_info = config.train.data[0] + name = data_info['name'] + data_dir = data_info['path'] + builder = tfds.builder(name, data_dir=data_dir) + + + + + # construct the dataset + if "val" not in builder.info.splits: + split = "train[:95%]" if train else "train[95%:]" + else: + split = "train" if train else "val" + + dataset = dl.DLataset.from_rlds( + builder, split=split, shuffle=shuffle, num_parallel_reads=num_parallel_reads + ) + + metadata_keys = [k for k in config.train.action_keys] + if config.all_obs_keys is not None: + metadata_keys.extend([f'observation/{k}' + for k in config.all_obs_keys]) + + normalization_metadata = get_obs_action_metadata( + builder, + dataset, + keys=metadata_keys, + load_if_exists=True#False + ) + + if name in RLDS_TRAJECTORY_MAP_TRANSFORMS: + if RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre'] is not None: + dataset = dataset.traj_map( + partial( + RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['pre'], + config=config + ), + num_parallel_calls + ) + + if normalization_metadata is not None: + dataset = dataset.traj_map( + partial( + CommonTransforms.normalize_obs_and_actions, + config=config, + metadata=normalization_metadata, + ), + num_parallel_calls, + ) + if config.train.action_keys != None: + dataset = dataset.traj_map( + partial( + CommonTransforms.concatenate_action_transform, + action_keys=config.train.action_keys + ), + num_parallel_calls, + ) + + if name in RLDS_TRAJECTORY_MAP_TRANSFORMS: + if RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post'] is not None: + dataset = dataset.traj_map( + partial( + RLDS_TRAJECTORY_MAP_TRANSFORMS[name]['post'], + config=config + ), + num_parallel_calls + ) + + return builder, dataset, normalization_metadata + + +def make_single_dataset( + config: dict, + *, + train: bool, + traj_transform_kwargs: dict = None, + shuffle_buffer_size=1000, +) -> dl.DLataset: + + if traj_transform_kwargs is None: + traj_transform_kwargs = { + "subsample_length": config.train.seq_length + config.train.frame_stack - 1 + } + + builder, dataset, normalization_metdata = make_dataset_from_rlds( + config=config, + train=train, + ) + + + + dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train) + dataset = apply_frame_transforms(dataset) + + # this seems to reduce memory usage without affecting speed + + dataset = dataset.shuffle(shuffle_buffer_size) + dataset = dataset.repeat().batch(config.train.batch_size) + dataset = dataset.with_ram_budget(1) + + + pytorch_dataset = RLDSTorchDataset(dataset) + + + return builder, pytorch_dataset, normalization_metdata + diff --git a/robomimic/scripts/train_rlds.py b/robomimic/scripts/train_rlds.py index 7f4da7c6..e8091247 100644 --- a/robomimic/scripts/train_rlds.py +++ b/robomimic/scripts/train_rlds.py @@ -40,8 +40,9 @@ from robomimic.config import config_factory from robomimic.algo import algo_factory, RolloutPolicy from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings -from robomimic.data.rtx_dataset import (make_dataset, get_obs_normalization_stats_rlds, +from robomimic.data.rtx_dataset import (get_obs_normalization_stats_rlds, get_action_normalization_stats_rlds) +from robomimic.data.rtx_dataset_octo import make_single_dataset def train(config, device): """ @@ -70,10 +71,9 @@ def train(config, device): ObsUtils.initialize_obs_utils_with_config(config) # Load the datasets - train_builder, train_loader, normalization_metadata = make_dataset( + train_builder, train_loader, normalization_metadata = make_single_dataset( config, train=True, - shuffle=True ) ds_format = config.train.data_format assert ds_format == 'rlds' @@ -81,11 +81,9 @@ def train(config, device): if config.experiment.validate: # cap num workers for validation dataset at 1 num_workers = min(config.train.num_data_workers, 1) - valid_builder, valid_loader, _ = make_dataset( + valid_builder, valid_loader, _ = make_single_dataset( config, train=True, - shuffle=True, - normalization_metadata=normalization_metadata ) else: