From f94437a5ec5e2a0bbccbfed2b6e7ff181144973b Mon Sep 17 00:00:00 2001 From: Jiahao Li Date: Sat, 25 Jan 2025 02:52:15 +0800 Subject: [PATCH] fix: reset before collect on watch (#1240) As title. - [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s) - [x] I have provided a description of the changes in this Pull Request - [ ] I have added documentation for my changes and have listed relevant changes in CHANGELOG.md - [ ] If applicable, I have added tests to cover my changes. - [ ] I have reformatted the code using `poe format` - [ ] I have checked style and types with `poe lint` and `poe type-check` - [ ] (Optional) I ran tests locally with `poe test` (or a subset of them with `poe test-reduced`) ,and they pass - [ ] (Optional) I have tested that documentation builds correctly with `poe doc-build` --- examples/atari/atari_c51.py | 2 +- examples/atari/atari_dqn.py | 2 +- examples/atari/atari_fqf.py | 2 +- examples/atari/atari_iqn.py | 2 +- examples/atari/atari_ppo.py | 2 +- examples/atari/atari_qrdqn.py | 2 +- examples/atari/atari_rainbow.py | 2 +- examples/atari/atari_sac.py | 2 +- examples/vizdoom/vizdoom_c51.py | 2 +- examples/vizdoom/vizdoom_ppo.py | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 587e1b19b..3c757f9b6 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -174,7 +174,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 8d0a2fdfe..023a961f5 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -216,7 +216,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 0a544eeac..c25002613 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -187,7 +187,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 5fa3638b2..8a30ca75d 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -184,7 +184,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index b1a5ae308..2f3832d23 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -244,7 +244,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 9a451f403..e47c08d92 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -178,7 +178,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index ab5ede3cf..5373d0536 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -219,7 +219,7 @@ def watch() -> None: beta=args.beta, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index cb589d83e..df43e49ac 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -227,7 +227,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 072421bc6..0c23e6e8e 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -180,7 +180,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 915bb20a2..6d4f55d14 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -246,7 +246,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name)