-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable optional checkpoint at loading #819
Conversation
4491e62
to
58466d5
Compare
@@ -169,12 +169,6 @@ def __init__( | |||
into one state dict before saving/loading. We rely on the individual state_dicts to not collide, | |||
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening | |||
support described in (1). | |||
|
|||
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lr_scheduler flatten at #794
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a comment here to say the lr_scheduler
resharding assumes that all lr_schedulers are the same.
@@ -511,6 +511,16 @@ def __init__(self): | |||
default=-1, | |||
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.", | |||
) | |||
self.parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently checkpoint.exclude only support excluding at loading, shall we use argument like exclude_from_loading
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, exclude_from_loading is more explicit.
], | ||
"Optional checkpoint", | ||
"optional_checkpoint", | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add integration test here, especially for that optional checkpoint at dataloader could avoid dp_degree mismatch error before and after checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but please change the comments
@@ -169,12 +169,6 @@ def __init__( | |||
into one state dict before saving/loading. We rely on the individual state_dicts to not collide, | |||
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening | |||
support described in (1). | |||
|
|||
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a comment here to say the lr_scheduler
resharding assumes that all lr_schedulers are the same.
torchtitan/checkpoint.py
Outdated
shadow_states = {k: v for k, v in states.items() if k not in self.exclude} | ||
for exclude_key in self.exclude: | ||
if exclude_key not in states: | ||
logger.warning(f"{exclude_key} not found in state_dict, skipping") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just raise an exception. So a better way to do this is
if not set(self.exclude).issubset(set(states.keys()):
raise ValueError("...")
@@ -511,6 +511,16 @@ def __init__(self): | |||
default=-1, | |||
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.", | |||
) | |||
self.parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, exclude_from_loading is more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general looks good. Had several comments on details.
Plus we need to document the usage in https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md
including the proper use cases mentioned in #809 (comment)
torchtitan/checkpoint.py
Outdated
optimizers do, so it's hard to write a generic 'flattener' utility. | ||
|
||
TODO: This is currently unsolved and needs a fix. | ||
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers by the ssumption that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers by the ssumption that | |
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers with the assumption that |
torchtitan/checkpoint.py
Outdated
@@ -203,6 +200,11 @@ def __init__( | |||
|
|||
self.model_weights_only = ckpt_config.model_weights_only | |||
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] | |||
self.exclude_from_loading = ( | |||
[item.strip() for item in ckpt_config.exclude_from_loading] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we do this strip
in the definition of string_list
?
torchtitan/checkpoint.py
Outdated
self.exclude_from_loading = ( | ||
[item.strip() for item in ckpt_config.exclude_from_loading] | ||
if ckpt_config.exclude_from_loading | ||
else [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this branch? Isn't it already a list after split
in string_list
?
torchtitan/config_manager.py
Outdated
self.parser.add_argument( | ||
"--checkpoint.exclude_from_loading", | ||
type=string_list, | ||
default="", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default should be []
? as in
torchtitan/torchtitan/config_manager.py
Line 305 in 690f299
default=[], |
If default is "", you'll always end up with [""] after string_split.
See https://docs.python.org/3.3/library/stdtypes.html
tests/integration_tests.py
Outdated
@@ -418,6 +418,22 @@ def build_test_list(): | |||
"test_generate", | |||
ngpu=2, | |||
), | |||
OverrideDefinitions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two tests missing:
- unit tests (on CPU). Passing the test here doesn't mean the behavior is correct. We should add a unit test similar to https://github.com/pytorch/torchtitan/blob/690f299d37c5f6d34273762c0d650888a754d3c0/tests/unit_tests/test_dataset_checkpointing.py
- The test here only covers the cmd line arg override, but it could be problematic if specified in toml. We need to add a test similar to
def test_parse_pp_split_points(self):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test here with comments. In the optional checkpoint, we save at [dp:4] and load at [dp:2, tp:2], dataloader should be excluded in loading, otherwise would raise error for dp_degree mismatch
torchtitan/checkpoint.py
Outdated
k: v for k, v in states.items() if k not in self.exclude_from_loading | ||
} | ||
for exclude_key in self.exclude_from_loading: | ||
if exclude_key != "" and exclude_key not in states: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should filter ""
(and any empty space) out in string_list
torchtitan/checkpoint.py
Outdated
} | ||
for exclude_key in self.exclude_from_loading: | ||
if exclude_key != "" and exclude_key not in states: | ||
raise ValueError(f"{exclude_key} not found in state_dict, skipping") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean by "skipping" when you raise
an exception. Technically it should be "failing"?
torchtitan/checkpoint.py
Outdated
@@ -435,10 +437,17 @@ def load(self, step: int = -1) -> bool: | |||
} | |||
logger.info(f"Loading the checkpoint at step {step}.") | |||
begin = time.monotonic() | |||
shadow_states = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain more about the naming? I'd call it states_to_load
torchtitan/config_manager.py
Outdated
@@ -665,6 +682,9 @@ def parse_args_from_command_line( | |||
# since the inferred type is just 'list' and it ends up flattening | |||
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] | |||
aux_parser.add_argument("--" + arg, type=string_list) | |||
elif arg == "checkpoint.exclude_from_loading": | |||
# same as above for checkpoint.exclude_from_loading |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# same as above for checkpoint.exclude_from_loading | |
# similar to the case above |
Can you publish the PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thank you!
please address remaining comments before merging.
torchtitan/checkpoint.py
Outdated
@@ -435,10 +433,17 @@ def load(self, step: int = -1) -> bool: | |||
} | |||
logger.info(f"Loading the checkpoint at step {step}.") | |||
begin = time.monotonic() | |||
state_to_load = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state_to_load = { | |
states_to_load = { |
@@ -511,6 +511,17 @@ def __init__(self): | |||
default=-1, | |||
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.", | |||
) | |||
self.parser.add_argument( | |||
"--checkpoint.exclude_from_loading", | |||
type=string_list, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we still do .strip
and empty check in string_list?
tests/unit_tests/test_job_config.py
Outdated
toml_splits = ["optimizer", "lr_scheduler", "dataloader"] | ||
toml_split_str = ",".join(toml_splits) | ||
cmdline_splits = ["optimizer", "lr_scheduler", "dataloader"] | ||
cmdline_split_str = ",".join(cmdline_splits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the point of having two sets is that we can test override, e.g. in "toml has split points, cmdline overrides them".
So we need to make them different to test robustness.
Add argument "--checkpoint.exclude" to provide users to exclude specific keys from being loaded from the checkpoint.