diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 3427c4829..70654d880 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -160,6 +160,31 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: return loss_dict + def make_optimizer_and_scheduler(self, cfg): + """Create the optimizer and learning rate scheduler for ACT""" + optimizer_params_dicts = [ + { + "params": [ + p + for n, p in self.named_parameters() + if not n.startswith("model.backbone") and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in self.named_parameters() + if n.startswith("model.backbone") and p.requires_grad + ], + "lr": cfg.training.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay + ) + lr_scheduler = None + return optimizer, lr_scheduler + class ACTTemporalEnsembler: def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 308a8be3c..6d276fa45 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -156,6 +156,25 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: loss = self.diffusion.compute_loss(batch) return {"loss": loss} + def make_optimizer_and_scheduler(self, cfg): + """Create the optimizer and learning rate scheduler for Diffusion policy""" + optimizer = torch.optim.Adam( + self.diffusion.parameters(), + cfg.training.lr, + cfg.training.adam_betas, + cfg.training.adam_eps, + cfg.training.adam_weight_decay, + ) + from diffusers.optimization import get_scheduler + + lr_scheduler = get_scheduler( + cfg.training.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.training.lr_warmup_steps, + num_training_steps=cfg.training.offline_steps, + ) + return optimizer, lr_scheduler + def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: """ diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index d97c4824c..169f67a0a 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -534,6 +534,12 @@ def update(self): # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) + def make_optimizer_and_scheduler(self, cfg): + """Create the optimizer and learning rate scheduler for TD-MPC""" + optimizer = torch.optim.Adam(self.parameters(), cfg.training.lr) + lr_scheduler = None + return optimizer, lr_scheduler + class TDMPCTOLD(nn.Module): """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 87cf59f19..18cf4491e 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -152,6 +152,12 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: return loss_dict + def make_optimizer_and_scheduler(self, cfg): + """Create the optimizer and learning rate scheduler for VQ-BeT""" + optimizer = VQBeTOptimizer(self, cfg) + scheduler = VQBeTScheduler(optimizer, cfg) + return optimizer, scheduler + class SpatialSoftmax(nn.Module): """ diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 45807503f..0c048cfb6 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -51,59 +51,6 @@ from lerobot.scripts.eval import eval_policy -def make_optimizer_and_scheduler(cfg, policy): - if cfg.policy.name == "act": - optimizer_params_dicts = [ - { - "params": [ - p - for n, p in policy.named_parameters() - if not n.startswith("model.backbone") and p.requires_grad - ] - }, - { - "params": [ - p - for n, p in policy.named_parameters() - if n.startswith("model.backbone") and p.requires_grad - ], - "lr": cfg.training.lr_backbone, - }, - ] - optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay - ) - lr_scheduler = None - elif cfg.policy.name == "diffusion": - optimizer = torch.optim.Adam( - policy.diffusion.parameters(), - cfg.training.lr, - cfg.training.adam_betas, - cfg.training.adam_eps, - cfg.training.adam_weight_decay, - ) - from diffusers.optimization import get_scheduler - - lr_scheduler = get_scheduler( - cfg.training.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=cfg.training.lr_warmup_steps, - num_training_steps=cfg.training.offline_steps, - ) - elif policy.name == "tdmpc": - optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) - lr_scheduler = None - elif cfg.policy.name == "vqbet": - from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler - - optimizer = VQBeTOptimizer(policy, cfg) - lr_scheduler = VQBeTScheduler(optimizer, cfg) - else: - raise NotImplementedError() - - return optimizer, lr_scheduler - - def update_policy( policy, batch, @@ -334,7 +281,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No assert isinstance(policy, nn.Module) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy - optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + optimizer, lr_scheduler = policy.make_optimizer_and_scheduler(cfg) grad_scaler = GradScaler(enabled=cfg.use_amp) step = 0 # number of policy updates (forward + backward + optim) diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 5236b7ae5..033638774 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -22,7 +22,6 @@ from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.factory import make_policy from lerobot.common.utils.utils import init_hydra_config, set_global_seed -from lerobot.scripts.train import make_optimizer_and_scheduler from tests.utils import DEFAULT_CONFIG_PATH @@ -40,7 +39,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides): dataset = make_dataset(cfg) policy = make_policy(cfg, dataset_stats=dataset.stats) policy.train() - optimizer, _ = make_optimizer_and_scheduler(cfg, policy) + optimizer, _ = policy.make_optimizer_and_scheduler(cfg) dataloader = torch.utils.data.DataLoader( dataset, diff --git a/tests/test_policies.py b/tests/test_policies.py index d90f00716..76a056d24 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -37,7 +37,6 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config, seeded_context -from lerobot.scripts.train import make_optimizer_and_scheduler from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -214,7 +213,7 @@ def test_act_backbone_lr(): dataset = make_dataset(cfg) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) - optimizer, _ = make_optimizer_and_scheduler(cfg, policy) + optimizer, _ = policy.make_optimizer_and_scheduler(cfg) assert len(optimizer.param_groups) == 2 assert optimizer.param_groups[0]["lr"] == cfg.training.lr assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone