Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] FusedLinearCrossEntropy support for Gemma2 #127

Open
yundai424 opened this issue Aug 27, 2024 · 4 comments · May be fixed by #320
Open

[feat] FusedLinearCrossEntropy support for Gemma2 #127

yundai424 opened this issue Aug 27, 2024 · 4 comments · May be fixed by #320
Assignees
Labels

Comments

@yundai424
Copy link
Collaborator

🚀 The feature, motivation and pitch

FLCE needs special handling for the soft capping in gemma2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py#L1054

Alternatives

No response

Additional context

No response

@troy1729
Copy link

troy1729 commented Aug 28, 2024

#take @yundai424 I would like to make an attempt to make it available.

I'm thinking this approach:

  • Introduce an optional dict parameter in forward() method here with softcap key value and non linearity key value(tanh in case of gemma2)
  • Peform the 3 steps from modelling_gemma2 after this matmul operation if optional dict parameter is not None

Can you assign it to me if this sounds okay?

@qingquansong
Copy link
Collaborator

@troy1729 Sounds reasonable to me. Assigned and feel free to kick off the implementation and ping us to discuss or review on any issues. Thank you!

@troy1729
Copy link

Hi @qingquansong, I've made the changes but still have to add the tests hence kept the PR in draft stage.
Might be a silly question, but we would want to have a triton kernel implementation for tanh/(any other non linearity) isn't it? Right now I've added torch.tanh callable. I'm sorry if this is obvious but thought to clarify

@qingquansong
Copy link
Collaborator

qingquansong commented Aug 29, 2024

Hey @troy1729 , thanks for the question (no silly question) and fast kick off! I think

  1. having certain triton functions operated on single element/block would be good in certain cases such as the silu function we have for swiglu that can be fused and used with other operations. Since in the end, we'd like to reduce element-wise operation overhead (like geglu/swiglu or Relu+ matmul etc.) rather than calling single one directly which will be same as calling torch.tanh especially after the torch compile. Also, check my comment 3 here and you'll find that implementing a single activation kernel would not be super helpful for you to fuse it with other operations especially in the backward pass. (isolated foward/backward functions could be helpful though)

  2. The soft cap idea is mainly scaling + caping range to (-1, 1) so using tanh (which keeps both pos and neg info) so some other torch activations may not be good to use here (though I agree we may have some cases in the future that possibly call extra torch activation functions)

  3. You may want to think about how the backprop is computed give this activation added on the logits. Since it's not as straight forward as just adding this activation, but you'll need to compute the grad_input (which is the grad of the hidden states) and the grad_weights.

  4. One more option is to put this option inside the liger normal CE loss (also need to take care of the backward if enabled this option) and then outside the chunked calling of the kernel, in the flce kernel, you don't need to worry about the backprop.

In sum, my suggestion would be: implement the tanh option for now only + follow geglu backward to see how tanh gradient is computed with chain rule to device the equation and implement it here

@Tcc0403 Tcc0403 linked a pull request Oct 22, 2024 that will close this issue
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants