Skip to content

Commit

Permalink
container
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Oct 24, 2024
1 parent 8d486cf commit 9cd3903
Showing 1 changed file with 29 additions and 27 deletions.
56 changes: 29 additions & 27 deletions dolomite_engine/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,33 +234,35 @@ def load_checkpoint_for_training(

dcp.load({"state": _ModelSaver(model_container)}, checkpoint_id=_get_model_path(load_path))

for pipeline_stage, model, optimizer, lr_scheduler in zip(
pipeline_stage_ids_on_current_rank, model_container, optimizer_container, lr_scheduler_container
):
has_teacher_model = model.has_teacher_model()
if has_teacher_model:
log_rank_0(
logging.WARN,
"the model will use non-strict loading of state dict during distillation, this has potential of incorrect behavior",
)

model_state_dict = get_model_state_dict(model)
dcp.load(model_state_dict, checkpoint_id=_get_model_path(load_path, pipeline_stage=pipeline_stage))
set_model_state_dict(model, model_state_dict, options=StateDictOptions(strict=not has_teacher_model))
del model_state_dict

if load_optimizer:
# TODO add options=StateDictOptions(flatten_optimizer_state_dict=True))
optimizer_state_dict = get_optimizer_state_dict(model, optimizer)
dcp.load(optimizer_state_dict, checkpoint_id=_get_optimizer_path(load_path, pipeline_stage=pipeline_stage))
set_optimizer_state_dict(model, optimizer, optim_state_dict=optimizer_state_dict)
del optimizer_state_dict

if load_lr_scheduler:
assert load_optimizer, "load_lr_scheduler requires loading of optimizer"

lr_scheduler.load_state_dict(torch.load(_get_lr_scheduler_path(load_path, pipeline_stage=pipeline_stage)))
elif args.load_args.resume_learning_rate:
# for pipeline_stage, model, optimizer, lr_scheduler in zip(
# pipeline_stage_ids_on_current_rank, model_container, optimizer_container, lr_scheduler_container
# ):
# has_teacher_model = model.has_teacher_model()
# if has_teacher_model:
# log_rank_0(
# logging.WARN,
# "the model will use non-strict loading of state dict during distillation, this has potential of incorrect behavior",
# )

# model_state_dict = get_model_state_dict(model)
# dcp.load(model_state_dict, checkpoint_id=_get_model_path(load_path, pipeline_stage=pipeline_stage))
# set_model_state_dict(model, model_state_dict, options=StateDictOptions(strict=not has_teacher_model))
# del model_state_dict

# if load_optimizer:
# # TODO add options=StateDictOptions(flatten_optimizer_state_dict=True))
# optimizer_state_dict = get_optimizer_state_dict(model, optimizer)
# dcp.load(optimizer_state_dict, checkpoint_id=_get_optimizer_path(load_path, pipeline_stage=pipeline_stage))
# set_optimizer_state_dict(model, optimizer, optim_state_dict=optimizer_state_dict)
# del optimizer_state_dict

if load_lr_scheduler:
assert load_optimizer, "load_lr_scheduler requires loading of optimizer"

for lr_scheduler in lr_scheduler_container:
lr_scheduler.load_state_dict(torch.load(_get_lr_scheduler_path(load_path)))
elif args.load_args.resume_learning_rate:
for optimizer, lr_scheduler in zip(optimizer_container, lr_scheduler_container):
_resume_learning_rate(
args,
optimizer=optimizer,
Expand Down

0 comments on commit 9cd3903

Please sign in to comment.