Skip to content

Commit

Permalink
container
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Oct 24, 2024
1 parent ddf7ced commit b7e72bc
Showing 1 changed file with 61 additions and 38 deletions.
99 changes: 61 additions & 38 deletions dolomite_engine/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,28 @@ def load_state_dict(self, state_dict: dict) -> None:
set_model_state_dict(model, state_dict[i], options=StateDictOptions(strict=not has_teacher_model))


class _OptimizerSaver(Stateful):
def __init__(self, model_container: ModelContainer, optimizer_container: OptimizerContainer) -> None:
self.model_container = model_container
self.optimizer_container = optimizer_container

def state_dict(self):
if self.optimizer_container is None:
return []

return [
get_optimizer_state_dict(model, optimizer)
for model, optimizer in zip(self.model_container, self.optimizer_container)
]

def load_state_dict(self, state_dict: dict) -> None:
if self.optimizer_container is None:
return

for i, (model, optimizer) in enumerate(zip(self.model_container, self.optimizer_container)):
set_optimizer_state_dict(model, optimizer, optim_state_dict=state_dict[i])


def save_checkpoint(
args: TrainingArgs,
model_container: ModelContainer,
Expand Down Expand Up @@ -104,50 +126,51 @@ def save_checkpoint(
save_path = _get_base_path(args.save_args.save_path, iteration)
os.makedirs(save_path, exist_ok=True)

model_saver = _ModelSaver(model_container)
savers = {"model": _ModelSaver(model_container)}

if optimizer_container is None:
optimizer_container = [None] * len(model_container)

if lr_scheduler_container is None:
lr_scheduler_container = [None] * len(model_container)

assert len(model_container) == len(optimizer_container)
assert len(model_container) == len(lr_scheduler_container)
if save_optimizer:
if optimizer_container is None:
log_rank_0(
logging.WARN,
"optimizer_container is not passed to save_checkpoint but save_optimizer is set to True. "
"Therefore, the function will not save the optimizer",
)
else:
savers["optimizer"] = _OptimizerSaver(model_container, optimizer_container)

_, pipeline_stage_ids_on_current_rank = get_pipeline_num_stages_and_stage_ids_on_current_rank(
args.distributed_args.num_pipeline_stages
)

dcp.save(model_saver, checkpoint_id=_get_model_path(save_path))

for pipeline_stage, model, optimizer, lr_scheduler in zip(
pipeline_stage_ids_on_current_rank, model_container, optimizer_container, lr_scheduler_container
):
if save_optimizer:
if optimizer is None:
log_rank_0(
logging.WARN,
"optimizer_container is not passed to save_checkpoint but save_optimizer is set to True. "
"Therefore, the function will not save the optimizer",
)
else:
# TODO add options=StateDictOptions(flatten_optimizer_state_dict=True))
dcp.save(
get_optimizer_state_dict(model, optimizer),
checkpoint_id=_get_optimizer_path(save_path, pipeline_stage=pipeline_stage),
)

if lr_scheduler is None:
log_rank_0(
logging.WARN,
"lr_scheduler is not passed to save_checkpoint. Therefore, the function will not save the lr_scheduler",
)
else:
run_rank_n(torch.save)(
lr_scheduler.state_dict(),
_get_lr_scheduler_path(save_path, pipeline_stage=pipeline_stage),
)
dcp.save(savers, checkpoint_id=_get_model_path(save_path))

# for pipeline_stage, model, optimizer, lr_scheduler in zip(
# pipeline_stage_ids_on_current_rank, model_container, optimizer_container, lr_scheduler_container
# ):
# if save_optimizer:
# if optimizer is None:
# log_rank_0(
# logging.WARN,
# "optimizer_container is not passed to save_checkpoint but save_optimizer is set to True. "
# "Therefore, the function will not save the optimizer",
# )
# else:
# # TODO add options=StateDictOptions(flatten_optimizer_state_dict=True))
# dcp.save(
# get_optimizer_state_dict(model, optimizer),
# checkpoint_id=_get_optimizer_path(save_path, pipeline_stage=pipeline_stage),
# )

# if lr_scheduler is None:
# log_rank_0(
# logging.WARN,
# "lr_scheduler is not passed to save_checkpoint. Therefore, the function will not save the lr_scheduler",
# )
# else:
# run_rank_n(torch.save)(
# lr_scheduler.state_dict(),
# _get_lr_scheduler_path(save_path, pipeline_stage=pipeline_stage),
# )

rng_state = {
"random_rng_state": random.getstate(),
Expand Down

0 comments on commit b7e72bc

Please sign in to comment.