Skip to content

Commit

Permalink
Update base for Update on "enable TritonFusedRMSNorm with local_map a…
Browse files Browse the repository at this point in the history
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
  • Loading branch information
XilunWu committed Jun 13, 2024
1 parent 18752a7 commit 0e47fe3
Showing 0 changed files with 0 additions and 0 deletions.

0 comments on commit 0e47fe3

Please sign in to comment.