From 037e302d46021f3574c53137738a6572f75a7364 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 28 Oct 2023 16:56:03 -0600 Subject: [PATCH] Add kvcache config for Mistral (#1766) * Add kvcache config for Mistral * Update configs.py * Update configs.py --- .../exporters/transforms/kv_cache/configs.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/sparseml/exporters/transforms/kv_cache/configs.py b/src/sparseml/exporters/transforms/kv_cache/configs.py index ff9189b1c41..d617075e7cb 100644 --- a/src/sparseml/exporters/transforms/kv_cache/configs.py +++ b/src/sparseml/exporters/transforms/kv_cache/configs.py @@ -138,6 +138,21 @@ class Config: multiply_batch_by_num_att_heads=False, ) +# Mistral has a config/model definition "MistralForCausalLM" but is based off Llama2. +# It contains these additions to Llama2-7b: +# * Sliding Window Attention +# * GQA (Grouped Query Attention) +# * Byte-fallback BPE tokenizer +MISTRAL_CONFIG = KeyValueCacheConfig( + model_name="mistral", + additional_transforms=AdditionalTransformsLLAMA, + key_num_attention_heads="num_attention_heads", + key_num_embedding_hidden_size="hidden_size", + transpose_value_input=None, + transpose_key_input=None, + multiply_batch_by_num_att_heads=False, +) + # Reusing the CodeGen transforms because it happens to match what we need for GPTNeo additional_transforms_gpt_neo = AdditionalTransformsCodeGen @@ -160,6 +175,7 @@ def get_kv_cache_config( BLOOM_CONFIG, MPT_CONFIG, LLAMA_CONFIG, + MISTRAL_CONFIG, GPT_NEO_CONFIG, ], ) -> KeyValueCacheConfig: