Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused_linear_cross_entropy: Move float32 cast into kernel #238

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

hansonw
Copy link
Contributor

@hansonw hansonw commented Sep 9, 2024

Summary

Another small optimization :) The logits_chunk.float() allocation may be surprisingly large, e.g. Cohere models have 256K vocabs, so each logit chunk in float32 could be something like 1024 * 256K * 4 = 1GB VRAM (even more if the chunk size is larger.)

I actually don't think any explicit casting is even required within the Triton kernel since the intermediate softmax calculation variables like m, d, etc. are already float32 by default, so with type promotion the calculations should all be float32 regardless.

However, I added explicit casts .cast(tl.float32) around all of the X_ptr loads to make this more obvious to the reader. In either case, the actual liger_cross_entropy_kernel runs so quickly that I don't think there's any performance difference - this is purely to save the float32 allocation. (It might be more efficient without the explicit casts, but I was not able to measure anything - even with a 1K x 256K logit matrix the kernel kind of runs instantly lol.)

Testing Done

  • Hardware Type: A100 80GB
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 10, 2024

thanks! this makes sense. i did try the similar thing before but seen divergence compared with casting from the torch side (not sure why, maybe i did it wrong). also, currently bfloat16 convergence test is not actually tested due to #176. after the fix is merged, we can try to run on convergence tests with bf16 to see if there is any gap.

@hansonw
Copy link
Contributor Author

hansonw commented Sep 10, 2024

I added a test_float32_internal() unit test which runs the kernel twice, once with a bfloat16 input and once with a float-upcasted version, and verifies that the resulting output (in bfloat16) is exactly identical 🙂 You can also verify that the test passes even without the explicit .cast(tl.float32) calls, so maybe those could be removed as long as the test is present..

@ByronHsu
Copy link
Collaborator

cool! I will take a deeper look today or tomorrow. This is exciting!

@kostum123
Copy link

Can we merge this? in current form liger kernel broken.

@lancerts
Copy link
Collaborator

lancerts commented Oct 1, 2024

@hansonw can we resolve the conflict? ty

@ByronHsu
Copy link
Collaborator

ByronHsu commented Oct 2, 2024

We can merge this once the conflict is resolved. thanks!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants