From bb24c3ac04ad2957dd76c70310426a7060890b38 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 15 Nov 2024 18:49:35 -0800 Subject: [PATCH] update Signed-off-by: Roger Wang --- vllm/model_executor/models/phi3_small.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 3704d7b6c26d7..2139cec441807 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -324,11 +324,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) def forward( self, @@ -337,9 +334,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) if (self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0): hidden_states = hidden_states * self.mup_embedding_multiplier