Skip to content

Commit

Permalink
Makes FSDP reshard_after_forward configurable (#822)
Browse files Browse the repository at this point in the history
# What does this PR do?
This PR addresses #644,
making FSDP `reshard_after_forward_policy` configurable by introducing:
1) `data_parallel` as a group within the `job_config`
2) `data_parallel.reshard_after_forward_policy` as an option within
`job_config`
3) routing `reshard_after_forward_policy` from job config to
`apply_fsdp` in
[parallelize_llama.py](torchtitan/parallelisms/parallelize_llama.py)
4) applying the defaults to existing job configs
5) adds an integration test that tests this configuration, i.e. setting
it to always

# Why do we need this PR?
Rather than hardcode the `reshard_after_forward_policy`, this allows
users to explore more memory/communication tradeoffs in more complex
sharding strategies.

# Next steps
As this PR introduces `data_parallel` group, a natural next step is to
convert other data parallel related parameters into this group as well.
Specifically, this includes `training.data_parallel_replicate_degree`
and `training.data_parallel_shard_degree`. I will convert these to
`data_parallel.replicate_degree` and `data_parallel.shard_degree` in a
future PR.

# Tests
Ran Llama3 8B model with all options, default / never / always on a
machine with 8 H100s. We should expect:
1) similar convergence between runs (using the same seed)
2) `never` > `default` > `always` in terms of memory usage

## Default
```
...
[rank0]:2025-02-05 11:44:41,670 - root - INFO - Applied FSDP to the model
[rank0]:2025-02-05 11:44:41,886 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank0]:2025-02-05 11:44:41,889 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250205-1144
[rank0]:2025-02-05 11:44:41,889 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 50 (warmup 200)
[rank0]:2025-02-05 11:44:41,889 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:NCCL version 2.21.5+cuda12.4
[rank0]:2025-02-05 11:44:54,986 - root - INFO - step:  1  loss: 12.2302  memory: 42.08GiB(44.30%)  tps: 626  mfu: 3.66%
[rank0]:2025-02-05 11:44:54,986 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 11:45:08,305 - root - INFO - step: 10  loss:  9.8873  memory: 49.59GiB(52.19%)  tps: 5,536  mfu: 32.42%
[rank0]:2025-02-05 11:45:22,895 - root - INFO - step: 20  loss:  8.4177  memory: 49.59GiB(52.19%)  tps: 5,615  mfu: 32.88%
[rank0]:2025-02-05 11:45:37,506 - root - INFO - step: 30  loss:  7.6932  memory: 49.59GiB(52.19%)  tps: 5,607  mfu: 32.84%
[rank0]:2025-02-05 11:45:52,129 - root - INFO - step: 40  loss:  7.3155  memory: 49.59GiB(52.19%)  tps: 5,603  mfu: 32.81%
[rank0]:2025-02-05 11:46:06,747 - root - INFO - step: 50  loss:  7.0608  memory: 49.59GiB(52.19%)  tps: 5,605  mfu: 32.82%
[rank0]:2025-02-05 11:46:06,748 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2025-02-05 11:46:08,750 - root - INFO - Training completed
...
```

## Always
```
...
[rank0]:2025-02-05 11:49:18,382 - root - INFO - Applied FSDP to the model
[rank0]:2025-02-05 11:49:18,617 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank0]:2025-02-05 11:49:18,619 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250205-1149
[rank0]:2025-02-05 11:49:18,620 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 50 (warmup 200)
[rank0]:2025-02-05 11:49:18,620 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:NCCL version 2.21.5+cuda12.4
[rank0]:2025-02-05 11:49:32,445 - root - INFO - step:  1  loss: 12.2302  memory: 41.89GiB(44.09%)  tps: 593  mfu: 3.47%
[rank0]:2025-02-05 11:49:32,445 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 11:49:45,525 - root - INFO - step: 10  loss:  9.8874  memory: 49.41GiB(52.01%)  tps: 5,637  mfu: 33.01%
[rank0]:2025-02-05 11:50:00,097 - root - INFO - step: 20  loss:  8.4177  memory: 49.41GiB(52.01%)  tps: 5,623  mfu: 32.93%
[rank0]:2025-02-05 11:50:14,719 - root - INFO - step: 30  loss:  7.6932  memory: 49.41GiB(52.01%)  tps: 5,603  mfu: 32.81%
[rank0]:2025-02-05 11:50:29,348 - root - INFO - step: 40  loss:  7.3146  memory: 49.41GiB(52.01%)  tps: 5,600  mfu: 32.80%
[rank0]:2025-02-05 11:50:43,995 - root - INFO - step: 50  loss:  7.0576  memory: 49.41GiB(52.01%)  tps: 5,593  mfu: 32.75%
[rank0]:2025-02-05 11:48:19,830 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2025-02-05 11:48:21,832 - root - INFO - Training completed
...
```

## Never
```
...
[rank0]:2025-02-05 11:46:54,548 - root - INFO - Applied FSDP to the model
[rank0]:2025-02-05 11:46:55,574 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank0]:2025-02-05 11:46:55,610 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250205-1146
[rank0]:2025-02-05 11:46:55,611 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 50 (warmup 200)
[rank0]:2025-02-05 11:46:55,611 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:NCCL version 2.21.5+cuda12.4
[rank0]:2025-02-05 11:47:10,242 - root - INFO - step:  1  loss: 12.2302  memory: 54.66GiB(57.54%)  tps: 560  mfu: 3.28%
[rank0]:2025-02-05 11:47:10,242 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 11:47:23,033 - root - INFO - step: 10  loss:  9.8872  memory: 62.18GiB(65.45%)  tps: 5,765  mfu: 33.76%
[rank0]:2025-02-05 11:47:37,183 - root - INFO - step: 20  loss:  8.4177  memory: 62.18GiB(65.45%)  tps: 5,790  mfu: 33.91%
[rank0]:2025-02-05 11:47:51,391 - root - INFO - step: 30  loss:  7.6932  memory: 62.18GiB(65.45%)  tps: 5,766  mfu: 33.77%
[rank0]:2025-02-05 11:48:05,612 - root - INFO - step: 40  loss:  7.3141  memory: 62.18GiB(65.45%)  tps: 5,761  mfu: 33.74%
[rank0]:2025-02-05 11:48:19,830 - root - INFO - step: 50  loss:  7.0576  memory: 62.18GiB(65.45%)  tps: 5,763  mfu: 33.75%
[rank0]:2025-02-05 11:48:19,830 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2025-02-05 11:48:21,832 - root - INFO - Training completed
...
```

and as a sanity check, run one more time with an invalid policy to
verify graceful exit:
```
[rank0]:2025-02-05 11:53:42,600 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2025-02-05 11:53:42,601 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]:   File "/home/allencwang/workspace/torchtitan/train.py", line 433, in <module>
[rank0]:[rank0]:     main(config)
[rank0]:[rank0]:   File "/home/allencwang/.conda/envs/titan/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
[rank0]:[rank0]:     return f(*args, **kwargs)
[rank0]:[rank0]:   File "/home/allencwang/workspace/torchtitan/train.py", line 172, in main
[rank0]:[rank0]:     models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
[rank0]:[rank0]:   File "/home/allencwang/workspace/torchtitan/torchtitan/parallelisms/parallelize_llama.py", line 87, in parallelize_llama
[rank0]:[rank0]:     apply_fsdp(
[rank0]:[rank0]:   File "/home/allencwang/workspace/torchtitan/torchtitan/parallelisms/parallelize_llama.py", line 356, in apply_fsdp
[rank0]:[rank0]:     raise ValueError(
[rank0]:[rank0]: ValueError: Invalid reshard_after_forward_policy: invalid.
[rank0]:[rank0]:[W205 11:53:42.876753845 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
```

---------

Co-authored-by: tianyu-l <[email protected]>
  • Loading branch information
allenwang28 and tianyu-l authored Feb 6, 2025
1 parent 37c4b81 commit 5940dde
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 7 deletions.
10 changes: 10 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,16 @@ def build_test_list():
"test_generate",
ngpu=2,
),
OverrideDefinitions(
[
[
"--training.fsdp_reshard_after_forward always",
],
],
"Test always resharding after forward pass",
"fsdp_reshard_always",
ngpu=2,
),
]
return integration_tests_flavors

Expand Down
17 changes: 17 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ def __init__(self):
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)
self.parser.add_argument(
"--training.fsdp_reshard_after_forward",
type=str,
default="default",
choices=["default", "always", "never"],
help="""
`reshard_after_forward` specifies the policy for applying `reshard_after_forward`
within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
trading off memory and communication. See torch's `fully_shard` API for more documentation
on `reshard_after_forward`.
The supported policies include "default", "always" and "never":
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal
scenarios.
- "always" will enable `reshard_after_forward` for all forward passes.
- "never" will disable `reshard_after_forward` for all forward passes.
""",
)
self.parser.add_argument(
"--experimental.enable_async_tensor_parallel",
action="store_true",
Expand Down
39 changes: 32 additions & 7 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def parallelize_llama(
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
pp_enabled=parallel_dims.pp_enabled,
cpu_offload=job_config.training.enable_cpu_offload,
reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
)

if parallel_dims.dp_replicate_enabled:
Expand Down Expand Up @@ -308,24 +309,48 @@ def apply_fsdp(
reduce_dtype: torch.dtype,
pp_enabled: bool,
cpu_offload: bool = False,
reshard_after_forward_policy: str = "default",
):
"""
Apply data parallelism to the model. FSDP2 is used here.
Apply data parallelism (via FSDP2) to the model.
Args:
model (nn.Module): The model to apply data parallelism to.
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
param_dtype (torch.dtype): The data type to use for model parameters.
reduce_dtype (torch.dtype): The data type to use for reduction operations.
pp_enabled (bool): Whether pipeline parallelism is enabled.
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default".
Other options: "never", "always".
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
- "always" will enable `reshard_after_forward` for all forward passes.
- "never" will disable `reshard_after_forward` for all forward passes.
"""
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
if reshard_after_forward_policy == "always":
reshard_after_forward = True
elif reshard_after_forward_policy == "never":
reshard_after_forward = False
elif reshard_after_forward_policy == "default":
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
raise ValueError(
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
)
fully_shard(
transformer_block,
**fsdp_config,
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
compile = false
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
Expand Down

0 comments on commit 5940dde

Please sign in to comment.