Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flex Attention Monkey Patch for LLAMA #540

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/liger_kernel/transformers/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F

from torch.nn import CrossEntropyLoss
from torch.nn.attention.flex_attention import flex_attention
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
Expand All @@ -19,6 +20,8 @@
if TYPE_CHECKING:
from transformers.cache_utils import Cache

flex_attention = torch.compile(flex_attention)


@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -249,3 +252,58 @@ def lce_forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py#L12
def flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
softcap: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

def causal_mod(score, b, h, q_idx, kv_idx):
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if causal_mask is not None:
score = score + causal_mask[b][0][q_idx][kv_idx]
return score

# def causal_mask_fn(b, h, q_idx, kv_idx):
# return q_idx >= kv_idx

# TODO: Construct block attention mask that leverages sparsity
# sparse_causal_mask = create_block_mask(
# causal_mask_fn, B=None, H=None, Q_LEN=query.shape[-2], KV_LEN=key.shape[-2], device=query.device, BLOCK_SIZE=1
# )

attn_output, attention_weights = flex_attention(
query,
key,
value,
score_mod=causal_mod,
# block_mask=sparse_causal_mask,
enable_gqa=True,
scale=scaling,
return_lse=True,
kernel_options={
"BLOCK_M": 32,
"BLOCK_N": 32,
"BLOCK_M1": 16,
"BLOCK_N1": 32,
"BLOCK_M2": 32,
"BLOCK_N2": 16,
},
)

attention_weights = attention_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attention_weights
37 changes: 37 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
from liger_kernel.transformers.model.llama import flex_attention_forward as llama_flex_attention_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
Expand All @@ -37,7 +38,9 @@

logger = logging.getLogger(__name__)
SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
FLEXATTENTION_SUPPORTED_TRANSFORMER_VERSION = "4.48.0"
TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
FLEX_ATTENTION_NOT_SUPPORT_WARNING = "Flex attention is not supported."


def _bind_method_to_module(module, method_name: str, new_method: Callable):
Expand Down Expand Up @@ -68,6 +71,7 @@ def apply_liger_kernel_to_llama(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Expand Down Expand Up @@ -115,6 +119,15 @@ def apply_liger_kernel_to_llama(
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated

if flex_attn:
# Patching HuggingFace default attn_impl from `toch.sdpa` to liger's `llama_flex_attention_forward``
if transformer_version >= version.parse(FLEXATTENTION_SUPPORTED_TRANSFORMER_VERSION):
modeling_llama.ALL_ATTENTION_FUNCTIONS.update(
{"sdpa": llama_flex_attention_forward, "flex_attention": llama_flex_attention_forward}
)
else:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
Expand All @@ -141,6 +154,7 @@ def apply_liger_kernel_to_mllama(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support by HuggingFace
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
Expand Down Expand Up @@ -194,6 +208,8 @@ def apply_liger_kernel_to_mllama(
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -244,6 +260,7 @@ def apply_liger_kernel_to_mistral(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support by Liger yet
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
Expand Down Expand Up @@ -278,6 +295,8 @@ def apply_liger_kernel_to_mistral(
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
if swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand All @@ -304,6 +323,7 @@ def apply_liger_kernel_to_mixtral(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support by Liger yet
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
Expand Down Expand Up @@ -349,6 +369,8 @@ def apply_liger_kernel_to_mixtral(
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
if swiglu:
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -376,6 +398,7 @@ def apply_liger_kernel_to_gemma(
rms_norm: bool = True,
geglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support by Liger yet
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Gemma
Expand Down Expand Up @@ -424,6 +447,8 @@ def apply_liger_kernel_to_gemma(
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand All @@ -450,6 +475,7 @@ def apply_liger_kernel_to_gemma2(
rms_norm: bool = True,
geglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support by Liger yet
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Gemma2
Expand Down Expand Up @@ -500,6 +526,8 @@ def apply_liger_kernel_to_gemma2(
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
if geglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -528,6 +556,7 @@ def apply_liger_kernel_to_qwen2(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support by HuggingFace
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
Expand Down Expand Up @@ -574,6 +603,8 @@ def apply_liger_kernel_to_qwen2(

if swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -602,6 +633,7 @@ def apply_liger_kernel_to_qwen2_vl(
layer_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support by HuggingFace
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
Expand Down Expand Up @@ -641,6 +673,8 @@ def apply_liger_kernel_to_qwen2_vl(
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if swiglu:
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -673,6 +707,7 @@ def apply_liger_kernel_to_phi3(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support by Liger yet
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
Expand Down Expand Up @@ -716,6 +751,8 @@ def apply_liger_kernel_to_phi3(
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down
13 changes: 9 additions & 4 deletions test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def run_mini_model(
kwargs = {
"rope": True,
"rms_norm": True,
"flex_attn": False,
}

model_supports_layer_norm = "qwen2_vl" in model_name
Expand All @@ -428,6 +429,10 @@ def run_mini_model(
kwargs["fused_linear_cross_entropy"] = True
kwargs["cross_entropy"] = False

model_supports_flex_attn = "llama3" in model_name # excluding mllama
if model_supports_flex_attn:
kwargs["flex_attn"] = True

MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
else:
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
Expand Down Expand Up @@ -457,14 +462,14 @@ def run_mini_model(
@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
[
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
pytest.param(
("mini_llama3", 32, 1e-4, torch.float32, 2e-1, 1e-1, 1e-4, 1e-5, 5e-2, 1e-5),
pytest.param( # Set larger loss tol. ref: https://github.com/pytorch-labs/attention-gym/blob/41ef5bff15033269b8cd8a012214a637345c1ddd/examples/benchmark.py#L151
"mini_llama3",
32,
1e-4,
torch.bfloat16,
1e-3,
1e-2,
2e-1,
1e-1,
1e-1,
1e-2,
1e-2,
Expand Down
13 changes: 9 additions & 4 deletions test/convergence/test_mini_models_with_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def run_mini_model(
kwargs = {
"rope": True,
"rms_norm": True,
"flex_attn": False,
}

model_supports_layer_norm = "qwen2_vl" in model_name
Expand All @@ -428,6 +429,10 @@ def run_mini_model(
kwargs["fused_linear_cross_entropy"] = False
kwargs["cross_entropy"] = True

model_supports_flex_attn = "llama3" in model_name # excluding mllama
if model_supports_flex_attn:
kwargs["flex_attn"] = True

MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
else:
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
Expand Down Expand Up @@ -456,16 +461,16 @@ def run_mini_model(
@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
[
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
("mini_llama3", 32, 1e-4, torch.float32, 2e-1, 1e-1, 1e-2, 1e-3, 5e-3, 1e-5),
pytest.param(
"mini_llama3",
32,
1e-4,
torch.bfloat16,
1e-3,
1e-2,
2e-1,
1e-1,
2e-1,
1e-1,
1e-2,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
Expand Down
5 changes: 5 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,13 @@ def revert_liger_kernel_to_llama(model_config: MiniModelConfig):
Revert all Liger kernel patches applied to Llama.
"""

from transformers import (
modeling_utils, # Sets, `ALL_ATTENTION_FUNCTIONS`, are mutable, and all references point to the same object so reloading a module doesn't reset imported objects.
)
from transformers.models.llama import modeling_llama

# Reload both modules in the correct order
importlib.reload(modeling_utils)
importlib.reload(modeling_llama)
model_config.model_class = modeling_llama.LlamaForCausalLM
print("Liger kernel patches have been reverted.")
Expand Down
Loading