diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index cae4a88de1638..3bef3f3226062 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -118,9 +118,9 @@ Text Generation (``--task generate``) - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. - ✅︎ - ✅︎ - * - :code:`CohereForCausalLM` + * - :code:`CohereForCausalLM`,:code:`Cohere2ForCausalLM` - Command-R - - :code:`CohereForAI/c4ai-command-r-v01`, etc. + - :code:`CohereForAI/c4ai-command-r-v01`, :code:`CohereForAI/c4ai-command-r7b-12-2024`, etc. - ✅︎ - ✅︎ * - :code:`DbrxForCausalLM` diff --git a/tests/models/registry.py b/tests/models/registry.py index 6a8b1742ceae3..fac8c4b2e9b19 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -53,6 +53,8 @@ class _HfExamplesInfo: # ChatGLMModel supports multimodal "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", trust_remote_code=True), + "Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501 + trust_remote_code=True), "DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"), "DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct", trust_remote_code=True), diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 3b728f2744fca..a4eea7f035c91 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -1,6 +1,7 @@ from unittest.mock import patch import pytest +import transformers from transformers import PretrainedConfig from vllm import LLM @@ -11,6 +12,9 @@ @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) def test_can_initialize(model_arch): model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + if (model_arch == "Cohere2ForCausalLM" + and transformers.__version__ < "4.48.0"): + pytest.skip(reason="Model introduced in HF >= 4.48.0") if not model_info.is_available_online: pytest.skip("Model is not available online") diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 85e24ca660686..c846e42f1b0c3 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -171,12 +171,26 @@ def __init__( rope_scaling=self.rope_scaling, is_neox_style=False, ) + + sliding_window = getattr(config, "sliding_window", None) + # Model v2 has sliding windows, v1 does not + self.v1 = sliding_window is None + + layer_idx = extract_layer_index(prefix) + layer_has_sliding_window = ( + getattr(config, "sliding_window_pattern", False) + and (layer_idx + 1) % self.config.sliding_window_pattern != 0) + + self.sliding_window = (sliding_window + if layer_has_sliding_window else None) + self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + per_layer_sliding_window=self.sliding_window, prefix=f"{prefix}.attn") if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, @@ -206,7 +220,8 @@ def forward( q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.use_qk_norm: q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) + if self.v1 or self.sliding_window: + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4e77746f312e3..68a2467a813a1 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -41,6 +41,7 @@ "BloomForCausalLM": ("bloom", "BloomForCausalLM"), # ChatGLMModel supports multimodal "CohereForCausalLM": ("commandr", "CohereForCausalLM"), + "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"), "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),