diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 1c06c666d..1d7a84850 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -11,6 +11,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401 diff --git a/src/liger_kernel/transformers/model/llava.py b/src/liger_kernel/transformers/model/llava.py new file mode 100644 index 000000000..292102031 --- /dev/null +++ b/src/liger_kernel/transformers/model/llava.py @@ -0,0 +1,379 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC +from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING +from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast +from transformers.models.llava.modeling_llava import logger +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import is_torchdynamo_compiling +from transformers.utils import replace_return_docstrings +from transformers.utils.deprecation import deprecate_kwarg + +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss + + +@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) +@deprecate_kwarg("num_logits_to_keep", new_name="logits_to_keep", version="4.50") +@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) +def lce_forward_deprecated( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, +) -> Union[Tuple, LlavaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + legacy_processing = False + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing + # not very reliable, but we don't expect one to actually pass 500+ images for one prompt + # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True + legacy_processing = ( + (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length + ) or (input_ids.shape[-1] == 1 and pixel_values is not None) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + if legacy_processing and image_features is not None: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." + ) + # prefill stage vs decoding stage (legacy behavior copied) + if input_ids.shape[1] != 1: + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + else: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + elif image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] * image_features.shape[1] + + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training and (labels is not None): + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device) + shift_hidden_states = hidden_states[..., :-1, :][ + shift_attention_mask.to(hidden_states.device) != 0 + ].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels) + + if not return_dict: + # NOTE: This part has not been tested. + output = outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) +@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor = None, + **lm_kwargs, +) -> Union[Tuple, LlavaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + image_sizes=image_sizes, + ) + + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training and (labels is not None): + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device) + shift_hidden_states = hidden_states[..., :-1, :][ + shift_attention_mask.to(hidden_states.device) != 0 + ].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels) + + if not return_dict: + # NOTE: This part has not been tested. + output = outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index e493386f0..460ccbdc7 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -19,6 +19,8 @@ from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated +from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward +from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated @@ -212,6 +214,85 @@ def apply_liger_kernel_to_llama( _patch_rms_norm_module(decoder_layer.post_attention_layernorm) +def apply_liger_kernel_to_llava( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + model: PreTrainedModel = None, + **kwargs, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Llava models. + Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa. + However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur. + NOTE: Llava is not available in transformers<4.36.0 + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.llava import modeling_llava + + if cross_entropy: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + if transformer_version >= version.parse("4.49.0"): + modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward + else: # if version < 4.49.0 + logger.warning( + "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526" + ) + modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated + + if model is not None: + text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type + text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None) + vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None) + + kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} + if text_liger_fn: + accept_params = inspect.signature(text_liger_fn).parameters + remain_params = set(kwargs) - (set(accept_params) & set(kwargs)) + text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params} + + if remain_params: + logger.warning( + f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n" + f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}" + ) + text_kwargs["model"] = model.language_model + text_liger_fn(**text_kwargs) + elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN: + logger.warning(f"{text_model_name} is not supported by Liger kernel.") + + if vision_liger_fn: + accept_params = inspect.signature(vision_liger_fn).parameters + remain_params = set(kwargs) - (set(accept_params) & set(kwargs)) + vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params} + + if remain_params: + logger.warning( + f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n" + f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}" + ) + vision_kwargs["model"] = model.vision_tower + vision_liger_fn(**vision_kwargs) + elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN: + logger.warning(f"{vision_model_name} is not supported by Liger kernel.") + + def apply_liger_kernel_to_mllama( rope: bool = True, cross_entropy: bool = False, @@ -959,6 +1040,7 @@ def apply_liger_kernel_to_olmo2( "qwen2_vl": apply_liger_kernel_to_qwen2_vl, "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl, "phi3": apply_liger_kernel_to_phi3, + "llava": apply_liger_kernel_to_llava, } diff --git a/test/convergence/bf16/test_mini_models.py b/test/convergence/bf16/test_mini_models.py index 9295dbed2..b88dfee79 100644 --- a/test/convergence/bf16/test_mini_models.py +++ b/test/convergence/bf16/test_mini_models.py @@ -22,6 +22,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_granite from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_llava from liger_kernel.transformers import apply_liger_kernel_to_mistral from liger_kernel.transformers import apply_liger_kernel_to_mixtral from liger_kernel.transformers import apply_liger_kernel_to_mllama @@ -37,6 +38,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_granite from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_llava from test.utils import revert_liger_kernel_to_mistral from test.utils import revert_liger_kernel_to_mixtral from test.utils import revert_liger_kernel_to_mllama @@ -93,6 +95,15 @@ except ImportError: OLMO2_AVAILABLE = False +try: + from transformers import CLIPVisionConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + from liger_kernel.utils import infer_device device = infer_device() @@ -533,6 +544,65 @@ ), ) +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llava, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + rope_scaling=None, + rope_theta=500000.0, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=4096, + model_type="clip_vision_model", + num_attention_heads=16, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + def create_model(model_name="mini_llama3"): """ @@ -623,6 +693,25 @@ def run_mini_model( 1e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ], + ), pytest.param( "mini_granite3", 32, diff --git a/test/convergence/bf16/test_mini_models_multimodal.py b/test/convergence/bf16/test_mini_models_multimodal.py index 1be54366b..dd9ded294 100644 --- a/test/convergence/bf16/test_mini_models_multimodal.py +++ b/test/convergence/bf16/test_mini_models_multimodal.py @@ -8,6 +8,7 @@ from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerFast +from liger_kernel.transformers import apply_liger_kernel_to_llava from liger_kernel.transformers import apply_liger_kernel_to_mllama from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl @@ -15,8 +16,11 @@ from test.utils import UNTOKENIZED_DATASET_PATH from test.utils import MiniModelConfig from test.utils import assert_verbose_allclose +from test.utils import load_image_processing_config +from test.utils import load_processor_config from test.utils import load_tokenizer_config from test.utils import multimodal_collate_fn +from test.utils import revert_liger_kernel_to_llava from test.utils import revert_liger_kernel_to_mllama from test.utils import revert_liger_kernel_to_qwen2_5_vl from test.utils import revert_liger_kernel_to_qwen2_vl @@ -61,6 +65,18 @@ except ImportError: MLLAMA_AVAILABLE = False +try: + from transformers import CLIPImageProcessor + from transformers import CLIPVisionConfig + from transformers import LlamaConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + from transformers.models.llava.processing_llava import LlavaProcessor + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + from liger_kernel.utils import infer_device device = infer_device() @@ -229,6 +245,65 @@ ), ) +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_llava, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + rope_scaling=None, + rope_theta=500000.0, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=4096, + model_type="clip_vision_model", + num_attention_heads=16, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + def create_processor(model_name): if model_name == "mini_qwen2_vl": @@ -284,6 +359,39 @@ def create_processor(model_name): fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) image_processor = MllamaImageProcessor(size={"height": 560, "width": 560}) return MllamaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer) + elif model_name == "mini_llava": + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/tokenizer_config.json", + ) + ) + image_processor_config = load_image_processing_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/preprocessor_config.json", + ) + ) + processor_config = load_processor_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/processor_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = CLIPImageProcessor(**image_processor_config) + + return LlavaProcessor(**processor_config, image_processor=image_processor, tokenizer=fast_tokenizer) else: raise ValueError(f"Processor not available for model {model_name}") @@ -407,7 +515,6 @@ def run_mini_model_multimodal( print(f"Step {i}, Loss: {output.loss.item()}") loss_list.append(output.loss.item()) - MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) return {"loss": loss_list, "logits": output.logits, "model": model} @@ -473,6 +580,25 @@ def run_mini_model_multimodal( ), ], ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ], + ), ], ) def test_mini_model_multimodal( diff --git a/test/convergence/bf16/test_mini_models_with_logits.py b/test/convergence/bf16/test_mini_models_with_logits.py index d822584ae..90dee025c 100644 --- a/test/convergence/bf16/test_mini_models_with_logits.py +++ b/test/convergence/bf16/test_mini_models_with_logits.py @@ -22,6 +22,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_granite from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_llava from liger_kernel.transformers import apply_liger_kernel_to_mistral from liger_kernel.transformers import apply_liger_kernel_to_mixtral from liger_kernel.transformers import apply_liger_kernel_to_mllama @@ -37,6 +38,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_granite from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_llava from test.utils import revert_liger_kernel_to_mistral from test.utils import revert_liger_kernel_to_mixtral from test.utils import revert_liger_kernel_to_mllama @@ -93,6 +95,15 @@ except ImportError: OLMO2_AVAILABLE = False +try: + from transformers import CLIPVisionConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + from liger_kernel.utils import infer_device device = infer_device() @@ -534,6 +545,65 @@ ), ) +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llava, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + rope_scaling=None, + rope_theta=500000.0, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=4096, + model_type="clip_vision_model", + num_attention_heads=16, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + def create_model(model_name="mini_llama3"): """ @@ -622,6 +692,25 @@ def run_mini_model( 1e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ], + ), pytest.param( "mini_granite3", 32, diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index fd07e7a09..1699aef61 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -22,6 +22,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_granite from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_llava from liger_kernel.transformers import apply_liger_kernel_to_mistral from liger_kernel.transformers import apply_liger_kernel_to_mixtral from liger_kernel.transformers import apply_liger_kernel_to_mllama @@ -37,6 +38,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_granite from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_llava from test.utils import revert_liger_kernel_to_mistral from test.utils import revert_liger_kernel_to_mixtral from test.utils import revert_liger_kernel_to_mllama @@ -92,6 +94,15 @@ except ImportError: OLMO2_AVAILABLE = False +try: + from transformers import CLIPVisionConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + from liger_kernel.utils import infer_device device = infer_device() @@ -532,6 +543,65 @@ ), ) +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llava, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + rope_scaling=None, + rope_theta=500000.0, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=4096, + model_type="clip_vision_model", + num_attention_heads=16, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + def create_model(model_name="mini_llama3"): """ @@ -610,6 +680,22 @@ def run_mini_model( "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", [ ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ), pytest.param( "mini_mllama", 32, diff --git a/test/convergence/fp32/test_mini_models_multimodal.py b/test/convergence/fp32/test_mini_models_multimodal.py index 3ccee328e..47cb904dc 100644 --- a/test/convergence/fp32/test_mini_models_multimodal.py +++ b/test/convergence/fp32/test_mini_models_multimodal.py @@ -8,6 +8,7 @@ from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerFast +from liger_kernel.transformers import apply_liger_kernel_to_llava from liger_kernel.transformers import apply_liger_kernel_to_mllama from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl @@ -15,8 +16,11 @@ from test.utils import UNTOKENIZED_DATASET_PATH from test.utils import MiniModelConfig from test.utils import assert_verbose_allclose +from test.utils import load_image_processing_config +from test.utils import load_processor_config from test.utils import load_tokenizer_config from test.utils import multimodal_collate_fn +from test.utils import revert_liger_kernel_to_llava from test.utils import revert_liger_kernel_to_mllama from test.utils import revert_liger_kernel_to_qwen2_5_vl from test.utils import revert_liger_kernel_to_qwen2_vl @@ -60,6 +64,18 @@ except ImportError: MLLAMA_AVAILABLE = False +try: + from transformers import CLIPImageProcessor + from transformers import CLIPVisionConfig + from transformers import LlamaConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + from transformers.models.llava.processing_llava import LlavaProcessor + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + from liger_kernel.utils import infer_device device = infer_device() @@ -228,6 +244,65 @@ ), ) +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_llava, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + rope_scaling=None, + rope_theta=500000.0, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=4096, + model_type="clip_vision_model", + num_attention_heads=16, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + def create_processor(model_name): if model_name == "mini_qwen2_vl": @@ -283,6 +358,39 @@ def create_processor(model_name): fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) image_processor = MllamaImageProcessor(size={"height": 560, "width": 560}) return MllamaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer) + elif model_name == "mini_llava": + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/tokenizer_config.json", + ) + ) + image_processor_config = load_image_processing_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/preprocessor_config.json", + ) + ) + processor_config = load_processor_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/processor_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = CLIPImageProcessor(**image_processor_config) + + return LlavaProcessor(**processor_config, image_processor=image_processor, tokenizer=fast_tokenizer) else: raise ValueError(f"Processor not available for model {model_name}") @@ -466,6 +574,22 @@ def run_mini_model_multimodal( reason="Mllama not available in this version of transformers", ), ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ), ], ) def test_mini_model_multimodal( diff --git a/test/convergence/fp32/test_mini_models_with_logits.py b/test/convergence/fp32/test_mini_models_with_logits.py index 74ad760f3..05155bcf0 100644 --- a/test/convergence/fp32/test_mini_models_with_logits.py +++ b/test/convergence/fp32/test_mini_models_with_logits.py @@ -22,6 +22,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_granite from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_llava from liger_kernel.transformers import apply_liger_kernel_to_mistral from liger_kernel.transformers import apply_liger_kernel_to_mixtral from liger_kernel.transformers import apply_liger_kernel_to_mllama @@ -37,6 +38,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_granite from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_llava from test.utils import revert_liger_kernel_to_mistral from test.utils import revert_liger_kernel_to_mixtral from test.utils import revert_liger_kernel_to_mllama @@ -92,6 +94,15 @@ except ImportError: OLMO2_AVAILABLE = False +try: + from transformers import CLIPVisionConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + from liger_kernel.utils import infer_device device = infer_device() @@ -533,6 +544,65 @@ ), ) +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llava, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + rope_scaling=None, + rope_theta=500000.0, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=4096, + model_type="clip_vision_model", + num_attention_heads=16, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + def create_model(model_name="mini_llama3"): """ @@ -609,6 +679,22 @@ def run_mini_model( "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", [ ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ), pytest.param( "mini_mllama", 32, diff --git a/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json new file mode 100644 index 000000000..c32625c74 --- /dev/null +++ b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json @@ -0,0 +1,28 @@ +{ + "crop_size": { + "height": 336, + "width": 336 + }, + "do_center_crop": true, + "do_convert_rgb": true, + "do_normalize": true, + "do_rescale": true, + "do_resize": true, + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_processor_type": "CLIPImageProcessor", + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "processor_class": "LlavaProcessor", + "resample": 3, + "rescale_factor": 0.00392156862745098, + "size": { + "shortest_edge": 336 + } +} \ No newline at end of file diff --git a/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json new file mode 100644 index 000000000..8fbb221c7 --- /dev/null +++ b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json @@ -0,0 +1,7 @@ +{ + "image_token": "", + "num_additional_image_tokens": 1, + "patch_size": 14, + "processor_class": "LlavaProcessor", + "vision_feature_select_strategy": "default" +} \ No newline at end of file diff --git a/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json new file mode 100644 index 000000000..f9c6572a8 --- /dev/null +++ b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json @@ -0,0 +1,66 @@ +{ + "add_bos_token": true, + "add_eos_token": false, + "add_prefix_space": null, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "4": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "bos_token": "", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "extra_special_tokens": { + "image_token": "" + }, + "image_token": "", + "legacy": false, + "chat_template": "{% if not add_generation_prompt is defined %}{% set add_last_empty_assistant = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message.role == 'user' %}{{ '### User:\n' }}{% if message.content is not string %}{% for content in message.content %}{% if content.type == 'image' %}{{ '' }}{% elif content.type == 'text' %}{{ content.text }}{% else %}{# Do nothing #}{% endif %}{% endfor %}{% else %}{{ message.content }}{% endif %}{{ '\n\n' }}{% elif message.role == 'system' %}{{ '### System:\n' }}{% if message.content is not string %}{% for content in message.content %}{% if content.type == 'image' %}{{ '' }}{% elif content.type == 'text' %}{{ content.text }}{% else %}{# Do nothing #}{% endif %}{% endfor %}{% else %}{{ message.content }}{% endif %}{{ '\n\n' }}{% elif message.role == 'assistant' %}{{ '### Assistant:\n' }}{% if message.content is not string %}{% for content in message.content %}{% if content.type == 'text' %}{{ content.text }}{% else %}{# Do nothing #}{% endif %}{% endfor %}{% else %}{{ message.content }}{% endif %}{% else %}{{ '' }}{% endif %}{% endfor %}{% if not add_generation_prompt %}{{ eos_token }}{% elif add_generation_prompt %}{{ '### Assistant:\n' }}{% else %}{# Do nothing #}{% endif %}", + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "padding_side": "left", + "processor_class": "LlavaProcessor", + "sp_model_kwargs": {}, + "tokenizer_class": "LlamaTokenizer", + "trust_remote_code": false, + "unk_token": "", + "use_default_system_prompt": false, + "return_token_type_ids": false +} \ No newline at end of file diff --git a/test/utils.py b/test/utils.py index 64a261bef..de43a2d53 100644 --- a/test/utils.py +++ b/test/utils.py @@ -180,6 +180,20 @@ def load_tokenizer_config(config_path: str) -> dict: return tokenizer_config +def load_image_processing_config(config_path: str) -> dict: + """Load and process image processing configuration from a JSON file.""" + with open(config_path) as reader: + image_processing_config = json.load(reader) + return image_processing_config + + +def load_processor_config(config_path: str) -> dict: + """Load and process processor configuration from a JSON file.""" + with open(config_path) as reader: + processor_config = json.load(reader) + return processor_config + + def train_bpe_tokenizer(special_tokens: List[str], unk_token: str = "<|unk|>"): """ Train a tokenizer using the BPE algorithm. @@ -368,6 +382,18 @@ def revert_liger_kernel_to_olmo2(model_config: MiniModelConfig): print("Liger kernel patches have been reverted.") +def revert_liger_kernel_to_llava(model_config: MiniModelConfig): + """ + Revert all Liger kernel patches applied to llava. + """ + + from transformers.models.llava import modeling_llava + + importlib.reload(modeling_llava) + model_config.model_class = modeling_llava.LlavaForConditionalGeneration + print("Liger kernel patches have been reverted.") + + class HFAlignmentLoss: def __init__( self,