Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update base for Update on "enable TritonFusedRMSNorm with local_map a…
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
- Loading branch information