Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static KV cache and test on Gemma-2B #4

Merged
merged 13 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/test-pytorch-xla-tpu-tgi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ jobs:
run: python -c "import torch_xla.core.xla_model as xm; assert xm.xla_device().type == 'xla', 'XLA device not available'"

- name: Build and test TGI server
run: make tgi_test
run: |
pip install accelerate==0.27.2
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ line-length = 119
target-version = ['py38']
extend-exclude = '.ipynb'

[tool.ruff]
[lint.ruff]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
line-length = 119

# Ignore import violations in all `__init__.py` files.
[tool.ruff.per-file-ignores]
[lint.ruff.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]

[tool.ruff.isort]
[lint.ruff.isort]
lines-after-imports = 2
known-first-party = ["optimum.tpu"]

Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def docker_launcher(
trust_remote_code: bool = False,
):
# TODO: consider finding out how to forward a port in the container instead of leaving it to 80.
#For now this is necessary because TPU dockers require to run with net=host and privileged mode.
# For now this is necessary because TPU dockers require to run with net=host and privileged mode.
port = 80

args = ["--model-id", model_id, "--env"]
Expand Down
9 changes: 7 additions & 2 deletions text-generation-inference/integration-tests/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ async def test_model_single_request(tgi_client):
decoder_input_details=True,
)
assert response.details.generated_tokens == 17
assert response.generated_text == "\n\nDeep learning is a technique that allows you to learn something from a set of"
assert (
response.generated_text == "\n\nDeep learning is a technique that allows you to learn something from a set of"
)

# Greedy bounded with input
response = await tgi_client.generate(
Expand All @@ -64,7 +66,10 @@ async def test_model_single_request(tgi_client):
seed=42,
decoder_input_details=True,
)
assert 'The deep neural networks that we create are essentially "miniature" neural networks that can easily be trained' in response.generated_text
assert (
'The deep neural networks that we create are essentially "miniature" neural networks that can easily be trained'
in response.generated_text
)


@pytest.mark.asyncio
Expand Down
134 changes: 96 additions & 38 deletions text-generation-inference/server/text_generation_server/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import List, Optional, Tuple

import torch
import torch_xla.core.xla_model as xm
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache
from transformers.generation import GenerationConfig

from .modeling import TpuModelForCausalLM
Expand Down Expand Up @@ -94,10 +95,11 @@ class State(Enum):
PAUSE = 1
READY = 2

def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):
def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase, device: [str, torch.device]):
self._id = id
self._tokenizer = tokenizer
self.clear()
self._device = device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's do the conversion from str to torch.device() right away here to ensure we can fail fast if this device doesn't exist and avoid overhead later down the road?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion does not make the check that the device is available. The only ways I found to check if the device is available is to invoke the torch_xla api directly. I can add a check before mapping the model if you wish.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed offline, adding such check is probably useless, given that the check will be done implicitly while mapping the model.


def clear(self):
"""Clear the slot and mark it as available."""
Expand All @@ -106,7 +108,7 @@ def clear(self):
self._inputs = ""
self._generation_config = None
self._tokens = []
self._mask = []
self._mask = None
self._selector = None
self._generated_tokens = 0
self._next_text_token_start = 0
Expand Down Expand Up @@ -139,6 +141,10 @@ def generation_config(self) -> GenerationConfig:
def generated_tokens(self) -> int:
return self._generated_tokens

@property
def cur_position(self) -> int:
return self._next_text_token_start

def assign(self, request: Request, generation_config: GenerationConfig):
"""Assign a request to a slot.

Expand Down Expand Up @@ -179,7 +185,10 @@ def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, s
self._next_text_token_start = 0
self._next_text_token_end = torch.numel(self._tokens)
self._next_text = ""
self._mask = attention_mask.clone()
if attention_mask is not None:
self._mask = attention_mask.clone()
mfuntowicz marked this conversation as resolved.
Show resolved Hide resolved
else:
self._mask = None
self._selector = selector

def pause(self):
Expand Down Expand Up @@ -238,8 +247,12 @@ def append(self, next_token: int) -> str:
Return:
The corresponding decoded text (if any).
"""
self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])])
self._mask = torch.cat([self._mask, torch.LongTensor([1])])
self._tokens = torch.cat(
[self._tokens, torch.tensor([next_token], device=self._device, dtype=self._tokens.dtype)]
)
# Update mask only if it was set previously
if self._mask is not None:
self._mask = torch.cat([self._mask, torch.tensor([1], device=self._device, dtype=self._mask.dtype)])
Copy link
Member

@mfuntowicz mfuntowicz Mar 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for later: Does this concatenate can be replaced by an inplace set from 0 to 1 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'll take a note.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having said that: this is handled in a transparent way by models that use static cache, I guess they already do that inside the model.

self._generated_tokens += 1
next_text = self._decode_next_tokens()
# Now that a new token has been generated, we can append the previous one to the generated text
Expand Down Expand Up @@ -296,8 +309,16 @@ def __init__(
tokenizer.padding_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)]
self.slots = [Slot(i, tokenizer, self.model.device) for i in range(self.model.config.batch_size)]
self.past_key_values = None
# _setup_cache is specific to some models (e.g.: Gemma and Llama). In those cases it is possible to setup
# a static cache, otherwise it is not.
self.use_static_cache = True
if getattr(self.model, "_setup_cache", False) is False:
logger.warning(
f"Static cache not available for {self.model.__class__.__name__}. Performance will be affected"
)
self.use_static_cache = False

@property
def info(self) -> InfoResponse:
Expand Down Expand Up @@ -326,8 +347,9 @@ def warmup(self, batch: Batch) -> int:
f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
)
self.prefill(batch)
return self.model.config.batch_size * self.model.config.n_positions
return self.model.config.batch_size * self.model.config.sequence_length

@torch.no_grad
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests.

Expand Down Expand Up @@ -361,9 +383,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# for unfinished requests.
inputs = [slot.cached_text for slot in self.slots]
# Tokenize with padding
padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True)
padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.model.device)
# If needed truncate sequences to fit into the static dimensions
seq_length = min(padded_inputs.input_ids.shape[-1], self.model.config.n_positions)
seq_length = min(padded_inputs.input_ids.shape[-1], self.model.config.sequence_length)
input_ids = padded_inputs.input_ids[:, :seq_length]
attention_mask = padded_inputs.attention_mask[:, :seq_length]
# Pause previously active slots during generation and store their last token.
Expand All @@ -377,17 +399,36 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
slot_input_ids = input_ids[i : i + 1, :]
# Padded input ids are also required to set logits processors and stopping criterias
selector = TokenSelector.create(
slot_input_ids, slot.generation_config, self.model, self.model.config.n_positions, seed=slot.seed
slot_input_ids,
slot.generation_config,
self.model,
self.model.config.sequence_length,
seed=slot.seed,
)
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
slot_attention_mask = attention_mask[i]
if self.use_static_cache:
# Attention mask does not need to be tracked when using static cache
slot_attention_mask = None
else:
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
# Clear KV cache
self.past_key_values = None
# Pause previously active slots during generation.
# The KV cache of paused slots will be prefilled during generation but new tokens
# will be ignored, as they have already been generated and sent back in the last decode.
generation, next_batch = self._generate_token(batch.id, input_ids, attention_mask)
# Obtain position ids using attention mask.
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
position_ids = position_ids[:, -input_ids.shape[-1] :]

extra_args = {}
if self.use_static_cache:
self.model._setup_cache(StaticCache, len(self.slots), self.model.config.sequence_length)
extra_args["cache_position"] = torch.arange(seq_length, device=self.model.device)
else:
# Reset/clear KV cache
self.past_key_values = None
generation, next_batch = self._generate_token(
batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids, **extra_args
)

# Reactivate previously active slots for the next decode, and append
# back their next token.
for slot, next_token in zip(active_slots, next_tokens):
Expand All @@ -396,6 +437,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
logger.debug("Model ready for decoding")
return generation, next_batch

@torch.no_grad
def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
"""Decode the specified prefilled requests.

Expand All @@ -416,46 +458,62 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# Reconstruct input_ids and attention_mask from slots
input_ids = None
attention_mask = None
position_ids = torch.zeros(
[self.model.config.batch_size, 1],
dtype=torch.int64,
device=self.model.device,
)
for i, slot in enumerate(self.slots):
if slot.state != Slot.State.EMPTY:
if input_ids is None:
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[self.model.config.batch_size, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64
[self.model.config.batch_size, 1],
fill_value=self.tokenizer.eos_token_id,
dtype=torch.int64,
device=self.model.device,
)
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token
if attention_mask is None:
# Create default mask covering all slots (even empty ones)
attention_mask = torch.zeros(
[self.model.config.batch_size, slot.attention_mask.size(-1)], dtype=torch.int64
)
attention_mask[i, :] = slot.attention_mask
if not self.use_static_cache:
# When using dynamic cache, the whole attention mask needs to be passed over to the model at each iteration.
if attention_mask is None:
# Create default mask covering all slots (even empty ones)
attention_mask = torch.zeros(
[self.model.config.batch_size, slot.attention_mask.size(-1)],
dtype=torch.int64,
device=self.model.device,
)
attention_mask[i, :] = slot.attention_mask
position_ids[i, 0] = slot.cur_position
if input_ids is None:
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
return self._generate_token(next_batch_id, input_ids, attention_mask)
extra_args = {}
if self.use_static_cache:
extra_args["cache_position"] = position_ids.max().unsqueeze(0)
else:
extra_args["attention_mask"] = attention_mask
extra_args["past_key_values"] = self.past_key_values
return self._generate_token(next_batch_id, input_ids, position_ids=position_ids, **extra_args)

def _generate_token(
self, next_batch_id: int, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None
self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params
) -> Tuple[List[Generation], CachedBatch]:
# Obtain position ids using attention mask.
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
position_ids = position_ids[:, -input_ids.shape[-1] :]
# Move input params to device
input_ids = input_ids.to(self.model.device)
attention_mask = attention_mask.to(self.model.device)
position_ids = position_ids.to(self.model.device)
# Add barrier to allow next graph step to always be the same
xm.mark_step()
# Forward
outputs = self.model(
input_ids,
past_key_values=self.past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
use_cache=True,
**forward_extra_params,
)
# Save KV cache
self.past_key_values = outputs.past_key_values
if not self.use_static_cache:
# Save KV cache
self.past_key_values = outputs.past_key_values
# Barrier for XLA model
xm.mark_step(wait=False)

generations = []
active_slots = False
for i, slot in enumerate(self.slots):
Expand Down Expand Up @@ -507,7 +565,7 @@ def _generate_token(

def _cached_batch(self, batch_id: int, request_ids: List):
size = len(request_ids)
max_tokens = size * self.model.config.n_positions
max_tokens = size * self.model.config.sequence_length
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)

def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import time
from pathlib import Path
from typing import Optional

from huggingface_hub import snapshot_download
from loguru import logger
from transformers import AutoConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from huggingface_hub import snapshot_download


def get_export_kwargs_from_env():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from typing import Any

import torch
import torch_xla.core.xla_model as xm
from loguru import logger
from transformers import AutoModelForCausalLM
from transformers.utils import is_accelerate_available


# TODO: For now TpuModelForCausalLM is just a shallow wrapper of
Expand All @@ -38,7 +38,23 @@ def from_pretrained(
*model_args: Any,
**kwargs: Any,
):
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
if "PJRT_DEVICE" not in environ:
logger.info("PJRT_DEVICE environment variable not found. Setting it to 'TPU'.")
environ["PJRT_DEVICE"] = "TPU"
if "DBG_DEVICE" in environ:
device = environ["DBG_DEVICE"]
logger.debug(f"Device set to: {device}")
else:
device = "xla"
if is_accelerate_available():
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, device_map=device, *model_args, **kwargs
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
model.to(device)
# Update config with specific data)
if task is not None or getattr(model.config, "task", None) is None:
model.config.task = task
Expand All @@ -47,13 +63,10 @@ def from_pretrained(
if sequence_length is not None or getattr(model.config, "sequence_length", None) is None:
model.config.sequence_length = sequence_length

if "PJRT_DEVICE" not in environ:
logger.warning("PJRT_DEVICE environment variable not found. Setting it to 'TPU'.")
environ["PJRT_DEVICE"] = "TPU"
dev = xm.xla_device()
# Do eval, move model to device and compile
model.to(dev)
# Do eval, and compile
model.eval()
model = torch.compile(model, backend="openxla_eval")
if device == "xla" and "DBG_COMPILE" in environ:
model = torch.compile(model, backend="openxla_eval")
logger.debug("Model compiled.")

return model
Loading