Skip to content

Commit

Permalink
container
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Oct 24, 2024
1 parent d85c577 commit ddf7ced
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions dolomite_engine/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR

Expand All @@ -44,6 +45,34 @@
_KILLSWITCH = "KILLSWITCH"


class _ModelSaver(Stateful):
def __init__(self, model_container: ModelContainer) -> None:
self.model_container = model_container

def state_dict(self):
result = []
for model in self.model_container:
state_dict = get_model_state_dict(model)

if model.has_teacher_model():
state_dict = _filter_out_teacher_state_dict(state_dict)

result.append(state_dict)

return result

def load_state_dict(self, state_dict: dict) -> None:
for i, model in enumerate(self.model_container):
has_teacher_model = model.has_teacher_model()
if has_teacher_model:
log_rank_0(
logging.WARN,
"the model will use non-strict loading of state dict during distillation, this has potential of incorrect behavior",
)

set_model_state_dict(model, state_dict[i], options=StateDictOptions(strict=not has_teacher_model))


def save_checkpoint(
args: TrainingArgs,
model_container: ModelContainer,
Expand Down Expand Up @@ -75,6 +104,8 @@ def save_checkpoint(
save_path = _get_base_path(args.save_args.save_path, iteration)
os.makedirs(save_path, exist_ok=True)

model_saver = _ModelSaver(model_container)

if optimizer_container is None:
optimizer_container = [None] * len(model_container)

Expand All @@ -88,15 +119,11 @@ def save_checkpoint(
args.distributed_args.num_pipeline_stages
)

dcp.save(model_saver, checkpoint_id=_get_model_path(save_path))

for pipeline_stage, model, optimizer, lr_scheduler in zip(
pipeline_stage_ids_on_current_rank, model_container, optimizer_container, lr_scheduler_container
):
model_state_dict = get_model_state_dict(model)
if model.has_teacher_model():
model_state_dict = _filter_out_teacher_state_dict(model_state_dict)

dcp.save(model_state_dict, checkpoint_id=_get_model_path(save_path, pipeline_stage=pipeline_stage))

if save_optimizer:
if optimizer is None:
log_rank_0(
Expand Down

0 comments on commit ddf7ced

Please sign in to comment.