diff --git a/README.md b/README.md index b381598c3fa..69c5caf8c3c 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ More information on installation such as optional dependencies and requirements ### Recipes -To enable flexibility, ease of use, and repeatability, SparseML uses a declarative interface called `recipes` for specifying the sparsity-related algorithms and hyperparamters that should be applied by SparseML. +To enable flexibility, ease of use, and repeatability, SparseML uses a declarative interface called `recipes` for specifying the sparsity-related algorithms and hyperparameters that should be applied by SparseML. `Recipes` are YAML-files formatted as a list of `modifiers`, which encode the instructions for SparseML. Example `modifiers` can be anything from setting the learning rate to encoding the hyperparameters of the gradual magnitude pruning algorithm. The SparseML system parses the `recipes` into a native format for each framework and applies the modifications to the model and training pipeline. diff --git a/src/sparseml/exporters/transforms/kv_cache/configs.py b/src/sparseml/exporters/transforms/kv_cache/configs.py index ff9189b1c41..d617075e7cb 100644 --- a/src/sparseml/exporters/transforms/kv_cache/configs.py +++ b/src/sparseml/exporters/transforms/kv_cache/configs.py @@ -138,6 +138,21 @@ class Config: multiply_batch_by_num_att_heads=False, ) +# Mistral has a config/model definition "MistralForCausalLM" but is based off Llama2. +# It contains these additions to Llama2-7b: +# * Sliding Window Attention +# * GQA (Grouped Query Attention) +# * Byte-fallback BPE tokenizer +MISTRAL_CONFIG = KeyValueCacheConfig( + model_name="mistral", + additional_transforms=AdditionalTransformsLLAMA, + key_num_attention_heads="num_attention_heads", + key_num_embedding_hidden_size="hidden_size", + transpose_value_input=None, + transpose_key_input=None, + multiply_batch_by_num_att_heads=False, +) + # Reusing the CodeGen transforms because it happens to match what we need for GPTNeo additional_transforms_gpt_neo = AdditionalTransformsCodeGen @@ -160,6 +175,7 @@ def get_kv_cache_config( BLOOM_CONFIG, MPT_CONFIG, LLAMA_CONFIG, + MISTRAL_CONFIG, GPT_NEO_CONFIG, ], ) -> KeyValueCacheConfig: