Skip to content

Commit

Permalink
Add kvcache config for Mistral (#1766)
Browse files Browse the repository at this point in the history
* Add kvcache config for Mistral

* Update configs.py

* Update configs.py
  • Loading branch information
mgoin authored Oct 28, 2023
1 parent 3b7c340 commit 955ae11
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ class Config:
multiply_batch_by_num_att_heads=False,
)

# Mistral has a config/model definition "MistralForCausalLM" but is based off Llama2.
# It contains these additions to Llama2-7b:
# * Sliding Window Attention
# * GQA (Grouped Query Attention)
# * Byte-fallback BPE tokenizer
MISTRAL_CONFIG = KeyValueCacheConfig(
model_name="mistral",
additional_transforms=AdditionalTransformsLLAMA,
key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="hidden_size",
transpose_value_input=None,
transpose_key_input=None,
multiply_batch_by_num_att_heads=False,
)

# Reusing the CodeGen transforms because it happens to match what we need for GPTNeo
additional_transforms_gpt_neo = AdditionalTransformsCodeGen

Expand All @@ -160,6 +175,7 @@ def get_kv_cache_config(
BLOOM_CONFIG,
MPT_CONFIG,
LLAMA_CONFIG,
MISTRAL_CONFIG,
GPT_NEO_CONFIG,
],
) -> KeyValueCacheConfig:
Expand Down

0 comments on commit 955ae11

Please sign in to comment.