Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Makes FSDP
reshard_after_forward
configurable (#822)
# What does this PR do? This PR addresses #644, making FSDP `reshard_after_forward_policy` configurable by introducing: 1) `data_parallel` as a group within the `job_config` 2) `data_parallel.reshard_after_forward_policy` as an option within `job_config` 3) routing `reshard_after_forward_policy` from job config to `apply_fsdp` in [parallelize_llama.py](torchtitan/parallelisms/parallelize_llama.py) 4) applying the defaults to existing job configs 5) adds an integration test that tests this configuration, i.e. setting it to always # Why do we need this PR? Rather than hardcode the `reshard_after_forward_policy`, this allows users to explore more memory/communication tradeoffs in more complex sharding strategies. # Next steps As this PR introduces `data_parallel` group, a natural next step is to convert other data parallel related parameters into this group as well. Specifically, this includes `training.data_parallel_replicate_degree` and `training.data_parallel_shard_degree`. I will convert these to `data_parallel.replicate_degree` and `data_parallel.shard_degree` in a future PR. # Tests Ran Llama3 8B model with all options, default / never / always on a machine with 8 H100s. We should expect: 1) similar convergence between runs (using the same seed) 2) `never` > `default` > `always` in terms of memory usage ## Default ``` ... [rank0]:2025-02-05 11:44:41,670 - root - INFO - Applied FSDP to the model [rank0]:2025-02-05 11:44:41,886 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank0]:2025-02-05 11:44:41,889 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250205-1144 [rank0]:2025-02-05 11:44:41,889 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 50 (warmup 200) [rank0]:2025-02-05 11:44:41,889 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:NCCL version 2.21.5+cuda12.4 [rank0]:2025-02-05 11:44:54,986 - root - INFO - step: 1 loss: 12.2302 memory: 42.08GiB(44.30%) tps: 626 mfu: 3.66% [rank0]:2025-02-05 11:44:54,986 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2025-02-05 11:45:08,305 - root - INFO - step: 10 loss: 9.8873 memory: 49.59GiB(52.19%) tps: 5,536 mfu: 32.42% [rank0]:2025-02-05 11:45:22,895 - root - INFO - step: 20 loss: 8.4177 memory: 49.59GiB(52.19%) tps: 5,615 mfu: 32.88% [rank0]:2025-02-05 11:45:37,506 - root - INFO - step: 30 loss: 7.6932 memory: 49.59GiB(52.19%) tps: 5,607 mfu: 32.84% [rank0]:2025-02-05 11:45:52,129 - root - INFO - step: 40 loss: 7.3155 memory: 49.59GiB(52.19%) tps: 5,603 mfu: 32.81% [rank0]:2025-02-05 11:46:06,747 - root - INFO - step: 50 loss: 7.0608 memory: 49.59GiB(52.19%) tps: 5,605 mfu: 32.82% [rank0]:2025-02-05 11:46:06,748 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2025-02-05 11:46:08,750 - root - INFO - Training completed ... ``` ## Always ``` ... [rank0]:2025-02-05 11:49:18,382 - root - INFO - Applied FSDP to the model [rank0]:2025-02-05 11:49:18,617 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank0]:2025-02-05 11:49:18,619 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250205-1149 [rank0]:2025-02-05 11:49:18,620 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 50 (warmup 200) [rank0]:2025-02-05 11:49:18,620 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:NCCL version 2.21.5+cuda12.4 [rank0]:2025-02-05 11:49:32,445 - root - INFO - step: 1 loss: 12.2302 memory: 41.89GiB(44.09%) tps: 593 mfu: 3.47% [rank0]:2025-02-05 11:49:32,445 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2025-02-05 11:49:45,525 - root - INFO - step: 10 loss: 9.8874 memory: 49.41GiB(52.01%) tps: 5,637 mfu: 33.01% [rank0]:2025-02-05 11:50:00,097 - root - INFO - step: 20 loss: 8.4177 memory: 49.41GiB(52.01%) tps: 5,623 mfu: 32.93% [rank0]:2025-02-05 11:50:14,719 - root - INFO - step: 30 loss: 7.6932 memory: 49.41GiB(52.01%) tps: 5,603 mfu: 32.81% [rank0]:2025-02-05 11:50:29,348 - root - INFO - step: 40 loss: 7.3146 memory: 49.41GiB(52.01%) tps: 5,600 mfu: 32.80% [rank0]:2025-02-05 11:50:43,995 - root - INFO - step: 50 loss: 7.0576 memory: 49.41GiB(52.01%) tps: 5,593 mfu: 32.75% [rank0]:2025-02-05 11:48:19,830 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2025-02-05 11:48:21,832 - root - INFO - Training completed ... ``` ## Never ``` ... [rank0]:2025-02-05 11:46:54,548 - root - INFO - Applied FSDP to the model [rank0]:2025-02-05 11:46:55,574 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank0]:2025-02-05 11:46:55,610 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250205-1146 [rank0]:2025-02-05 11:46:55,611 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 50 (warmup 200) [rank0]:2025-02-05 11:46:55,611 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:NCCL version 2.21.5+cuda12.4 [rank0]:2025-02-05 11:47:10,242 - root - INFO - step: 1 loss: 12.2302 memory: 54.66GiB(57.54%) tps: 560 mfu: 3.28% [rank0]:2025-02-05 11:47:10,242 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2025-02-05 11:47:23,033 - root - INFO - step: 10 loss: 9.8872 memory: 62.18GiB(65.45%) tps: 5,765 mfu: 33.76% [rank0]:2025-02-05 11:47:37,183 - root - INFO - step: 20 loss: 8.4177 memory: 62.18GiB(65.45%) tps: 5,790 mfu: 33.91% [rank0]:2025-02-05 11:47:51,391 - root - INFO - step: 30 loss: 7.6932 memory: 62.18GiB(65.45%) tps: 5,766 mfu: 33.77% [rank0]:2025-02-05 11:48:05,612 - root - INFO - step: 40 loss: 7.3141 memory: 62.18GiB(65.45%) tps: 5,761 mfu: 33.74% [rank0]:2025-02-05 11:48:19,830 - root - INFO - step: 50 loss: 7.0576 memory: 62.18GiB(65.45%) tps: 5,763 mfu: 33.75% [rank0]:2025-02-05 11:48:19,830 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2025-02-05 11:48:21,832 - root - INFO - Training completed ... ``` and as a sanity check, run one more time with an invalid policy to verify graceful exit: ``` [rank0]:2025-02-05 11:53:42,600 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:2025-02-05 11:53:42,601 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[rank0]: Traceback (most recent call last): [rank0]:[rank0]: File "/home/allencwang/workspace/torchtitan/train.py", line 433, in <module> [rank0]:[rank0]: main(config) [rank0]:[rank0]: File "/home/allencwang/.conda/envs/titan/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper [rank0]:[rank0]: return f(*args, **kwargs) [rank0]:[rank0]: File "/home/allencwang/workspace/torchtitan/train.py", line 172, in main [rank0]:[rank0]: models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) [rank0]:[rank0]: File "/home/allencwang/workspace/torchtitan/torchtitan/parallelisms/parallelize_llama.py", line 87, in parallelize_llama [rank0]:[rank0]: apply_fsdp( [rank0]:[rank0]: File "/home/allencwang/workspace/torchtitan/torchtitan/parallelisms/parallelize_llama.py", line 356, in apply_fsdp [rank0]:[rank0]: raise ValueError( [rank0]:[rank0]: ValueError: Invalid reshard_after_forward_policy: invalid. [rank0]:[rank0]:[W205 11:53:42.876753845 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) ``` --------- Co-authored-by: tianyu-l <[email protected]>
- Loading branch information