From 49c6d6fc15ef644e5c3b1003ad4e0d9ea5fcb9a9 Mon Sep 17 00:00:00 2001 From: yifanmao Date: Thu, 6 Feb 2025 17:22:24 -0800 Subject: [PATCH] Enable optional checkpoint at loading (#819) Add argument "--checkpoint.exclude" to provide users to exclude specific keys from being loaded from the checkpoint. 1. if checkpoint.exclude contains "dataloder", users could load with different dp_degree as dataloader is excluded without resharding 2. if checkpoint.exclude contains "lr_scheduler", lr_scheduler would count from step 0 --- docs/checkpoint.md | 9 ++++ tests/integration_tests.py | 18 ++++++++ tests/unit_tests/test_job_config.py | 71 +++++++++++++++++++++++++++++ torchtitan/checkpoint.py | 17 ++++--- torchtitan/config_manager.py | 23 +++++++++- 5 files changed, 131 insertions(+), 7 deletions(-) diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 05ef6f4d1..50ae42a81 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -64,6 +64,15 @@ Finally, once you have obtained the last checkpoint, you can use the following c python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outputs/checkpoint/step-1000 checkpoint.pt ``` +7. EXCLUDING SPECIFIC KEYS FROM CHECKPOINT LOADING +In some cases, you may want to partially load from a previous-trained checkpoint and modify certain settings, such as the number of GPUs or the current step. To achieve this, you can use the `exclude_from_loading` parameter to specify which keys should be excluded from loading. +This parameter takes a comma-separated list of keys that should be excluded from loading. +``` +[checkpoint] +enable_checkpoint = true +exclude_from_loading = "data_loader,lr_scheduler" +``` + That's it. You have now successfully converted a sharded torchtitan checkpoint for use in torchtune. diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 9c7394ec9..1bdd5df91 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -418,6 +418,24 @@ def build_test_list(): "fsdp_reshard_always", ngpu=2, ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--training.steps 10", + ], + # Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be + # excluded during loading to avoid errors caused by mismatched dp_degree. + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer", + "--training.tensor_parallel_degree 2", + "--training.steps 20", + ], + ], + "Optional checkpoint", + "optional_checkpoint", + ), ] return integration_tests_flavors diff --git a/tests/unit_tests/test_job_config.py b/tests/unit_tests/test_job_config.py index aed007add..ae9bb5635 100644 --- a/tests/unit_tests/test_job_config.py +++ b/tests/unit_tests/test_job_config.py @@ -116,6 +116,77 @@ def test_parse_pp_split_points(self): config.experimental.pipeline_parallel_split_points == cmdline_splits ), config.experimental.pipeline_parallel_split_points + def test_parse_exclude_from_loading(self): + + toml_splits = ["optimizer", "dataloader"] + toml_split_str = ",".join(toml_splits) + cmdline_splits = ["optimizer", "lr_scheduler"] + cmdline_split_str = ",".join(cmdline_splits) + # no split points specified + config = JobConfig() + config.parse_args( + [ + "--job.config_file", + "./train_configs/debug_model.toml", + ] + ) + assert config.checkpoint.exclude_from_loading == [] + + # toml has no split points, but cmdline splits are specified + config = JobConfig() + config.parse_args( + [ + "--job.config_file", + "./train_configs/debug_model.toml", + "--checkpoint.exclude_from_loading", + f"{cmdline_split_str}", + ] + ) + assert ( + config.checkpoint.exclude_from_loading == cmdline_splits + ), config.checkpoint.exclude_from_loading + + # toml has split points, cmdline does not + with tempfile.NamedTemporaryFile() as fp: + with open(fp.name, "wb") as f: + tomli_w.dump( + { + "checkpoint": { + "exclude_from_loading": toml_split_str, + } + }, + f, + ) + config = JobConfig() + config.parse_args(["--job.config_file", fp.name]) + assert ( + config.checkpoint.exclude_from_loading == toml_splits + ), config.checkpoint.exclude_from_loading + + # toml has split points, cmdline overrides them + with tempfile.NamedTemporaryFile() as fp: + with open(fp.name, "wb") as f: + tomli_w.dump( + { + "checkpoint": { + "exclude_from_loading": toml_split_str, + } + }, + f, + ) + config = JobConfig() + config.parse_args( + [ + "--job.config_file", + fp.name, + "--checkpoint.exclude_from_loading", + f"{cmdline_split_str}", + ] + ) + assert ( + config.checkpoint.exclude_from_loading == cmdline_splits + ), config.checkpoint.exclude_from_loading + def test_print_help(self): config = JobConfig() parser = config.parser diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 7d1433830..367f863fd 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -170,11 +170,8 @@ def __init__( which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening support described in (1). - 3. LR schedulers also index model states like optimizers and would need to be flattened properly to support - resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like - optimizers do, so it's hard to write a generic 'flattener' utility. - - TODO: This is currently unsolved and needs a fix. + 3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers with the assumption that + all lr_schedulers have the same state_dict. """ self.states = states @@ -203,6 +200,7 @@ def __init__( self.model_weights_only = ckpt_config.model_weights_only self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] + self.exclude_from_loading = ckpt_config.exclude_from_loading self.mp = None if async_mode == AsyncMode.DISABLED: @@ -435,10 +433,17 @@ def load(self, step: int = -1) -> bool: } logger.info(f"Loading the checkpoint at step {step}.") begin = time.monotonic() + states_to_load = { + k: v for k, v in states.items() if k not in self.exclude_from_loading + } + for exclude_key in self.exclude_from_loading: + if exclude_key not in states: + raise ValueError(f"{exclude_key} not found in state_dict.") dcp.load( - states, + states_to_load, checkpoint_id=self._create_checkpoint_id(step), ) + states.update(states_to_load) logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." ) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2d3024912..3cc630c2e 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -26,7 +26,7 @@ def string_list(raw_arg): - return raw_arg.split(",") + return [s.strip() for s in raw_arg.split(",") if s.strip()] class JobConfig: @@ -529,6 +529,17 @@ def __init__(self): default=-1, help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.", ) + self.parser.add_argument( + "--checkpoint.exclude_from_loading", + type=string_list, + nargs="*", + default=[], + help=""" + Exclude specific keys from being loaded from the checkpoint. + Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'. + This will load the model only, excluding the specified keys. + """, + ) # activation checkpointing configs self.parser.add_argument( "--activation_checkpoint.mode", @@ -636,6 +647,13 @@ def parse_args(self, args_list: list = sys.argv[1:]): exp["pipeline_parallel_split_points"] = string_list( exp["pipeline_parallel_split_points"] ) + if ( + "checkpoint" in args_dict + and "exclude_from_loading" in args_dict["checkpoint"] + and isinstance(args_dict["checkpoint"]["exclude_from_loading"], str) + ): + ckpt = args_dict["checkpoint"] + ckpt["exclude_from_loading"] = string_list(ckpt["exclude_from_loading"]) # override args dict with cmd_args cmd_args_dict = self._args_to_two_level_dict(cmd_args) @@ -683,6 +701,9 @@ def parse_args_from_command_line( # since the inferred type is just 'list' and it ends up flattening # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] aux_parser.add_argument("--" + arg, type=string_list) + elif arg == "checkpoint.exclude_from_loading": + # similar to the case above + aux_parser.add_argument("--" + arg, type=string_list) else: aux_parser.add_argument("--" + arg, type=type(val))