diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a19949a63..782d8cdba 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -188,3 +188,4 @@ MLP backpropagation dataclass superset +picklable diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 7c8266807..1a7c86ba8 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -27,7 +27,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) - parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--seed", type=int, default=1) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--actor-lr", type=float, default=1e-3) parser.add_argument("--critic-lr", type=float, default=1e-3) diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 286df4cff..0b0cf359a 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -21,8 +21,14 @@ def __init__(self, sampling_config: SamplingConfig): self.sampling_config = sampling_config def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: - max_update_num = ( - np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) - * self.sampling_config.num_epochs - ) - return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute) + + class _LRLambda: + def __init__(self, sampling_config: SamplingConfig): + self.max_update_num = ( + np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect) + * sampling_config.num_epochs + ) + + def compute(self, epoch: int) -> float: + return 1.0 - epoch / self.max_update_num diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 686a26c05..4fc4f9c1e 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -61,12 +61,15 @@ class PolicyPersistence: class Mode(Enum): POLICY_STATE_DICT = "policy_state_dict" - """Persist only the policy's state dictionary""" + """Persist only the policy's state dictionary. Note that for a policy to be restored from + such a dictionary, it is necessary to first create a structurally equivalent object which can + accept the respective state.""" POLICY = "policy" """Persist the entire policy. This is larger but has the advantage of the policy being loadable without requiring an environment to be instantiated. It has the potential disadvantage that upon breaking code changes in the policy implementation (e.g. renamed/moved class), it will no longer be loadable. + Note that a precondition is that the policy be picklable in its entirety. """ def get_filename(self) -> str: