diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 587e1b19b..3c757f9b6 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -174,7 +174,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 8d0a2fdfe..023a961f5 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -216,7 +216,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 0a544eeac..c25002613 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -187,7 +187,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 5fa3638b2..8a30ca75d 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -184,7 +184,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index b1a5ae308..2f3832d23 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -244,7 +244,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 9a451f403..e47c08d92 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -178,7 +178,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index ab5ede3cf..5373d0536 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -219,7 +219,7 @@ def watch() -> None: beta=args.beta, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index cb589d83e..df43e49ac 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -227,7 +227,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 072421bc6..0c23e6e8e 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -180,7 +180,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 915bb20a2..6d4f55d14 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -246,7 +246,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name)