diff --git a/README.md b/README.md index a8d1fcc4c..1c32a9d82 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ We report our [Performance](docs/performance.md) verified on 64 A100 GPUs ### Coming soon + 1. Async checkpointing 2. FP8 support 3. Context Parallel diff --git a/test_runner.py b/test_runner.py index dfd4a987e..3bd2770f7 100755 --- a/test_runner.py +++ b/test_runner.py @@ -122,12 +122,21 @@ def build_test_list(args): OverrideDefinitions( [ [ - "--training.compile", + "--training.compile --model.norm_type=rmsnorm", f"--job.dump_folder {args.output_dir}/1d_compile/", ], ], "1D compile", ), + OverrideDefinitions( + [ + [ + "--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", + f"--job.dump_folder {args.output_dir}/2d_compile/", + ], + ], + "2D compile", + ), OverrideDefinitions( [ [ diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 61cf79fe3..5c69ac4ef 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -318,7 +318,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ), "output": col_parallel_strategy( input_layouts=Shard(1), - output_layouts=(Shard(-1) if loss_parallel else Replicate()), + output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, ), "norm": SequenceParallel(), @@ -360,6 +360,40 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): logger.info("Applied Tensor Parallelism to the model") + # apply AC + torch.compile + ac_config = job_config.activation_checkpoint + enable_compile = job_config.training.compile + for layer_id, transformer_block in model.layers.items(): + if ac_config.mode in ("full", "selective"): + transformer_block = checkpoint_wrapper(transformer_block, ac_config) + if enable_compile: + # turn on per-transformer block compile after AC wrapping and before FSDP + # TODO: dynamic shape have some issues so we turn it off for now. + # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate + # compile time. + # torch._dynamo.config.inline_inbuilt_nn_modules = True + transformer_block = torch.compile(transformer_block, dynamic=False) + model.layers[layer_id] = transformer_block + + if ac_config.mode in ("full", "selective"): + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + if ( + enable_compile + and ac_config.mode == "selective" + and ac_config.selective_ac_option == "op" + ): + # some temp flags for torch.compile enablement + SAC + torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( + True + ) + if enable_compile: + if job_config.model.norm_type == "fused_rmsnorm": + raise NotImplementedError( + "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm." + ) + logger.info("Compiled each TransformerBlock with torch.compile") + + # apply DP (FSDP2) if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names @@ -367,13 +401,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) - ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} for layer_id, transformer_block in model.layers.items(): - if job_config.activation_checkpoint.mode in ("full", "selective"): - transformer_block = checkpoint_wrapper( - transformer_block, job_config.activation_checkpoint - ) # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately. # When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings @@ -387,12 +416,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): reshard_after_forward=reshard_after_forward, ) model.layers[layer_id] = transformer_block - model = fully_shard( model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled ) - if ac_mode in ("full", "selective"): - logger.info(f"Applied {ac_mode} activation checkpointing to the model") logger.info("Applied FSDP to the model") return model diff --git a/train.py b/train.py index e13acb3d6..6a8512a4c 100644 --- a/train.py +++ b/train.py @@ -245,20 +245,6 @@ def loss_fn(pred, labels): metric_logger = build_metric_logger(job_config) - # torch.compile model for improved performance - if job_config.training.compile: - if ( - job_config.activation_checkpoint.mode == "selective" - and job_config.activation_checkpoint.selective_ac_option == "op" - ): - torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( - True - ) - logger.info("Compiling model with torch.compile") - # Dynamic shape have issues with distributed, turn dynamic off as Transformer - # training is static_shape TODO: resolve dynamic shape issue and restore defaults - model = torch.compile(model, dynamic=False) - train_state = TrainState() # train loop