Skip to content

Commit

Permalink
enable optional checkpoint at loading
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Feb 4, 2025
1 parent 26abff7 commit 4491e62
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
18 changes: 11 additions & 7 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,6 @@ def __init__(
into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
optimizers do, so it's hard to write a generic 'flattener' utility.
TODO: This is currently unsolved and needs a fix.
"""
self.states = states

Expand Down Expand Up @@ -203,6 +197,11 @@ def __init__(

self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
self.exclude = (
[item.strip() for item in ckpt_config.exclude.split]
if ckpt_config.exclude
else []
)

self.mp = None
if async_mode == AsyncMode.DISABLED:
Expand Down Expand Up @@ -435,10 +434,15 @@ def load(self, step: int = -1) -> bool:
}
logger.info(f"Loading the checkpoint at step {step}.")
begin = time.monotonic()
shadow_states = {k: v for k, v in states.items() if k not in self.exclude}
for exclude_key in self.exclude:
if exclude_key not in states:
logger.warning(f"{exclude_key} not found in state_dict, skipping")
dcp.load(
states,
shadow_states,
checkpoint_id=self._create_checkpoint_id(step),
)
states.update(shadow_states)
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
)
Expand Down
10 changes: 10 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,16 @@ def __init__(self):
default=-1,
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
)
self.parser.add_argument(
"--checkpoint.exclude",
type=string_list,
default="",
help="""
Exclude specific keys from being loaded from the checkpoint.
Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
This will load the model only, excluding the specified keys.
""",
)
# activation checkpointing configs
self.parser.add_argument(
"--activation_checkpoint.mode",
Expand Down

0 comments on commit 4491e62

Please sign in to comment.