Skip to content

Commit

Permalink
add a wrapper for running phi-3-mini with kv cache
Browse files Browse the repository at this point in the history
  • Loading branch information
helunwencser committed Jul 25, 2024
1 parent b13f8aa commit a7fffeb
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 22 deletions.
11 changes: 11 additions & 0 deletions examples/models/phi-3-mini/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .phi_3_mini import Phi3Mini

__all__ = [
Phi3Mini,
]
34 changes: 12 additions & 22 deletions examples/models/phi-3-mini/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch

from extension.llm.transformers.static_cache import ETStaticCache
from phi_3_mini import Phi3Mini

from transformers import AutoTokenizer, Phi3ForCausalLM

Expand Down Expand Up @@ -42,38 +42,28 @@ def _generate_token(args, model, prompt_tokens):
def _generate_token_with_kv_cache(args, model, prompt_tokens):
print("Generating tokens:", end="", flush=True)

result = model.forward(
input_ids=prompt_tokens,
use_cache=True,
return_dict=True,
past_key_values=ETStaticCache(
model.config,
prompt_tokens.shape[0],
args.seq_len + prompt_tokens.shape[-1],
device=model.device,
dtype=model.dtype,
),
)
model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1])

current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item()
current_key_value = result.past_key_values
for input_pos in range(prompt_tokens.shape[-1]):
result = model.forward(
input_ids=prompt_tokens[:, input_pos : input_pos + 1],
cache_position=torch.arange(0, input_pos, device=model.model.device),
)

current_token = torch.argmax(result[:, -1, :], dim=-1).item()
print(f" {current_token}", end="", flush=True)

generated_tokens = [current_token]

while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
result = model.forward(
input_ids=torch.tensor([[current_token]], dtype=torch.long),
use_cache=True,
return_dict=True,
past_key_values=current_key_value,
cache_position=torch.arange(
0, prompt_tokens.shape[-1] + len(generated_tokens), device=model.device
0,
prompt_tokens.shape[-1] + len(generated_tokens),
device=model.model.device,
),
)
current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item()
current_key_value = result.past_key_values
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
print(f" {current_token}", end="", flush=True)
generated_tokens.append(current_token)

Expand Down
36 changes: 36 additions & 0 deletions examples/models/phi-3-mini/phi_3_mini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn

from extension.llm.transformers.static_cache import ETStaticCache
from transformers import Phi3ForCausalLM


class Phi3Mini(torch.nn.Module):

def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
super().__init__()
self.model = model
self.cache = ETStaticCache(
config=model.config,
max_batch_size=max_batch_size,
max_cache_len=max_seq_len,
device=self.model.device,
dtype=self.model.dtype,
)

def forward(
self,
input_ids: torch.LongTensor = None,
cache_position: torch.LongTensor = None,
) -> torch.FloatTensor:
return self.model.forward(
input_ids=input_ids,
use_cache=True,
return_dict=True,
past_key_values=self.cache,
cache_position=cache_position,
).logits

0 comments on commit a7fffeb

Please sign in to comment.