Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move function make_optimizer_and_scheduler to policy #401

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,31 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:

return loss_dict

def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for ACT"""
optimizer_params_dicts = [
{
"params": [
p
for n, p in self.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in self.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
return optimizer, lr_scheduler


class ACTTemporalEnsembler:
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
Expand Down
19 changes: 19 additions & 0 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,25 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}

def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for Diffusion policy"""
optimizer = torch.optim.Adam(
self.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
from diffusers.optimization import get_scheduler

lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
return optimizer, lr_scheduler


def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Expand Down
6 changes: 6 additions & 0 deletions lerobot/common/policies/tdmpc/modeling_tdmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,12 @@ def update(self):
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)

def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for TD-MPC"""
optimizer = torch.optim.Adam(self.parameters(), cfg.training.lr)
lr_scheduler = None
return optimizer, lr_scheduler


class TDMPCTOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
Expand Down
6 changes: 6 additions & 0 deletions lerobot/common/policies/vqbet/modeling_vqbet.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:

return loss_dict

def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for VQ-BeT"""
optimizer = VQBeTOptimizer(self, cfg)
scheduler = VQBeTScheduler(optimizer, cfg)
return optimizer, scheduler


class SpatialSoftmax(nn.Module):
"""
Expand Down
55 changes: 1 addition & 54 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,59 +51,6 @@
from lerobot.scripts.eval import eval_policy


def make_optimizer_and_scheduler(cfg, policy):
if cfg.policy.name == "act":
optimizer_params_dicts = [
{
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in policy.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
from diffusers.optimization import get_scheduler

lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler

optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg)
else:
raise NotImplementedError()

return optimizer, lr_scheduler


def update_policy(
policy,
batch,
Expand Down Expand Up @@ -334,7 +281,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
assert isinstance(policy, nn.Module)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
optimizer, lr_scheduler = policy.make_optimizer_and_scheduler(cfg)
grad_scaler = GradScaler(enabled=cfg.use_amp)

step = 0 # number of policy updates (forward + backward + optim)
Expand Down
3 changes: 1 addition & 2 deletions tests/scripts/save_policy_to_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.utils import DEFAULT_CONFIG_PATH


Expand All @@ -40,7 +39,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
optimizer, _ = policy.make_optimizer_and_scheduler(cfg)

dataloader = torch.utils.data.DataLoader(
dataset,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel

Expand Down Expand Up @@ -214,7 +213,7 @@ def test_act_backbone_lr():

dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
optimizer, _ = policy.make_optimizer_and_scheduler(cfg)
Comment on lines -217 to +216
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add optimizer, _ = policy.make_optimizer_and_scheduler(cfg) to another place in our unit tests?

It feels like this code logic should be tested for all policies, not just act. Thanks ;)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cadene Should we change test_act_backbone_lr to a more general function for all policies like:

@pytest.mark.parametrize(
    "env_name,policy_name",
    [
        ("pusht", "tdmpc"),
        ("pusht", "diffusion"),
        ("pusht", "vqbet"),
        ("aloha", "act")
    ],
)
def test_policy_backbone_lr(env_name, policy_name):
    """
    Test that the ACT policy can be instantiated with a different learning rate for the backbone.
    """
    cfg = init_hydra_config(
        DEFAULT_CONFIG_PATH,
        overrides=[
            f"env={env_name}",
            f"policy={policy_name}",
            f"device={DEVICE}",
            "training.lr_backbone=0.001",
            "training.lr=0.01",
        ],
    )
....

assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone
Expand Down
Loading