diff --git a/dolomite_engine/checkpointing.py b/dolomite_engine/checkpointing.py index 9462e08..f5dae98 100644 --- a/dolomite_engine/checkpointing.py +++ b/dolomite_engine/checkpointing.py @@ -73,6 +73,28 @@ def load_state_dict(self, state_dict: dict) -> None: set_model_state_dict(model, state_dict[i], options=StateDictOptions(strict=not has_teacher_model)) +class _OptimizerSaver(Stateful): + def __init__(self, model_container: ModelContainer, optimizer_container: OptimizerContainer) -> None: + self.model_container = model_container + self.optimizer_container = optimizer_container + + def state_dict(self): + if self.optimizer_container is None: + return [] + + return [ + get_optimizer_state_dict(model, optimizer) + for model, optimizer in zip(self.model_container, self.optimizer_container) + ] + + def load_state_dict(self, state_dict: dict) -> None: + if self.optimizer_container is None: + return + + for i, (model, optimizer) in enumerate(zip(self.model_container, self.optimizer_container)): + set_optimizer_state_dict(model, optimizer, optim_state_dict=state_dict[i]) + + def save_checkpoint( args: TrainingArgs, model_container: ModelContainer, @@ -104,50 +126,51 @@ def save_checkpoint( save_path = _get_base_path(args.save_args.save_path, iteration) os.makedirs(save_path, exist_ok=True) - model_saver = _ModelSaver(model_container) + savers = {"model": _ModelSaver(model_container)} - if optimizer_container is None: - optimizer_container = [None] * len(model_container) - - if lr_scheduler_container is None: - lr_scheduler_container = [None] * len(model_container) - - assert len(model_container) == len(optimizer_container) - assert len(model_container) == len(lr_scheduler_container) + if save_optimizer: + if optimizer_container is None: + log_rank_0( + logging.WARN, + "optimizer_container is not passed to save_checkpoint but save_optimizer is set to True. " + "Therefore, the function will not save the optimizer", + ) + else: + savers["optimizer"] = _OptimizerSaver(model_container, optimizer_container) _, pipeline_stage_ids_on_current_rank = get_pipeline_num_stages_and_stage_ids_on_current_rank( args.distributed_args.num_pipeline_stages ) - dcp.save(model_saver, checkpoint_id=_get_model_path(save_path)) - - for pipeline_stage, model, optimizer, lr_scheduler in zip( - pipeline_stage_ids_on_current_rank, model_container, optimizer_container, lr_scheduler_container - ): - if save_optimizer: - if optimizer is None: - log_rank_0( - logging.WARN, - "optimizer_container is not passed to save_checkpoint but save_optimizer is set to True. " - "Therefore, the function will not save the optimizer", - ) - else: - # TODO add options=StateDictOptions(flatten_optimizer_state_dict=True)) - dcp.save( - get_optimizer_state_dict(model, optimizer), - checkpoint_id=_get_optimizer_path(save_path, pipeline_stage=pipeline_stage), - ) - - if lr_scheduler is None: - log_rank_0( - logging.WARN, - "lr_scheduler is not passed to save_checkpoint. Therefore, the function will not save the lr_scheduler", - ) - else: - run_rank_n(torch.save)( - lr_scheduler.state_dict(), - _get_lr_scheduler_path(save_path, pipeline_stage=pipeline_stage), - ) + dcp.save(savers, checkpoint_id=_get_model_path(save_path)) + + # for pipeline_stage, model, optimizer, lr_scheduler in zip( + # pipeline_stage_ids_on_current_rank, model_container, optimizer_container, lr_scheduler_container + # ): + # if save_optimizer: + # if optimizer is None: + # log_rank_0( + # logging.WARN, + # "optimizer_container is not passed to save_checkpoint but save_optimizer is set to True. " + # "Therefore, the function will not save the optimizer", + # ) + # else: + # # TODO add options=StateDictOptions(flatten_optimizer_state_dict=True)) + # dcp.save( + # get_optimizer_state_dict(model, optimizer), + # checkpoint_id=_get_optimizer_path(save_path, pipeline_stage=pipeline_stage), + # ) + + # if lr_scheduler is None: + # log_rank_0( + # logging.WARN, + # "lr_scheduler is not passed to save_checkpoint. Therefore, the function will not save the lr_scheduler", + # ) + # else: + # run_rank_n(torch.save)( + # lr_scheduler.state_dict(), + # _get_lr_scheduler_path(save_path, pipeline_stage=pipeline_stage), + # ) rng_state = { "random_rng_state": random.getstate(),