Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable optional checkpoint at loading #819

Merged
merged 12 commits into from
Feb 7, 2025
16 changes: 16 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,22 @@ def build_test_list():
"test_generate",
ngpu=2,
),
OverrideDefinitions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two tests missing:

  1. unit tests (on CPU). Passing the test here doesn't mean the behavior is correct. We should add a unit test similar to https://github.com/pytorch/torchtitan/blob/690f299d37c5f6d34273762c0d650888a754d3c0/tests/unit_tests/test_dataset_checkpointing.py
  2. The test here only covers the cmd line arg override, but it could be problematic if specified in toml. We need to add a test similar to
    def test_parse_pp_split_points(self):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test here with comments. In the optional checkpoint, we save at [dp:4] and load at [dp:2, tp:2], dataloader should be excluded in loading, otherwise would raise error for dp_degree mismatch

[
[
"--checkpoint.enable_checkpoint",
"--training.steps 10",
],
[
"--checkpoint.enable_checkpoint",
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
"--training.tensor_parallel_degree 2",
"--training.steps 20",
],
],
"Optional checkpoint",
"optional_checkpoint",
),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add integration test here, especially for that optional checkpoint at dataloader could avoid dp_degree mismatch error before and after checkpoint

]
return integration_tests_flavors

Expand Down
21 changes: 15 additions & 6 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,8 @@ def __init__(
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).

3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
Copy link
Contributor Author

@mori360 mori360 Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lr_scheduler flatten at #794

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a comment here to say the lr_scheduler resharding assumes that all lr_schedulers are the same.

resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
optimizers do, so it's hard to write a generic 'flattener' utility.

TODO: This is currently unsolved and needs a fix.
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers by the ssumption that
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers by the ssumption that
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers with the assumption that

all lr_schedulers have the same state_dict.
"""
self.states = states

Expand Down Expand Up @@ -203,6 +200,11 @@ def __init__(

self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
self.exclude_from_loading = (
[item.strip() for item in ckpt_config.exclude_from_loading]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do this strip in the definition of string_list?

if ckpt_config.exclude_from_loading
else []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this branch? Isn't it already a list after split in string_list?

)

self.mp = None
if async_mode == AsyncMode.DISABLED:
Expand Down Expand Up @@ -435,10 +437,17 @@ def load(self, step: int = -1) -> bool:
}
logger.info(f"Loading the checkpoint at step {step}.")
begin = time.monotonic()
shadow_states = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain more about the naming? I'd call it states_to_load

k: v for k, v in states.items() if k not in self.exclude_from_loading
}
for exclude_key in self.exclude_from_loading:
if exclude_key != "" and exclude_key not in states:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should filter "" (and any empty space) out in string_list

raise ValueError(f"{exclude_key} not found in state_dict, skipping")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by "skipping" when you raise an exception. Technically it should be "failing"?

dcp.load(
states,
shadow_states,
checkpoint_id=self._create_checkpoint_id(step),
)
states.update(shadow_states)
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
)
Expand Down
20 changes: 20 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,16 @@ def __init__(self):
default=-1,
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
)
self.parser.add_argument(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently checkpoint.exclude only support excluding at loading, shall we use argument like exclude_from_loading?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, exclude_from_loading is more explicit.

"--checkpoint.exclude_from_loading",
type=string_list,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we still do .strip and empty check in string_list?

default="",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default should be []? as in


If default is "", you'll always end up with [""] after string_split.
See https://docs.python.org/3.3/library/stdtypes.html

help="""
Exclude specific keys from being loaded from the checkpoint.
Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
This will load the model only, excluding the specified keys.
""",
)
# activation checkpointing configs
self.parser.add_argument(
"--activation_checkpoint.mode",
Expand Down Expand Up @@ -618,6 +628,13 @@ def parse_args(self, args_list: list = sys.argv[1:]):
exp["pipeline_parallel_split_points"] = string_list(
exp["pipeline_parallel_split_points"]
)
if (
"checkpoint" in args_dict
and "exclude_from_loading" in args_dict["checkpoint"]
and isinstance(args_dict["checkpoint"]["exclude_from_loading"], str)
):
ckpt = args_dict["checkpoint"]
ckpt["exclude_from_loading"] = string_list(ckpt["exclude_from_loading"])

# override args dict with cmd_args
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
Expand Down Expand Up @@ -665,6 +682,9 @@ def parse_args_from_command_line(
# since the inferred type is just 'list' and it ends up flattening
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
elif arg == "checkpoint.exclude_from_loading":
# same as above for checkpoint.exclude_from_loading
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# same as above for checkpoint.exclude_from_loading
# similar to the case above

aux_parser.add_argument("--" + arg, type=string_list)
else:
aux_parser.add_argument("--" + arg, type=type(val))

Expand Down
Loading