Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any plans to add models from the llava series? #514

Open
jp1924 opened this issue Jan 6, 2025 · 12 comments
Open

Any plans to add models from the llava series? #514

jp1924 opened this issue Jan 6, 2025 · 12 comments
Assignees

Comments

@jp1924
Copy link
Contributor

jp1924 commented Jan 6, 2025

🚀 The feature, motivation and pitch

Do you have any plans to add models from the llava series provided by transformers?
If possible, I would like to add llava, llava-next, and llava-onevision. Is that feasible?

Alternatives

No response

Additional context

No response

@jp1924
Copy link
Contributor Author

jp1924 commented Jan 13, 2025

@Tcc0403 Is it okay if I work on this?

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 13, 2025

Sure!

@jp1924
Copy link
Contributor Author

jp1924 commented Jan 13, 2025

Thanks! @Tcc0403

Before opening a PR, I have two points I'd like to discuss.

I've been working on a separate fork here, and two main issues have come up:

  1. There's a discrepancy between huggingface llava's forward implementation and liger-kernel's forward implementation.
  2. LLaVA loads its language model and vision encoder through AutoModel.

Let's look at issue 1 first. Here's how loss is implemented in huggingface llava:

        outputs = self.language_model(
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            num_logits_to_keep=num_logits_to_keep,
        )

        logits = outputs[0]

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            if attention_mask is not None:
                # we use the input attention mask to shift the logits and labels, because it is 2D.
                # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
                shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
                shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
                shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()

Huggingface llava receives pre-calculated logits from the language model and uses those for loss calculation.
However, liger-kernel handles loss calculation by passing both lm_head and hidden states from the language model.

This creates an implementation discrepancy when applying liger-kernel to the existing llava.
You can see how the code differs here.

For issue 2, unlike Qwen2-VL and Mllama, llava doesn't have fixed language model and vision encoder components:

class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
    def __init__(self, config: LlavaConfig):
        super().__init__(config)
        self.vision_tower = AutoModel.from_config(config.vision_config)

        self.multi_modal_projector = LlavaMultiModalProjector(config)
        self.vocab_size = config.text_config.vocab_size
        self.language_model = AutoModelForCausalLM.from_config(config.text_config)
        self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1

        self.post_init()

The init is designed to load various models through AutoModel.

This makes it unclear which language model and vision encoder should be loaded for test cases.

Additionally, llava's code allows training the language model without visual data.
Combined with issue 1, this means models not officially supported by liger-kernel could potentially use liger-kernel through llava as a workaround.

This could lead to unexpected side effects - what are your thoughts on this?

While these issues can be addressed if needed, I'd like to hear your thoughts on how to proceed.

Therefore, if these features of Llava conflict with the development philosophy of Liger-Kernel,
it might be better not to add Llava.

@jp1924
Copy link
Contributor Author

jp1924 commented Jan 14, 2025

Actually, the perfect solution to this problem would be to use the loss calculated by the lm instead of calculating the loss in huggingface llava.
This would work well with liger-kernel, but since this involves transformers, it might get complicated. What do you think?

@jp1924
Copy link
Contributor Author

jp1924 commented Jan 14, 2025

cc @ByronHsu

@y-rok
Copy link

y-rok commented Jan 15, 2025

Hello, and thank you for considering adding Llava models. I’m new to using the Liger kernel.

As you mentioned, Llava models can be utilized with various language and vision models. However, I thought that when using Llama 3 and the CLIP vision encoder, the following function call seems to work before loading the model (using the Liger kernel without cross-entropy):

apply_liger_kernel_to_llama( rope=True, swiglu=True, cross_entropy=Flase, fused_linear_cross_entropy=False, rms_norm=True )

However, I observed a degradation in model performance when I applied this configuration.

Do you have any idea why this might be happening? If you could provide some insights, I’d be happy to assist in adding Llava models.

@jp1924 jp1924 mentioned this issue Jan 16, 2025
3 tasks
@jp1924
Copy link
Contributor Author

jp1924 commented Jan 20, 2025

@y-rok Ah... Sorry for the late reply.
First, could you explain your situation in more detail? It is not clear how the performance degradation is occurring.
And currently, Liger does not support LLAVA. Are you using your own redefined code?
Also, what version are you using?

First of all, I've never encountered this issue before, so I need to understand the current situation to figure out what the problem is.

@y-rok
Copy link

y-rok commented Jan 21, 2025

@jp1924
Thank you for your response. Since I was aware that Liger does not support LLAVA, I decided to apply the Liger kernel only to the language model. Specifically, I used the LLaMA 3.2 1B-Instruct model along with the CLIP vision encoder and called the following function before loading the model.

apply_liger_kernel_to_llama( rope=True, swiglu=True, cross_entropy=Flase, fused_linear_cross_entropy=False, rms_norm=True )

Even though I disabled cross-entropy, I experienced a drop in performance. I trained the model on 30% of the LLAVA-OneVision dataset and evaluated it on ChartQA and OCR-Bench. Moreover, the same issue occurred even when I enabled only the RoPE kernel and disabled all other variables. As far as I understand, enabling RoPE should only modify the RoPE kernel, so the performance should remain similar. However, that was not the case. Do you have any insights into why this might be happening?

@jp1924
Copy link
Contributor Author

jp1924 commented Jan 22, 2025

@y-rok
The issue you mentioned is something I haven't encountered before,
and I can't suggest a helpful solution.

Could you provide the code that reproduces the problem so I can take a look?
I didn't face any issues when I trained the Llava model with gemma2 9b.
Maybe try switching the model?

Is the issue occurring in stage 1 or stage 2?

@y-rok
Copy link

y-rok commented Jan 22, 2025

@jp1924
The code is quite simple, as I mentioned. I only called apply_liger_kernel_to_llama before loading the model with from_pretrained(). After I double-check everything, I’ll share my results again.

By the way, I encountered this issue during Stage 2.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 29, 2025

@y-rok Might be the same issue as #544. Could you try liger-kernel-nightly to see whether the issue remains with liger rope patch?

@y-rok
Copy link

y-rok commented Feb 3, 2025

@Tcc0403
Sorry for the late reply. I will let you know the results after I give it a try.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants