diff --git a/dolomite_engine/checkpointing.py b/dolomite_engine/checkpointing.py index 558d454..9462e08 100644 --- a/dolomite_engine/checkpointing.py +++ b/dolomite_engine/checkpointing.py @@ -18,6 +18,7 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.state_dict_loader import _load_state_dict +from torch.distributed.checkpoint.stateful import Stateful from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR @@ -44,6 +45,34 @@ _KILLSWITCH = "KILLSWITCH" +class _ModelSaver(Stateful): + def __init__(self, model_container: ModelContainer) -> None: + self.model_container = model_container + + def state_dict(self): + result = [] + for model in self.model_container: + state_dict = get_model_state_dict(model) + + if model.has_teacher_model(): + state_dict = _filter_out_teacher_state_dict(state_dict) + + result.append(state_dict) + + return result + + def load_state_dict(self, state_dict: dict) -> None: + for i, model in enumerate(self.model_container): + has_teacher_model = model.has_teacher_model() + if has_teacher_model: + log_rank_0( + logging.WARN, + "the model will use non-strict loading of state dict during distillation, this has potential of incorrect behavior", + ) + + set_model_state_dict(model, state_dict[i], options=StateDictOptions(strict=not has_teacher_model)) + + def save_checkpoint( args: TrainingArgs, model_container: ModelContainer, @@ -75,6 +104,8 @@ def save_checkpoint( save_path = _get_base_path(args.save_args.save_path, iteration) os.makedirs(save_path, exist_ok=True) + model_saver = _ModelSaver(model_container) + if optimizer_container is None: optimizer_container = [None] * len(model_container) @@ -88,15 +119,11 @@ def save_checkpoint( args.distributed_args.num_pipeline_stages ) + dcp.save(model_saver, checkpoint_id=_get_model_path(save_path)) + for pipeline_stage, model, optimizer, lr_scheduler in zip( pipeline_stage_ids_on_current_rank, model_container, optimizer_container, lr_scheduler_container ): - model_state_dict = get_model_state_dict(model) - if model.has_teacher_model(): - model_state_dict = _filter_out_teacher_state_dict(model_state_dict) - - dcp.save(model_state_dict, checkpoint_id=_get_model_path(save_path, pipeline_stage=pipeline_stage)) - if save_optimizer: if optimizer is None: log_rank_0(