-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable optional checkpoint at loading #819
Changes from 4 commits
58466d5
759c545
47f914a
673013b
7418f60
a5c0006
8e31858
2fb6f55
c3d2370
096d506
b1f1d5d
582fe7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -418,6 +418,22 @@ def build_test_list(): | |
"test_generate", | ||
ngpu=2, | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
"--training.steps 10", | ||
], | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer", | ||
"--training.tensor_parallel_degree 2", | ||
"--training.steps 20", | ||
], | ||
], | ||
"Optional checkpoint", | ||
"optional_checkpoint", | ||
), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add integration test here, especially for that optional checkpoint at dataloader could avoid dp_degree mismatch error before and after checkpoint |
||
] | ||
return integration_tests_flavors | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -170,11 +170,8 @@ def __init__( | |||||
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening | ||||||
support described in (1). | ||||||
|
||||||
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lr_scheduler flatten at #794 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add a comment here to say the |
||||||
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like | ||||||
optimizers do, so it's hard to write a generic 'flattener' utility. | ||||||
|
||||||
TODO: This is currently unsolved and needs a fix. | ||||||
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers by the ssumption that | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
all lr_schedulers have the same state_dict. | ||||||
""" | ||||||
self.states = states | ||||||
|
||||||
|
@@ -203,6 +200,11 @@ def __init__( | |||||
|
||||||
self.model_weights_only = ckpt_config.model_weights_only | ||||||
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] | ||||||
self.exclude_from_loading = ( | ||||||
[item.strip() for item in ckpt_config.exclude_from_loading] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we do this |
||||||
if ckpt_config.exclude_from_loading | ||||||
else [] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this branch? Isn't it already a list after |
||||||
) | ||||||
|
||||||
self.mp = None | ||||||
if async_mode == AsyncMode.DISABLED: | ||||||
|
@@ -435,10 +437,17 @@ def load(self, step: int = -1) -> bool: | |||||
} | ||||||
logger.info(f"Loading the checkpoint at step {step}.") | ||||||
begin = time.monotonic() | ||||||
shadow_states = { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain more about the naming? I'd call it |
||||||
k: v for k, v in states.items() if k not in self.exclude_from_loading | ||||||
} | ||||||
for exclude_key in self.exclude_from_loading: | ||||||
if exclude_key != "" and exclude_key not in states: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should filter |
||||||
raise ValueError(f"{exclude_key} not found in state_dict, skipping") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do you mean by "skipping" when you |
||||||
dcp.load( | ||||||
states, | ||||||
shadow_states, | ||||||
checkpoint_id=self._create_checkpoint_id(step), | ||||||
) | ||||||
states.update(shadow_states) | ||||||
logger.info( | ||||||
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." | ||||||
) | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -511,6 +511,16 @@ def __init__(self): | |||||
default=-1, | ||||||
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.", | ||||||
) | ||||||
self.parser.add_argument( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently checkpoint.exclude only support excluding at loading, shall we use argument like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, exclude_from_loading is more explicit. |
||||||
"--checkpoint.exclude_from_loading", | ||||||
type=string_list, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we still do |
||||||
default="", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default should be torchtitan/torchtitan/config_manager.py Line 305 in 690f299
If default is "", you'll always end up with [""] after string_split. See https://docs.python.org/3.3/library/stdtypes.html |
||||||
help=""" | ||||||
Exclude specific keys from being loaded from the checkpoint. | ||||||
Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'. | ||||||
This will load the model only, excluding the specified keys. | ||||||
""", | ||||||
) | ||||||
# activation checkpointing configs | ||||||
self.parser.add_argument( | ||||||
"--activation_checkpoint.mode", | ||||||
|
@@ -618,6 +628,13 @@ def parse_args(self, args_list: list = sys.argv[1:]): | |||||
exp["pipeline_parallel_split_points"] = string_list( | ||||||
exp["pipeline_parallel_split_points"] | ||||||
) | ||||||
if ( | ||||||
"checkpoint" in args_dict | ||||||
and "exclude_from_loading" in args_dict["checkpoint"] | ||||||
and isinstance(args_dict["checkpoint"]["exclude_from_loading"], str) | ||||||
): | ||||||
ckpt = args_dict["checkpoint"] | ||||||
ckpt["exclude_from_loading"] = string_list(ckpt["exclude_from_loading"]) | ||||||
|
||||||
# override args dict with cmd_args | ||||||
cmd_args_dict = self._args_to_two_level_dict(cmd_args) | ||||||
|
@@ -665,6 +682,9 @@ def parse_args_from_command_line( | |||||
# since the inferred type is just 'list' and it ends up flattening | ||||||
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] | ||||||
aux_parser.add_argument("--" + arg, type=string_list) | ||||||
elif arg == "checkpoint.exclude_from_loading": | ||||||
# same as above for checkpoint.exclude_from_loading | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
aux_parser.add_argument("--" + arg, type=string_list) | ||||||
else: | ||||||
aux_parser.add_argument("--" + arg, type=type(val)) | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two tests missing:
torchtitan/tests/unit_tests/test_job_config.py
Line 48 in 690f299
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test here with comments. In the optional checkpoint, we save at [dp:4] and load at [dp:2, tp:2], dataloader should be excluded in loading, otherwise would raise error for dp_degree mismatch