Skip to content

Commit

Permalink
[Model] Support Cohere2ForCausalLM (Cohere R7B) (#11203)
Browse files Browse the repository at this point in the history
  • Loading branch information
janimo authored Dec 16, 2024
1 parent b3b1526 commit bddbbcb
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 4 deletions.
4 changes: 2 additions & 2 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ Text Generation (``--task generate``)
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
- ✅︎
- ✅︎
* - :code:`CohereForCausalLM`
* - :code:`CohereForCausalLM`,:code:`Cohere2ForCausalLM`
- Command-R
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
- :code:`CohereForAI/c4ai-command-r-v01`, :code:`CohereForAI/c4ai-command-r7b-12-2024`, etc.
- ✅︎
- ✅︎
* - :code:`DbrxForCausalLM`
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class _HfExamplesInfo:
# ChatGLMModel supports multimodal
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
trust_remote_code=True),
"Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501
trust_remote_code=True),
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
"DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct",
trust_remote_code=True),
Expand Down
4 changes: 4 additions & 0 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import patch

import pytest
import transformers
from transformers import PretrainedConfig

from vllm import LLM
Expand All @@ -11,6 +12,9 @@
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
if (model_arch == "Cohere2ForCausalLM"
and transformers.__version__ < "4.48.0"):
pytest.skip(reason="Model introduced in HF >= 4.48.0")
if not model_info.is_available_online:
pytest.skip("Model is not available online")

Expand Down
19 changes: 17 additions & 2 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -171,12 +171,26 @@ def __init__(
rope_scaling=self.rope_scaling,
is_neox_style=False,
)

sliding_window = getattr(config, "sliding_window", None)
# Model v2 has sliding windows, v1 does not
self.v1 = sliding_window is None

layer_idx = extract_layer_index(prefix)
layer_has_sliding_window = (
getattr(config, "sliding_window_pattern", False)
and (layer_idx + 1) % self.config.sliding_window_pattern != 0)

self.sliding_window = (sliding_window
if layer_has_sliding_window else None)

self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=self.sliding_window,
prefix=f"{prefix}.attn")
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
Expand Down Expand Up @@ -206,7 +220,8 @@ def forward(
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
if self.v1 or self.sliding_window:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
# ChatGLMModel supports multimodal
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
Expand Down

0 comments on commit bddbbcb

Please sign in to comment.