diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 05b178aec6832..ac0d265a961f0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -176,7 +176,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.sampler = get_sampler() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + return self.backbone.get_input_embeddings(input_ids) def forward(self, input_ids: torch.Tensor,