Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jamba liger fused linear+xentropy #102

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

winglian
Copy link

Summary

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@ByronHsu
Copy link
Collaborator

awesome! please make sure you add both conv (w logits and w/o logits) and unit tests. we are very focused on testing

@ByronHsu
Copy link
Collaborator

#63

@yundai424 yundai424 linked an issue Aug 26, 2024 that may be closed by this pull request
@yubofredwang
Copy link

I added the following additional monkey patch for Jamba.

    from transformers.models.jamba import modeling_jamba
    if rms_norm:
        # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
        modeling_jamba.JambaRMSNorm = LigerRMSNorm
    if cross_entropy:
        modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
    if swiglu:
        modeling_jamba.JambaMLP = LigerSwiGLUMLP

However, convergence test seems to be failing for some values in the tensor:

E           Mismatch at index (0, 5): tensor1[(0, 5)] = 1.1513792276382446, tensor2[(0, 5)] = 1.1512681245803833
E           Mismatch at index (0, 27): tensor1[(0, 27)] = 0.6227690577507019, tensor2[(0, 27)] = 0.6227344870567322
E           Mismatch at index (0, 28): tensor1[(0, 28)] = 0.7790964841842651, tensor2[(0, 28)] = 0.7790292501449585
E           Mismatch at index (0, 29): tensor1[(0, 29)] = 0.524261474609375, tensor2[(0, 29)] = 0.5243569612503052
E           Mismatch at index (0, 30): tensor1[(0, 30)] = 0.8967938423156738, tensor2[(0, 30)] = 0.8968125581741333

I tracked this down to LigerRMSNorm but needs more time to investigate why there is a difference

if rope:
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the comment is wrong

if cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if swiglu:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is lce_forward?

@yubofredwang
Copy link

yubofredwang commented Sep 5, 2024

HI @winglian created a PR towards main branch of your fork. Do you want to merge it first and then update this PR to base on that? winglian#1

Or I can create a separate PR to linkedin:main #214

@ByronHsu thoughts?

@winglian
Copy link
Author

winglian commented Sep 5, 2024

@yubofredwang if your PR captures all the changes, I'm happy to have your PR supersede mine. thanks!

@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 6, 2024

@yubofredwang there are few conflicts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] Add jamba support
3 participants