Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 committed Nov 16, 2024
1 parent b113ec2 commit bb24c3a
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions vllm/model_executor/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
self,
Expand All @@ -337,9 +334,13 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
if (self.mup_embedding_multiplier is not None
and self.mup_embedding_multiplier > 0.0):
hidden_states = hidden_states * self.mup_embedding_multiplier
Expand Down

0 comments on commit bb24c3a

Please sign in to comment.