Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete delayed scaling #812

Merged
merged 3 commits into from
Jan 31, 2025
Merged

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented Jan 30, 2025

Torchao plans to deprecate delayed scaling, delete it at torchtitan

Fix the issue: #654

Here are the logs running with enable_float8_linear = true

  1. compile = false
[rank0]:2025-01-31 10:12:50,551 - root - INFO - Float8 training active
[rank0]:2025-01-31 10:12:50,571 - root - INFO - Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=False
[rank0]:2025-01-31 10:12:50,572 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2025-01-31 10:12:50,572 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:2025-01-31 10:12:50,635 - root - INFO - Applied FSDP to the model
[rank0]:2025-01-31 10:12:50,835 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank0]:2025-01-31 10:12:50,835 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:2025-01-31 10:12:50,837 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250131-1012
[rank0]:2025-01-31 10:12:50,837 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 5 (warmup 200)
[rank0]:2025-01-31 10:12:50,837 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:2025-01-31 10:13:02,460 - root - INFO - step:  1  loss: 12.2581  memory: 74.27GiB(78.18%)  tps: 705  mfu: 4.13%
[rank0]:2025-01-31 10:13:02,460 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-01-31 10:13:04,973 - root - INFO - step:  2  loss: 12.0754  memory: 81.77GiB(86.07%)  tps: 3,262  mfu: 19.10%
[rank0]:2025-01-31 10:13:07,033 - root - INFO - step:  3  loss: 11.7432  memory: 81.77GiB(86.07%)  tps: 3,980  mfu: 23.30%
[rank0]:2025-01-31 10:13:09,089 - root - INFO - step:  4  loss: 11.3079  memory: 81.77GiB(86.07%)  tps: 3,986  mfu: 23.34%
[rank0]:2025-01-31 10:13:11,146 - root - INFO - step:  5  loss: 10.9303  memory: 81.77GiB(86.07%)  tps: 3,985  mfu: 23.33%
[rank0]:2025-01-31 10:13:11,147 - root - INFO - Saving a full checkpoint at last step, step 5.
[rank0]:2025-01-31 10:13:31,549 - root - INFO - Finished saving the checkpoint (or staging if async is enabled)in 20.40 seconds.
[rank0]:2025-01-31 10:13:31,549 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2025-01-31 10:13:33,551 - root - INFO - Training completed
  1. compile = true
[rank0]:2025-01-31 10:18:55,527 - root - INFO - Float8 training active
[rank0]:2025-01-31 10:18:55,547 - root - INFO - Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=False
[rank0]:2025-01-31 10:18:55,548 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2025-01-31 10:18:55,549 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:2025-01-31 10:18:55,591 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2025-01-31 10:18:55,656 - root - INFO - Applied FSDP to the model
[rank0]:2025-01-31 10:18:56,530 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank0]:2025-01-31 10:18:56,532 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250131-1018
[rank0]:2025-01-31 10:18:56,533 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 5 (warmup 200)
[rank0]:2025-01-31 10:18:56,533 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0] 
[rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0] Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
[rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0] Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_).
[rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0] 
[rank0]:/data/users/yifanmao/pytorch/torch/_inductor/lowering.py:1903: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2025-01-31 10:19:15,619 - root - INFO - step:  1  loss: 12.2476  memory: 40.21GiB(42.32%)  tps: 429  mfu: 2.51%
[rank0]:2025-01-31 10:19:15,619 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-01-31 10:19:16,747 - root - INFO - step:  2  loss: 12.0860  memory: 47.77GiB(50.28%)  tps: 7,267  mfu: 42.55%
[rank0]:2025-01-31 10:19:17,852 - root - INFO - step:  3  loss: 11.7620  memory: 47.77GiB(50.28%)  tps: 7,420  mfu: 43.45%
[rank0]:2025-01-31 10:19:18,953 - root - INFO - step:  4  loss: 11.3075  memory: 47.77GiB(50.28%)  tps: 7,449  mfu: 43.62%
[rank0]:2025-01-31 10:19:20,054 - root - INFO - step:  5  loss: 10.9359  memory: 47.77GiB(50.28%)  tps: 7,448  mfu: 43.61%
[rank0]:2025-01-31 10:19:20,054 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2025-01-31 10:19:22,056 - root - INFO - Training completed

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 30, 2025
@mori360 mori360 marked this pull request as ready for review January 30, 2025 23:12
@mori360 mori360 requested review from tianyu-l and vkuzo January 30, 2025 23:12
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me.

Since we don't have float8 testing in CI yet, can you include results with successful example runs, including ones with torch.compile enabled/disabled.

@tianyu-l tianyu-l linked an issue Jan 31, 2025 that may be closed by this pull request
@mori360
Copy link
Contributor Author

mori360 commented Jan 31, 2025

Changes look good to me.

Since we don't have float8 testing in CI yet, can you include results with successful example runs, including ones with torch.compile enabled/disabled.

Thanks for the comments, the logs are attached in the description with compile or not

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@mori360 mori360 merged commit 2271b63 into pytorch:main Jan 31, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

meta device issue with float8 delayed scale
4 participants