Skip to content

Commit

Permalink
fix: reset before collect on watch (#1240)
Browse files Browse the repository at this point in the history
As title.

- [ ] I have added the correct label(s) to this Pull Request or linked
the relevant issue(s)
- [x] I have provided a description of the changes in this Pull Request
- [ ] I have added documentation for my changes and have listed relevant
changes in CHANGELOG.md
- [ ] If applicable, I have added tests to cover my changes.
- [ ] I have reformatted the code using `poe format` 
- [ ] I have checked style and types with `poe lint` and `poe
type-check`
- [ ] (Optional) I ran tests locally with `poe test` 
(or a subset of them with `poe test-reduced`) ,and they pass
- [ ] (Optional) I have tested that documentation builds correctly with
`poe doc-build`
  • Loading branch information
li-plus authored Jan 24, 2025
1 parent b006cd5 commit f94437a
Show file tree
Hide file tree
Showing 10 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def watch() -> None:
stack_num=args.frames_stack,
)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def watch() -> None:
stack_num=args.frames_stack,
)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def watch() -> None:
stack_num=args.frames_stack,
)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def watch() -> None:
stack_num=args.frames_stack,
)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def watch() -> None:
stack_num=args.frames_stack,
)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def watch() -> None:
stack_num=args.frames_stack,
)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def watch() -> None:
beta=args.beta,
)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def watch() -> None:
stack_num=args.frames_stack,
)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/vizdoom/vizdoom_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def watch() -> None:
stack_num=args.frames_stack,
)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/vizdoom/vizdoom_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def watch() -> None:
stack_num=args.frames_stack,
)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
Expand Down

0 comments on commit f94437a

Please sign in to comment.