Skip to content

Commit

Permalink
container
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Oct 24, 2024
1 parent 8fbd1f5 commit e746a4b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions dolomite_engine/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .data import ResumableDataLoader
from .enums import Mode
from .hf_models import fix_unsharded_state_dict
from .model_wrapper import ModelWrapper, get_model_container
from .model_wrapper import ModelWrapper
from .optimization import get_scheduler_container
from .utils import ExperimentsTracker, ProcessGroupManager, load_yaml, log_rank_0, run_rank_n, string_to_torch_dtype

Expand Down Expand Up @@ -401,8 +401,8 @@ def _resume_learning_rate(

# we create lr scheduler again here since optimizer is loaded from disk and lr scheduler is now out of sync
# this helps to resume phase 2
lr_scheduler_tmp = get_scheduler(
optimizer=optimizer,
lr_scheduler_tmp = get_scheduler_container(
optimizer_container=OptimizerContainer([optimizer]),
num_warmup_steps=args.lr_scheduler_args.num_warmup_steps,
num_constant_steps=args.lr_scheduler_args.num_constant_steps,
num_decay_steps=args.lr_scheduler_args.num_decay_steps,
Expand All @@ -411,7 +411,7 @@ def _resume_learning_rate(
lr_decay_factor=args.lr_scheduler_args.lr_decay_factor,
extra_lr_scheduler_args=args.lr_scheduler_args.extra_lr_scheduler_args,
last_epoch=-1 if iteration is None else iteration - 1,
)
)[0]

for grp, lr_ in zip(optimizer.param_groups, initial_lr):
grp["initial_lr"] = lr_
Expand Down

0 comments on commit e746a4b

Please sign in to comment.