Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drop fsdp-1 saving method #51

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 37 additions & 86 deletions dolomite_engine/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR

Expand Down Expand Up @@ -69,44 +66,22 @@ def save_checkpoint(
save_path = _get_base_path(args.save_args.save_path, iteration)
os.makedirs(save_path, exist_ok=True)

if args.distributed_args.fsdp_algorithm == 1:
dp_rank = ProcessGroupManager.get_data_parallel_rank()

# TODO add support for local state dict
with FSDP.state_dict_type(
model,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
model_state_dict = model.state_dict()
if model.has_teacher_model():
model_state_dict = _filter_out_teacher_state_dict(model_state_dict)

if dp_rank == 0:
torch.save(model_state_dict, f"{_get_model_path(save_path)}.pt")

if save_optimizer:
optimizer_state_dict = FSDP.optim_state_dict(model=model, optim=optimizer)
if dp_rank == 0:
torch.save(optimizer_state_dict, f"{_get_optimizer_path(save_path)}.pt")
else:
model_state_dict = get_model_state_dict(model)
if model.has_teacher_model():
model_state_dict = _filter_out_teacher_state_dict(model_state_dict)

dcp.save(model_state_dict, checkpoint_id=_get_model_path(save_path))

if save_optimizer:
if optimizer is None:
log_rank_0(
logging.WARN,
"optimizer is not passed to save_checkpoint but save_optimizer is set to True. "
"Therefore, the function will not save the optimizer",
)
else:
# TODO add options=StateDictOptions(flatten_optimizer_state_dict=True))
dcp.save(get_optimizer_state_dict(model, optimizer), checkpoint_id=_get_optimizer_path(save_path))
model_state_dict = get_model_state_dict(model)
if model.has_teacher_model():
model_state_dict = _filter_out_teacher_state_dict(model_state_dict)

dcp.save(model_state_dict, checkpoint_id=_get_model_path(save_path))

if save_optimizer:
if optimizer is None:
log_rank_0(
logging.WARN,
"optimizer is not passed to save_checkpoint but save_optimizer is set to True. "
"Therefore, the function will not save the optimizer",
)
else:
# TODO add options=StateDictOptions(flatten_optimizer_state_dict=True))
dcp.save(get_optimizer_state_dict(model, optimizer), checkpoint_id=_get_optimizer_path(save_path))

if lr_scheduler is None:
log_rank_0(
Expand Down Expand Up @@ -204,38 +179,17 @@ def load_checkpoint_for_training(
"the model will use non-strict loading of state dict during distillation, this has potential of incorrect behavior",
)

if args.distributed_args.fsdp_algorithm == 1:
# TODO add support for local state dict
with FSDP.state_dict_type(
model,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
):
model.load_state_dict(
torch.load(f"{_get_model_path(load_path)}.pt", map_location="cpu"), strict=not has_teacher_model
)

if load_optimizer:
optimizer.load_state_dict(
FSDP.optim_state_dict_to_load(
model=model,
optim=optimizer,
optim_state_dict=torch.load(f"{_get_optimizer_path(load_path)}.pt", map_location="cpu"),
)
)
else:
model_state_dict = get_model_state_dict(model)
dcp.load(model_state_dict, checkpoint_id=_get_model_path(load_path))
set_model_state_dict(model, model_state_dict, options=StateDictOptions(strict=not has_teacher_model))
del model_state_dict
model_state_dict = get_model_state_dict(model)
dcp.load(model_state_dict, checkpoint_id=_get_model_path(load_path))
set_model_state_dict(model, model_state_dict, options=StateDictOptions(strict=not has_teacher_model))
del model_state_dict

if load_optimizer:
# TODO add options=StateDictOptions(flatten_optimizer_state_dict=True))
optimizer_state_dict = get_optimizer_state_dict(model, optimizer)
dcp.load(optimizer_state_dict, checkpoint_id=_get_optimizer_path(load_path))
set_optimizer_state_dict(model, optimizer, optim_state_dict=optimizer_state_dict)
del optimizer_state_dict
if load_optimizer:
# TODO add options=StateDictOptions(flatten_optimizer_state_dict=True))
optimizer_state_dict = get_optimizer_state_dict(model, optimizer)
dcp.load(optimizer_state_dict, checkpoint_id=_get_optimizer_path(load_path))
set_optimizer_state_dict(model, optimizer, optim_state_dict=optimizer_state_dict)
del optimizer_state_dict

if load_lr_scheduler:
assert load_optimizer, "load_lr_scheduler requires loading of optimizer"
Expand Down Expand Up @@ -320,21 +274,18 @@ def load_checkpoint_for_inference(
if use_meta:
model = model.to_empty(device="cpu")

if args_from_checkpoint.distributed_args.fsdp_algorithm == 1:
state = torch.load(f"{_get_model_path(_get_base_path(load_path, iteration))}.pt", map_location="cpu")
else:
state = {}
_load_state_dict(
state,
storage_reader=FileSystemReader(_get_model_path(_get_base_path(load_path, iteration))),
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)
state = {}
_load_state_dict(
state,
storage_reader=FileSystemReader(_get_model_path(_get_base_path(load_path, iteration))),
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)

if checkpoint_tp_world_size > 1:
state = fix_unsharded_state_dict(
model.config, state, tensor_parallel_size=checkpoint_tp_world_size, prefix="model."
)
if checkpoint_tp_world_size > 1:
state = fix_unsharded_state_dict(
model.config, state, tensor_parallel_size=checkpoint_tp_world_size, prefix="model."
)

was_compiled_model = args_from_checkpoint.distributed_args.torch_compile

Expand Down
Loading