Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an issue where policies built with LRSchedulerFactoryLinear were not picklable #992

Merged
merged 4 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,4 @@ MLP
backpropagation
dataclass
superset
picklable
2 changes: 1 addition & 1 deletion test/offline/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Pendulum-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--actor-lr", type=float, default=1e-3)
parser.add_argument("--critic-lr", type=float, default=1e-3)
Expand Down
16 changes: 11 additions & 5 deletions tianshou/highlevel/params/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@ def __init__(self, sampling_config: SamplingConfig):
self.sampling_config = sampling_config

def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
max_update_num = (
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
* self.sampling_config.num_epochs
)
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute)

class _LRLambda:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo you can define a function instead of creating a temporary class

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class is there to cache the second parameter of the function (max_update_num).

def __init__(self, sampling_config: SamplingConfig):
self.max_update_num = (
np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect)
* sampling_config.num_epochs
)

def compute(self, epoch: int) -> float:
return 1.0 - epoch / self.max_update_num
5 changes: 4 additions & 1 deletion tianshou/highlevel/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ class PolicyPersistence:

class Mode(Enum):
POLICY_STATE_DICT = "policy_state_dict"
"""Persist only the policy's state dictionary"""
"""Persist only the policy's state dictionary. Note that for a policy to be restored from
such a dictionary, it is necessary to first create a structurally equivalent object which can
accept the respective state."""
POLICY = "policy"
"""Persist the entire policy. This is larger but has the advantage of the policy being loadable
without requiring an environment to be instantiated.
It has the potential disadvantage that upon breaking code changes in the policy implementation
(e.g. renamed/moved class), it will no longer be loadable.
Note that a precondition is that the policy be picklable in its entirety.
"""

def get_filename(self) -> str:
Expand Down
Loading