Skip to content

Commit

Permalink
torch.compile each TransformerBlock instead of the whole model (#268)
Browse files Browse the repository at this point in the history
This way we could temporarily enable 2-D parallel compile, and it might
make sense to do transformer block compile in the future with PP (which
we'll see).

We should figure out:
1. dynamic shape issue when turning on 2D parallel
2. full model compile issue for 2D parallel compile
3. cache reusing currently does not work, enable it later
  • Loading branch information
wanchaol authored May 22, 2024
1 parent c5a9718 commit 60810a9
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 24 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ We report our [Performance](docs/performance.md) verified on 64 A100 GPUs


### Coming soon

1. Async checkpointing
2. FP8 support
3. Context Parallel
Expand Down
11 changes: 10 additions & 1 deletion test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,21 @@ def build_test_list(args):
OverrideDefinitions(
[
[
"--training.compile",
"--training.compile --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/1d_compile/",
],
],
"1D compile",
),
OverrideDefinitions(
[
[
"--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/2d_compile/",
],
],
"2D compile",
),
OverrideDefinitions(
[
[
Expand Down
44 changes: 35 additions & 9 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
),
"output": col_parallel_strategy(
input_layouts=Shard(1),
output_layouts=(Shard(-1) if loss_parallel else Replicate()),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
"norm": SequenceParallel(),
Expand Down Expand Up @@ -360,20 +360,49 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

logger.info("Applied Tensor Parallelism to the model")

# apply AC + torch.compile
ac_config = job_config.activation_checkpoint
enable_compile = job_config.training.compile
for layer_id, transformer_block in model.layers.items():
if ac_config.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(transformer_block, ac_config)
if enable_compile:
# turn on per-transformer block compile after AC wrapping and before FSDP
# TODO: dynamic shape have some issues so we turn it off for now.
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
# compile time.
# torch._dynamo.config.inline_inbuilt_nn_modules = True
transformer_block = torch.compile(transformer_block, dynamic=False)
model.layers[layer_id] = transformer_block

if ac_config.mode in ("full", "selective"):
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
if (
enable_compile
and ac_config.mode == "selective"
and ac_config.selective_ac_option == "op"
):
# some temp flags for torch.compile enablement + SAC
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)
if enable_compile:
if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
)
logger.info("Compiled each TransformerBlock with torch.compile")

# apply DP (FSDP2)
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in model.layers.items():
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately.
# When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings
Expand All @@ -387,12 +416,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block

model = fully_shard(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")

return model
14 changes: 0 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,6 @@ def loss_fn(pred, labels):

metric_logger = build_metric_logger(job_config)

# torch.compile model for improved performance
if job_config.training.compile:
if (
job_config.activation_checkpoint.mode == "selective"
and job_config.activation_checkpoint.selective_ac_option == "op"
):
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)
logger.info("Compiling model with torch.compile")
# Dynamic shape have issues with distributed, turn dynamic off as Transformer
# training is static_shape TODO: resolve dynamic shape issue and restore defaults
model = torch.compile(model, dynamic=False)

train_state = TrainState()

# train loop
Expand Down

0 comments on commit 60810a9

Please sign in to comment.