-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add static KV cache and test on Gemma-2B #4
Changes from all commits
c84ee30
c48e7ee
de3acb4
c6dc4a7
60ad09c
5a94c67
15d16dd
9ff56ad
0999e21
bd264d9
8b78fd3
27a2669
be1194b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,9 @@ | |
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
import torch_xla.core.xla_model as xm | ||
from loguru import logger | ||
from transformers import AutoTokenizer, PreTrainedTokenizerBase | ||
from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache | ||
from transformers.generation import GenerationConfig | ||
|
||
from .modeling import TpuModelForCausalLM | ||
|
@@ -94,10 +95,11 @@ class State(Enum): | |
PAUSE = 1 | ||
READY = 2 | ||
|
||
def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase): | ||
def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase, device: [str, torch.device]): | ||
self._id = id | ||
self._tokenizer = tokenizer | ||
self.clear() | ||
self._device = device | ||
|
||
def clear(self): | ||
"""Clear the slot and mark it as available.""" | ||
|
@@ -106,7 +108,7 @@ def clear(self): | |
self._inputs = "" | ||
self._generation_config = None | ||
self._tokens = [] | ||
self._mask = [] | ||
self._mask = None | ||
self._selector = None | ||
self._generated_tokens = 0 | ||
self._next_text_token_start = 0 | ||
|
@@ -139,6 +141,10 @@ def generation_config(self) -> GenerationConfig: | |
def generated_tokens(self) -> int: | ||
return self._generated_tokens | ||
|
||
@property | ||
def cur_position(self) -> int: | ||
return self._next_text_token_start | ||
|
||
def assign(self, request: Request, generation_config: GenerationConfig): | ||
"""Assign a request to a slot. | ||
|
||
|
@@ -179,7 +185,10 @@ def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, s | |
self._next_text_token_start = 0 | ||
self._next_text_token_end = torch.numel(self._tokens) | ||
self._next_text = "" | ||
self._mask = attention_mask.clone() | ||
if attention_mask is not None: | ||
self._mask = attention_mask.clone() | ||
mfuntowicz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
self._mask = None | ||
self._selector = selector | ||
|
||
def pause(self): | ||
|
@@ -238,8 +247,12 @@ def append(self, next_token: int) -> str: | |
Return: | ||
The corresponding decoded text (if any). | ||
""" | ||
self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])]) | ||
self._mask = torch.cat([self._mask, torch.LongTensor([1])]) | ||
self._tokens = torch.cat( | ||
[self._tokens, torch.tensor([next_token], device=self._device, dtype=self._tokens.dtype)] | ||
) | ||
# Update mask only if it was set previously | ||
if self._mask is not None: | ||
self._mask = torch.cat([self._mask, torch.tensor([1], device=self._device, dtype=self._mask.dtype)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe for later: Does this concatenate can be replaced by an inplace set from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, I'll take a note. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. having said that: this is handled in a transparent way by models that use static cache, I guess they already do that inside the model. |
||
self._generated_tokens += 1 | ||
next_text = self._decode_next_tokens() | ||
# Now that a new token has been generated, we can append the previous one to the generated text | ||
|
@@ -296,8 +309,16 @@ def __init__( | |
tokenizer.padding_side = "left" | ||
self.tokenizer = tokenizer | ||
self.special_tokens = self.tokenizer.all_special_ids | ||
self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)] | ||
self.slots = [Slot(i, tokenizer, self.model.device) for i in range(self.model.config.batch_size)] | ||
self.past_key_values = None | ||
# _setup_cache is specific to some models (e.g.: Gemma and Llama). In those cases it is possible to setup | ||
# a static cache, otherwise it is not. | ||
self.use_static_cache = True | ||
if getattr(self.model, "_setup_cache", False) is False: | ||
logger.warning( | ||
f"Static cache not available for {self.model.__class__.__name__}. Performance will be affected" | ||
) | ||
self.use_static_cache = False | ||
|
||
@property | ||
def info(self) -> InfoResponse: | ||
|
@@ -326,8 +347,9 @@ def warmup(self, batch: Batch) -> int: | |
f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length." | ||
) | ||
self.prefill(batch) | ||
return self.model.config.batch_size * self.model.config.n_positions | ||
return self.model.config.batch_size * self.model.config.sequence_length | ||
|
||
@torch.no_grad | ||
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: | ||
"""Prefill new requests. | ||
|
||
|
@@ -361,9 +383,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: | |
# for unfinished requests. | ||
inputs = [slot.cached_text for slot in self.slots] | ||
# Tokenize with padding | ||
padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True) | ||
padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.model.device) | ||
# If needed truncate sequences to fit into the static dimensions | ||
seq_length = min(padded_inputs.input_ids.shape[-1], self.model.config.n_positions) | ||
seq_length = min(padded_inputs.input_ids.shape[-1], self.model.config.sequence_length) | ||
input_ids = padded_inputs.input_ids[:, :seq_length] | ||
attention_mask = padded_inputs.attention_mask[:, :seq_length] | ||
# Pause previously active slots during generation and store their last token. | ||
|
@@ -377,17 +399,36 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: | |
slot_input_ids = input_ids[i : i + 1, :] | ||
# Padded input ids are also required to set logits processors and stopping criterias | ||
selector = TokenSelector.create( | ||
slot_input_ids, slot.generation_config, self.model, self.model.config.n_positions, seed=slot.seed | ||
slot_input_ids, | ||
slot.generation_config, | ||
self.model, | ||
self.model.config.sequence_length, | ||
seed=slot.seed, | ||
) | ||
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) | ||
slot_attention_mask = attention_mask[i] | ||
if self.use_static_cache: | ||
# Attention mask does not need to be tracked when using static cache | ||
slot_attention_mask = None | ||
else: | ||
slot_attention_mask = attention_mask[i] | ||
slot.reset(slot_input_ids, slot_attention_mask, selector) | ||
# Clear KV cache | ||
self.past_key_values = None | ||
# Pause previously active slots during generation. | ||
# The KV cache of paused slots will be prefilled during generation but new tokens | ||
# will be ignored, as they have already been generated and sent back in the last decode. | ||
generation, next_batch = self._generate_token(batch.id, input_ids, attention_mask) | ||
# Obtain position ids using attention mask. | ||
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) | ||
position_ids = position_ids[:, -input_ids.shape[-1] :] | ||
|
||
extra_args = {} | ||
if self.use_static_cache: | ||
self.model._setup_cache(StaticCache, len(self.slots), self.model.config.sequence_length) | ||
extra_args["cache_position"] = torch.arange(seq_length, device=self.model.device) | ||
else: | ||
# Reset/clear KV cache | ||
self.past_key_values = None | ||
generation, next_batch = self._generate_token( | ||
batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids, **extra_args | ||
) | ||
|
||
# Reactivate previously active slots for the next decode, and append | ||
# back their next token. | ||
for slot, next_token in zip(active_slots, next_tokens): | ||
|
@@ -396,6 +437,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: | |
logger.debug("Model ready for decoding") | ||
return generation, next_batch | ||
|
||
@torch.no_grad | ||
def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: | ||
"""Decode the specified prefilled requests. | ||
|
||
|
@@ -416,46 +458,62 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa | |
# Reconstruct input_ids and attention_mask from slots | ||
input_ids = None | ||
attention_mask = None | ||
position_ids = torch.zeros( | ||
[self.model.config.batch_size, 1], | ||
dtype=torch.int64, | ||
device=self.model.device, | ||
) | ||
for i, slot in enumerate(self.slots): | ||
if slot.state != Slot.State.EMPTY: | ||
if input_ids is None: | ||
# Create blank inputs covering all slots (even empty ones) | ||
input_ids = torch.full( | ||
[self.model.config.batch_size, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64 | ||
[self.model.config.batch_size, 1], | ||
fill_value=self.tokenizer.eos_token_id, | ||
dtype=torch.int64, | ||
device=self.model.device, | ||
) | ||
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) | ||
input_ids[i, 0] = slot.next_token | ||
if attention_mask is None: | ||
# Create default mask covering all slots (even empty ones) | ||
attention_mask = torch.zeros( | ||
[self.model.config.batch_size, slot.attention_mask.size(-1)], dtype=torch.int64 | ||
) | ||
attention_mask[i, :] = slot.attention_mask | ||
if not self.use_static_cache: | ||
# When using dynamic cache, the whole attention mask needs to be passed over to the model at each iteration. | ||
if attention_mask is None: | ||
# Create default mask covering all slots (even empty ones) | ||
attention_mask = torch.zeros( | ||
[self.model.config.batch_size, slot.attention_mask.size(-1)], | ||
dtype=torch.int64, | ||
device=self.model.device, | ||
) | ||
attention_mask[i, :] = slot.attention_mask | ||
position_ids[i, 0] = slot.cur_position | ||
if input_ids is None: | ||
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") | ||
return self._generate_token(next_batch_id, input_ids, attention_mask) | ||
extra_args = {} | ||
if self.use_static_cache: | ||
extra_args["cache_position"] = position_ids.max().unsqueeze(0) | ||
else: | ||
extra_args["attention_mask"] = attention_mask | ||
extra_args["past_key_values"] = self.past_key_values | ||
return self._generate_token(next_batch_id, input_ids, position_ids=position_ids, **extra_args) | ||
|
||
def _generate_token( | ||
self, next_batch_id: int, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None | ||
self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params | ||
) -> Tuple[List[Generation], CachedBatch]: | ||
# Obtain position ids using attention mask. | ||
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) | ||
position_ids = position_ids[:, -input_ids.shape[-1] :] | ||
# Move input params to device | ||
input_ids = input_ids.to(self.model.device) | ||
attention_mask = attention_mask.to(self.model.device) | ||
position_ids = position_ids.to(self.model.device) | ||
# Add barrier to allow next graph step to always be the same | ||
xm.mark_step() | ||
# Forward | ||
outputs = self.model( | ||
input_ids, | ||
past_key_values=self.past_key_values, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
return_dict=True, | ||
use_cache=True, | ||
**forward_extra_params, | ||
) | ||
# Save KV cache | ||
self.past_key_values = outputs.past_key_values | ||
if not self.use_static_cache: | ||
# Save KV cache | ||
self.past_key_values = outputs.past_key_values | ||
# Barrier for XLA model | ||
xm.mark_step(wait=False) | ||
|
||
generations = [] | ||
active_slots = False | ||
for i, slot in enumerate(self.slots): | ||
|
@@ -507,7 +565,7 @@ def _generate_token( | |
|
||
def _cached_batch(self, batch_id: int, request_ids: List): | ||
size = len(request_ids) | ||
max_tokens = size * self.model.config.n_positions | ||
max_tokens = size * self.model.config.sequence_length | ||
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens) | ||
|
||
def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's do the conversion from
str
totorch.device()
right away here to ensure we can fail fast if this device doesn't exist and avoid overhead later down the road?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conversion does not make the check that the device is available. The only ways I found to check if the device is available is to invoke the
torch_xla
api directly. I can add a check before mapping the model if you wish.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed offline, adding such check is probably useless, given that the check will be done implicitly while mapping the model.