Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rlds mem #129

Open
wants to merge 2 commits into
base: r2d2_rlds
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
_libgcc_mutex=0.1=main
_openmp_mutex=5.1=1_gnu
absl-py=1.4.0=pypi_0
array-record=0.5.0=pypi_0
astunparse=1.6.3=pypi_0
attrs=23.2.0=pypi_0
black=23.12.1=pypi_0
ca-certificates=2023.12.12=h06a4308_0
cachetools=5.3.2=pypi_0
certifi=2023.11.17=pypi_0
charset-normalizer=3.3.2=pypi_0
click=8.1.7=pypi_0
contourpy=1.2.0=pypi_0
cycler=0.12.1=pypi_0
detr=0.0.0=dev_0
diffusers=0.11.1=pypi_0
dlimp=0.0.1=dev_0
dm-reverb=0.14.0=pypi_0
dm-tree=0.1.8=pypi_0
egl-probe=1.0.2=pypi_0
etils=1.5.2=pypi_0
filelock=3.13.1=pypi_0
flake8=7.0.0=pypi_0
flake8-bugbear=23.12.2=pypi_0
flake8-comprehensions=3.14.0=pypi_0
flatbuffers=23.5.26=pypi_0
fonttools=4.47.0=pypi_0
fsspec=2023.12.2=pypi_0
gast=0.5.4=pypi_0
google-auth=2.26.1=pypi_0
google-auth-oauthlib=1.2.0=pypi_0
google-pasta=0.2.0=pypi_0
googleapis-common-protos=1.62.0=pypi_0
grpcio=1.60.0=pypi_0
h5py=3.10.0=pypi_0
huggingface-hub=0.20.2=pypi_0
idna=3.6=pypi_0
imageio=2.33.1=pypi_0
imageio-ffmpeg=0.4.9=pypi_0
importlib-metadata=7.0.1=pypi_0
importlib-resources=6.1.1=pypi_0
jinja2=3.1.2=pypi_0
keras=2.15.0=pypi_0
kiwisolver=1.4.5=pypi_0
lazy-loader=0.3=pypi_0
ld_impl_linux-64=2.38=h1181459_1
libclang=16.0.6=pypi_0
libcst=1.1.0=pypi_0
libffi=3.3=he6710b0_2
libgcc-ng=11.2.0=h1234567_1
libgomp=11.2.0=h1234567_1
libstdcxx-ng=11.2.0=h1234567_1
markdown=3.5.1=pypi_0
markupsafe=2.1.3=pypi_0
matplotlib=3.8.2=pypi_0
mccabe=0.7.0=pypi_0
ml-dtypes=0.2.0=pypi_0
moreorless=0.4.0=pypi_0
mpmath=1.3.0=pypi_0
mypy-extensions=1.0.0=pypi_0
ncurses=6.4=h6a678d5_0
networkx=3.2.1=pypi_0
numpy=1.26.3=pypi_0
nvidia-cublas-cu12=12.1.3.1=pypi_0
nvidia-cuda-cupti-cu12=12.1.105=pypi_0
nvidia-cuda-nvrtc-cu12=12.1.105=pypi_0
nvidia-cuda-runtime-cu12=12.1.105=pypi_0
nvidia-cudnn-cu12=8.9.2.26=pypi_0
nvidia-cufft-cu12=11.0.2.54=pypi_0
nvidia-curand-cu12=10.3.2.106=pypi_0
nvidia-cusolver-cu12=11.4.5.107=pypi_0
nvidia-cusparse-cu12=12.1.0.106=pypi_0
nvidia-nccl-cu12=2.18.1=pypi_0
nvidia-nvjitlink-cu12=12.3.101=pypi_0
nvidia-nvtx-cu12=12.1.105=pypi_0
oauthlib=3.2.2=pypi_0
opencv-python=4.9.0.80=pypi_0
openssl=1.1.1w=h7f8727e_0
opt-einsum=3.3.0=pypi_0
packaging=23.2=pypi_0
pathspec=0.12.1=pypi_0
pillow=10.2.0=pypi_0
pip=23.3.1=py39h06a4308_0
platformdirs=4.1.0=pypi_0
plotly=5.18.0=pypi_0
portpicker=1.6.0=pypi_0
promise=2.3=pypi_0
protobuf=3.20.3=pypi_0
psutil=5.9.7=pypi_0
pyasn1=0.5.1=pypi_0
pyasn1-modules=0.3.0=pypi_0
pycodestyle=2.11.1=pypi_0
pycosat=0.6.6=pypi_0
pyflakes=3.2.0=pypi_0
pyparsing=3.1.1=pypi_0
python=3.9.0=hdb3f193_2
python-dateutil=2.8.2=pypi_0
pyyaml=6.0.1=pypi_0
readline=8.2=h5eee18b_0
regex=2023.12.25=pypi_0
requests=2.31.0=pypi_0
requests-oauthlib=1.3.1=pypi_0
rlds=0.1.8=pypi_0
robomimic=0.3.0=dev_0
rsa=4.9=pypi_0
scikit-image=0.22.0=pypi_0
scipy=1.11.4=pypi_0
setuptools=68.2.2=py39h06a4308_0
six=1.16.0=pypi_0
sqlite=3.41.2=h5eee18b_0
stdlibs=2023.12.15=pypi_0
sympy=1.12=pypi_0
tenacity=8.2.3=pypi_0
tensorboard=2.15.1=pypi_0
tensorboard-data-server=0.7.2=pypi_0
tensorboardx=2.6.2.2=pypi_0
tensorflow=2.15.0=pypi_0
tensorflow-datasets=4.9.3=pypi_0
tensorflow-estimator=2.15.0=pypi_0
tensorflow-io-gcs-filesystem=0.35.0=pypi_0
tensorflow-metadata=1.14.0=pypi_0
termcolor=2.4.0=pypi_0
tifffile=2023.12.9=pypi_0
tk=8.6.12=h1ccaba5_0
toml=0.10.2=pypi_0
tomli=2.0.1=pypi_0
torch=2.1.2=pypi_0
torchvision=0.16.2=pypi_0
tqdm=4.66.1=pypi_0
trailrunner=1.4.0=pypi_0
triton=2.1.0=pypi_0
typing-extensions=4.9.0=pypi_0
typing-inspect=0.9.0=pypi_0
tzdata=2023d=h04d1e81_0
urllib3=2.1.0=pypi_0
usort=1.0.7=pypi_0
werkzeug=3.0.1=pypi_0
wheel=0.41.2=py39h06a4308_0
wrapt=1.14.1=pypi_0
xz=5.4.5=h5eee18b_0
zipp=3.17.0=pypi_0
zlib=1.2.13=h5eee18b_0
135 changes: 121 additions & 14 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,122 @@
numpy>=1.13.3
h5py
psutil
tqdm
termcolor
tensorboard
tensorboardX
imageio
imageio-ffmpeg
matplotlib
egl_probe>=1.0.1
torch
torchvision
absl-py==1.4.0
array-record==0.5.0
astunparse==1.6.3
attrs==23.2.0
black==23.12.1
cachetools==5.3.2
certifi==2023.11.17
charset-normalizer==3.3.2
click==8.1.7
contourpy==1.2.0
cycler==0.12.1
-e git+https://github.com/tonyzhaozh/act@73071e16a6595662d753415b90c0abb64815009c#egg=detr&subdirectory=../../act/detr
diffusers==0.11.1
pytorch3d==0.7.3
-e git+https://github.com/kvablack/dlimp@ad72ce3a9b414db2185bc0b38461d4101a65477a#egg=dlimp
dm-reverb==0.14.0
dm-tree==0.1.8
egl-probe==1.0.2
etils==1.5.2
filelock==3.13.1
flake8==7.0.0
flake8-bugbear==23.12.2
flake8-comprehensions==3.14.0
flatbuffers==23.5.26
fonttools==4.47.0
fsspec==2023.12.2
gast==0.5.4
google-auth==2.26.1
google-auth-oauthlib==1.2.0
google-pasta==0.2.0
googleapis-common-protos==1.62.0
grpcio==1.60.0
h5py==3.10.0
huggingface-hub==0.20.2
idna==3.6
imageio==2.33.1
imageio-ffmpeg==0.4.9
importlib-metadata==7.0.1
importlib-resources==6.1.1
Jinja2==3.1.2
keras==2.15.0
kiwisolver==1.4.5
lazy_loader==0.3
libclang==16.0.6
libcst==1.1.0
Markdown==3.5.1
MarkupSafe==2.1.3
matplotlib==3.8.2
mccabe==0.7.0
ml-dtypes==0.2.0
moreorless==0.4.0
mpmath==1.3.0
mypy-extensions==1.0.0
networkx==3.2.1
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
opencv-python==4.9.0.80
opt-einsum==3.3.0
packaging==23.2
pathspec==0.12.1
pillow==10.2.0
platformdirs==4.1.0
plotly==5.18.0
portpicker==1.6.0
promise==2.3
protobuf==3.20.3
psutil==5.9.7
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycodestyle==2.11.1
pycosat==0.6.6
pyflakes==3.2.0
pyparsing==3.1.1
python-dateutil==2.8.2
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
requests-oauthlib==1.3.1
rlds==0.1.8
-e git+https://github.com/ARISE-Initiative/robomimic.git@dbe18cc3f2623a6e73ad1353e55de6e1266aabe1#egg=robomimic
rsa==4.9
scikit-image==0.22.0
scipy==1.11.4
six==1.16.0
stdlibs==2023.12.15
sympy==1.12
tenacity==8.2.3
tensorboard==2.15.1
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
tensorflow==2.15.0
tensorflow-datasets==4.9.3
tensorflow-estimator==2.15.0
tensorflow-io-gcs-filesystem==0.35.0
tensorflow-metadata==1.14.0
termcolor==2.4.0
tifffile==2023.12.9
toml==0.10.2
tomli==2.0.1
torch==2.1.2
torchvision==0.16.2
tqdm==4.66.1
trailrunner==1.4.0
triton==2.1.0
typing-inspect==0.9.0
typing_extensions==4.9.0
urllib3==2.1.0
usort==1.0.7
Werkzeug==3.0.1
wrapt==1.14.1
zipp==3.17.0
11 changes: 9 additions & 2 deletions robomimic/data/rtx_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ def step_map_fn(traj: Dict[str, Any]) -> Dict[str, Any]:
).build(validate_expected_tensor_spec=False)

trajectory_dataset = trajectory_transform.transform_episodic_rlds_dataset(dataset)


# combined_dataset = tf.data.Dataset.sample_from_datasets([trajectory_dataset])
# combined_dataset = combined_dataset.batch(2)
Expand All @@ -682,8 +683,14 @@ def step_map_fn(traj: Dict[str, Any]) -> Dict[str, Any]:
dataset = trajectory_dataset
# shuffle, repeat, pre-fetch, batch
# dataset = dataset.cache() # optionally keep full dataset in memory
dataset = dataset.shuffle(1000) # set shuffle buffer size
dataset = dataset.repeat().batch(config.train.batch_size).prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.shuffle(10000) # set shuffle buffer size
dataset = dataset.repeat().batch(config.train.batch_size)#.prefetch(tf.data.experimental.AUTOTUNE)

# memory management
# options = tf.data.Options()
# options.autotune.ram_budget = 1024 * 1024 * 1024
# dataset = dataset.with_options(options)

dataset = dataset.as_numpy_iterator()
dataset = RLDSTorchDataset(dataset)

Expand Down
Loading