diff --git a/dolomite_engine/checkpointing.py b/dolomite_engine/checkpointing.py index abe5749..e7c11c7 100644 --- a/dolomite_engine/checkpointing.py +++ b/dolomite_engine/checkpointing.py @@ -27,7 +27,7 @@ from .data import ResumableDataLoader from .enums import Mode from .hf_models import fix_unsharded_state_dict -from .model_wrapper import ModelWrapper, get_model_container +from .model_wrapper import ModelWrapper from .optimization import get_scheduler_container from .utils import ExperimentsTracker, ProcessGroupManager, load_yaml, log_rank_0, run_rank_n, string_to_torch_dtype @@ -401,8 +401,8 @@ def _resume_learning_rate( # we create lr scheduler again here since optimizer is loaded from disk and lr scheduler is now out of sync # this helps to resume phase 2 - lr_scheduler_tmp = get_scheduler( - optimizer=optimizer, + lr_scheduler_tmp = get_scheduler_container( + optimizer_container=OptimizerContainer([optimizer]), num_warmup_steps=args.lr_scheduler_args.num_warmup_steps, num_constant_steps=args.lr_scheduler_args.num_constant_steps, num_decay_steps=args.lr_scheduler_args.num_decay_steps, @@ -411,7 +411,7 @@ def _resume_learning_rate( lr_decay_factor=args.lr_scheduler_args.lr_decay_factor, extra_lr_scheduler_args=args.lr_scheduler_args.extra_lr_scheduler_args, last_epoch=-1 if iteration is None else iteration - 1, - ) + )[0] for grp, lr_ in zip(optimizer.param_groups, initial_lr): grp["initial_lr"] = lr_