From 72aad159d081e0243d3423faf526a9f4353650dd Mon Sep 17 00:00:00 2001 From: Wanchao Date: Fri, 22 Mar 2024 09:21:03 -0700 Subject: [PATCH] refactor config manager and support cmd overrides (#157) This PR supports explicit cmd overrides, to allow infra layers to override certain options (the most important one is dump_folder) --- test/test_job_config.py | 12 ++ torchtrain/config_manager.py | 193 +++++++++++-------- torchtrain/datasets/tokenizer.py | 1 - torchtrain/parallelisms/parallelize_llama.py | 2 +- 4 files changed, 130 insertions(+), 78 deletions(-) diff --git a/test/test_job_config.py b/test/test_job_config.py index 5dcf7490..6e3a4ec9 100644 --- a/test/test_job_config.py +++ b/test/test_job_config.py @@ -28,3 +28,15 @@ def test_empty_config_file(self): config = JobConfig() config.parse_args(["--job.config_file", fp.name]) assert config.job.description + + def test_job_config_file_cmd_overrides(self): + config = JobConfig() + config.parse_args( + [ + "--job.config_file", + "./train_configs/debug_model.toml", + "--job.dump_folder", + "/tmp/test_tt/", + ] + ) + assert config.job.dump_folder == "/tmp/test_tt/" diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index 5005bef6..815831e0 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -6,13 +6,15 @@ import argparse import sys from collections import defaultdict -from typing import Union +from typing import Tuple, Union try: import tomllib except ModuleNotFoundError: import tomli as tomllib +from torchtrain.logging_utils import logger + class JobConfig: """ @@ -21,98 +23,75 @@ class JobConfig: - Default config is loaded from a toml file. If no toml file is provided, then the default config is loaded from argparse defaults. - if toml file has missing keys, they are filled with argparse defaults. - """ + - if additional explicit cmd args are provided in addition to the toml + file, they will override the toml config and the argparse defaults - def parse_args(self, args_list: list = sys.argv[1:]): - args = JobConfig.init_args_from_command_line(args_list) - config_file = getattr(args, "job.config_file", None) - args_dict = self._args_to_two_level_dict(args) - if config_file is not None: - with open(config_file, "rb") as f: - for k, v in tomllib.load(f).items(): - # to prevent overwrite of non-specified keys - args_dict[k] |= v - for k, v in args_dict.items(): - class_type = type(k.title(), (), v) - setattr(self, k, class_type()) - self._validate_config() + precedence order: cmdline > toml > argparse default - def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict: - args_dict = defaultdict(defaultdict) - for k, v in vars(args).items(): - first_level_key, second_level_key = k.split(".", 1) - args_dict[first_level_key][second_level_key] = v - return args_dict + Arg parsing semantics: - def _validate_config(self): - # TODO: Add more mandatory validations - assert self.model.name and self.model.flavor and self.model.tokenizer_path - return True + Each argument starts with _ which is the section name in the toml file + followed by name of the option in the toml file. For ex, + model.name translates to: + [model] + name + in the toml file + """ - @staticmethod - def init_args_from_command_line( - args_list: list = sys.argv[1:], - ) -> argparse.Namespace: - """ - Each argument starts with _ which is the section name in the toml file - followed by name of the option in the toml file. For ex, - model.name translates to: - [model] - name - in the toml file - """ - parser = argparse.ArgumentParser(description="TorchTrain arg parser.") - parser.add_argument( + def __init__(self): + # main parser + self.parser = argparse.ArgumentParser(description="TorchTrain arg parser.") + self.parser.add_argument( "--job.config_file", type=str, default=None, help="job config file", ) - # misc configs - parser.add_argument( + # job level configs + self.parser.add_argument( "--job.dump_folder", type=str, default="./torchtrain/outputs", help="folder to dump job outputs", ) - parser.add_argument( + self.parser.add_argument( "--job.description", type=str, default="default job", help="description of the job", ) # profiling configs - parser.add_argument( + self.parser.add_argument( "--profiling.run_profiler", action="store_true", help="enable pytorch profiler", ) - parser.add_argument( + self.parser.add_argument( "--profiling.save_traces_folder", type=str, default="profiling/traces", help="trace file location", ) - parser.add_argument( + self.parser.add_argument( "--profiling.profile_every_x_iter", type=int, default=10, help="collect profiler traces every x iterations", ) # metrics configs - parser.add_argument( + self.parser.add_argument( "--metrics.enable_tensorboard", action="store_true", help="whether to log metrics to TensorBoard", ) - parser.add_argument( + self.parser.add_argument( "--metrics.log_freq", type=int, default=10, help="how often to log metrics to TensorBoard", ) - parser.add_argument( + self.parser.add_argument( "--metrics.save_tb_folder", type=str, default="tb", @@ -120,19 +99,19 @@ def init_args_from_command_line( ) # model configs - parser.add_argument( + self.parser.add_argument( "--model.name", type=str, default="llama", help="which model to train", ) - parser.add_argument( + self.parser.add_argument( "--model.flavor", type=str, default="debugmodel", help="which model config to train", ) - parser.add_argument( + self.parser.add_argument( "--model.tokenizer_path", type=str, default="./torchtrain/datasets/tokenizer/tokenizer.model", @@ -140,18 +119,18 @@ def init_args_from_command_line( ) # optimizer configs - parser.add_argument( + self.parser.add_argument( "--optimizer.name", type=str, default="AdamW", help="optimizer to use" ) - parser.add_argument( + self.parser.add_argument( "--optimizer.lr", type=float, default=8e-4, help="learning rate to use" ) # training configs - parser.add_argument( + self.parser.add_argument( "--training.dataset", type=str, default="alpaca", help="dataset to use" ) - parser.add_argument( + self.parser.add_argument( "--training.dataset_path", type=str, help=( @@ -159,60 +138,60 @@ def init_args_from_command_line( "loaded from this path instead of downloaded.", ), ) - parser.add_argument( + self.parser.add_argument( "--training.batch_size", type=int, default=8, help="batch size" ) - parser.add_argument( + self.parser.add_argument( "--training.seq_len", type=int, default=2048, help="sequence length" ) - parser.add_argument( + self.parser.add_argument( "--training.warmup_steps", type=int, default=200, help="steps for lr scheduler warmup", ) - parser.add_argument( + self.parser.add_argument( "--training.max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping", ) - parser.add_argument( + self.parser.add_argument( "--training.steps", type=int, default=10000, help="how many train steps to run", ) - parser.add_argument( + self.parser.add_argument( "--training.data_parallel_degree", type=int, default=-1, help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", ) - parser.add_argument( + self.parser.add_argument( "--training.sequence_parallel_degree", type=int, default=1, help="Sequence Parallelism degree. 1 means disabled.", ) - parser.add_argument( + self.parser.add_argument( "--training.enable_loss_parallel", default=True, action="store_true", help="whether to enable loss parallel when sequence parallel is enabled", ) - parser.add_argument( + self.parser.add_argument( "--training.pipeline_parallel_degree", type=int, default=1, help="Pipeline Parallelism degree (default of 1 means disabled)", ) - parser.add_argument( + self.parser.add_argument( "--training.compile", action="store_true", help="Whether to compile the model.", ) - parser.add_argument( + self.parser.add_argument( "--training.checkpoint_interval", type=int, default=3600, @@ -221,7 +200,7 @@ def init_args_from_command_line( "steps depending on --training.checkpoint-internval-type." ), ) - parser.add_argument( + self.parser.add_argument( "--training.checkpoint_interval_type", type=str, default="steps", @@ -230,7 +209,7 @@ def init_args_from_command_line( "The default value is step." ), ) - parser.add_argument( + self.parser.add_argument( "--training.checkpoint_folder", type=str, default="", @@ -239,7 +218,7 @@ def init_args_from_command_line( "is an empty string, checkpointing is disabled." ), ) - parser.add_argument( + self.parser.add_argument( "--training.fp8_linear", type=str, default="", @@ -249,7 +228,7 @@ def init_args_from_command_line( ], # TODO: add "delayed" option back in when supported help="Type of fp8 linear quantization to apply to the model", ) - parser.add_argument( + self.parser.add_argument( "--training.gc_freq", type=int, default=50, @@ -257,13 +236,13 @@ def init_args_from_command_line( ) # activation checkpointing - parser.add_argument( + self.parser.add_argument( "--activation_checkpoint.mode", type=str, default="selective", help=" ['none', 'full', 'selective'] = type of activation checkpointing to use", ) - parser.add_argument( + self.parser.add_argument( "--activation_checkpoint.selective_ac_option", type=str, default="2", # 2 = checkpoint every other layer @@ -271,13 +250,13 @@ def init_args_from_command_line( ) # communications library settings - parser.add_argument( + self.parser.add_argument( "--comm.init_timeout_seconds", type=int, default=300, help="Timeout for communication operations, during initialization and first train step.", ) - parser.add_argument( + self.parser.add_argument( "--comm.train_timeout_seconds", type=int, default=5, @@ -286,10 +265,72 @@ def init_args_from_command_line( "usually a tighter bound than during initialization." ), ) - parser.add_argument( + self.parser.add_argument( "--comm.trace_buf_size", type=int, default=20000, help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled", ) - return parser.parse_args(args_list) + + def parse_args(self, args_list: list = sys.argv[1:]): + args, cmd_args = self.parse_args_from_command_line(args_list) + config_file = getattr(args, "job.config_file", None) + # build up a two level dict + args_dict = self._args_to_two_level_dict(args) + if config_file is not None: + try: + with open(config_file, "rb") as f: + for k, v in tomllib.load(f).items(): + # to prevent overwrite of non-specified keys + args_dict[k] |= v + except (FileNotFoundError, tomllib.TOMLDecodeError) as e: + logger.info( + f"Error while loading the configuration file: {config_file}" + ) + logger.info(f"Error details: {str(e)}") + raise e + + # override args dict with cmd_args + cmd_args_dict = self._args_to_two_level_dict(cmd_args) + for section, section_args in cmd_args_dict.items(): + for k, v in section_args.items(): + args_dict[section][k] = v + + for k, v in args_dict.items(): + class_type = type(k.title(), (), v) + setattr(self, k, class_type()) + self._validate_config() + + def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict: + args_dict = defaultdict(defaultdict) + for k, v in vars(args).items(): + first_level_key, second_level_key = k.split(".", 1) + args_dict[first_level_key][second_level_key] = v + return args_dict + + def _validate_config(self) -> bool: + # TODO: Add more mandatory validations + assert self.model.name and self.model.flavor and self.model.tokenizer_path + return True + + def parse_args_from_command_line( + self, args_list + ) -> Tuple[argparse.Namespace, argparse.Namespace]: + """ + Parse command line arguments and return the parsed args and the command line only args + """ + args = self.parser.parse_args(args_list) + + # aux parser to parse the command line only args, with no defaults from main parser + aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS) + for arg, val in vars(args).items(): + if isinstance(val, bool): + aux_parser.add_argument( + "--" + arg, action="store_true" if val else "store_false" + ) + else: + aux_parser.add_argument("--" + arg, type=type(val)) + + cmd_args, _ = aux_parser.parse_known_args(args_list) + + return args, cmd_args diff --git a/torchtrain/datasets/tokenizer.py b/torchtrain/datasets/tokenizer.py index 2f2f6fee..f3d500be 100644 --- a/torchtrain/datasets/tokenizer.py +++ b/torchtrain/datasets/tokenizer.py @@ -54,7 +54,6 @@ class SentencePieceTokenizer(TokenizerIf): """ def __init__(self, tokenizer_path: str): - super().__init__(tokenizer_path) # reload tokenizer self.sp_model = SentencePieceProcessor(model_file=tokenizer_path) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index e9fd1cc8..56e4235a 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -44,6 +44,7 @@ torch.ops.c10d_functional.reduce_scatter_tensor.default, } + # Uses PTD FSDP AC wrapper # currently selective per op and per layer checkpointing are supported def checkpoint_wrapper(module, config): @@ -84,7 +85,6 @@ def selective_checkpointing_context_fn(): ) elif config.mode == "selective" and config.selective_ac_option.isdigit(): - """enables selective checkpointing of candidate layers. Usage: 'selective_ac_option' with a positive 'int' value in config controls which layers to checkpoint.