Skip to content

Commit

Permalink
Merge branch 'main' into fix/ic/ddp_recipe_load
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques authored Oct 30, 2023
2 parents 8e889cf + d8e9a45 commit b8eb710
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ More information on installation such as optional dependencies and requirements

### Recipes

To enable flexibility, ease of use, and repeatability, SparseML uses a declarative interface called `recipes` for specifying the sparsity-related algorithms and hyperparamters that should be applied by SparseML.
To enable flexibility, ease of use, and repeatability, SparseML uses a declarative interface called `recipes` for specifying the sparsity-related algorithms and hyperparameters that should be applied by SparseML.

`Recipes` are YAML-files formatted as a list of `modifiers`, which encode the instructions for SparseML. Example `modifiers` can be anything from setting the learning rate to encoding the hyperparameters of the gradual magnitude pruning algorithm. The SparseML system parses the `recipes` into a native format for each framework and applies the modifications to the model and training pipeline.

Expand Down
16 changes: 16 additions & 0 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ class Config:
multiply_batch_by_num_att_heads=False,
)

# Mistral has a config/model definition "MistralForCausalLM" but is based off Llama2.
# It contains these additions to Llama2-7b:
# * Sliding Window Attention
# * GQA (Grouped Query Attention)
# * Byte-fallback BPE tokenizer
MISTRAL_CONFIG = KeyValueCacheConfig(
model_name="mistral",
additional_transforms=AdditionalTransformsLLAMA,
key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="hidden_size",
transpose_value_input=None,
transpose_key_input=None,
multiply_batch_by_num_att_heads=False,
)

# Reusing the CodeGen transforms because it happens to match what we need for GPTNeo
additional_transforms_gpt_neo = AdditionalTransformsCodeGen

Expand All @@ -160,6 +175,7 @@ def get_kv_cache_config(
BLOOM_CONFIG,
MPT_CONFIG,
LLAMA_CONFIG,
MISTRAL_CONFIG,
GPT_NEO_CONFIG,
],
) -> KeyValueCacheConfig:
Expand Down

0 comments on commit b8eb710

Please sign in to comment.