From c84ee307e65ea4c90250fd11b9372b84d89607b3 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Tue, 5 Mar 2024 16:59:49 +0100 Subject: [PATCH 01/13] chore(style): run make style --- text-generation-inference/integration-tests/conftest.py | 2 +- text-generation-inference/integration-tests/test_gpt2.py | 9 +++++++-- .../server/text_generation_server/model.py | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/text-generation-inference/integration-tests/conftest.py b/text-generation-inference/integration-tests/conftest.py index f1abf365..5fb08ec6 100644 --- a/text-generation-inference/integration-tests/conftest.py +++ b/text-generation-inference/integration-tests/conftest.py @@ -85,7 +85,7 @@ def docker_launcher( trust_remote_code: bool = False, ): # TODO: consider finding out how to forward a port in the container instead of leaving it to 80. - #For now this is necessary because TPU dockers require to run with net=host and privileged mode. + # For now this is necessary because TPU dockers require to run with net=host and privileged mode. port = 80 args = ["--model-id", model_id, "--env"] diff --git a/text-generation-inference/integration-tests/test_gpt2.py b/text-generation-inference/integration-tests/test_gpt2.py index 60de1b1f..3219ddd5 100644 --- a/text-generation-inference/integration-tests/test_gpt2.py +++ b/text-generation-inference/integration-tests/test_gpt2.py @@ -38,7 +38,9 @@ async def test_model_single_request(tgi_client): decoder_input_details=True, ) assert response.details.generated_tokens == 17 - assert response.generated_text == "\n\nDeep learning is a technique that allows you to learn something from a set of" + assert ( + response.generated_text == "\n\nDeep learning is a technique that allows you to learn something from a set of" + ) # Greedy bounded with input response = await tgi_client.generate( @@ -64,7 +66,10 @@ async def test_model_single_request(tgi_client): seed=42, decoder_input_details=True, ) - assert 'The deep neural networks that we create are essentially "miniature" neural networks that can easily be trained' in response.generated_text + assert ( + 'The deep neural networks that we create are essentially "miniature" neural networks that can easily be trained' + in response.generated_text + ) @pytest.mark.asyncio diff --git a/text-generation-inference/server/text_generation_server/model.py b/text-generation-inference/server/text_generation_server/model.py index 89e3a769..f6af2d48 100644 --- a/text-generation-inference/server/text_generation_server/model.py +++ b/text-generation-inference/server/text_generation_server/model.py @@ -2,10 +2,11 @@ import time from pathlib import Path from typing import Optional + +from huggingface_hub import snapshot_download from loguru import logger from transformers import AutoConfig from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from huggingface_hub import snapshot_download def get_export_kwargs_from_env(): From c48e7eec47f2e32bbb482bf41b143daf35404c89 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Tue, 5 Mar 2024 17:00:50 +0100 Subject: [PATCH 02/13] chore(style): update pyproject to avoid ruff warning --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fef57bd8..d21952a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,17 +17,17 @@ line-length = 119 target-version = ['py38'] extend-exclude = '.ipynb' -[tool.ruff] +[lint.ruff] # Never enforce `E501` (line length violations). ignore = ["C901", "E501", "E741", "W605"] select = ["C", "E", "F", "I", "W"] line-length = 119 # Ignore import violations in all `__init__.py` files. -[tool.ruff.per-file-ignores] +[lint.ruff.per-file-ignores] "__init__.py" = ["E402", "F401", "F403", "F811"] -[tool.ruff.isort] +[lint.ruff.isort] lines-after-imports = 2 known-first-party = ["optimum.tpu"] From de3acb46f195b92811d6dddf1d489f1a8d4d414c Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 7 Mar 2024 15:22:38 +0000 Subject: [PATCH 03/13] fix(tgi): sequence length should be based on sequence_length config It was previously using n_positions sometimes, but that would not be available on some model configs. --- .../server/text_generation_server/generator.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 7dfb46f3..bfac0b7b 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -326,7 +326,7 @@ def warmup(self, batch: Batch) -> int: f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length." ) self.prefill(batch) - return self.model.config.batch_size * self.model.config.n_positions + return self.model.config.batch_size * self.model.config.sequence_length def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: """Prefill new requests. @@ -363,7 +363,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # Tokenize with padding padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True) # If needed truncate sequences to fit into the static dimensions - seq_length = min(padded_inputs.input_ids.shape[-1], self.model.config.n_positions) + seq_length = min(padded_inputs.input_ids.shape[-1], self.model.config.sequence_length) input_ids = padded_inputs.input_ids[:, :seq_length] attention_mask = padded_inputs.attention_mask[:, :seq_length] # Pause previously active slots during generation and store their last token. @@ -377,7 +377,11 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: slot_input_ids = input_ids[i : i + 1, :] # Padded input ids are also required to set logits processors and stopping criterias selector = TokenSelector.create( - slot_input_ids, slot.generation_config, self.model, self.model.config.n_positions, seed=slot.seed + slot_input_ids, + slot.generation_config, + self.model, + self.model.config.sequence_length, + seed=slot.seed, ) slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) slot_attention_mask = attention_mask[i] @@ -507,7 +511,7 @@ def _generate_token( def _cached_batch(self, batch_id: int, request_ids: List): size = len(request_ids) - max_tokens = size * self.model.config.n_positions + max_tokens = size * self.model.config.sequence_length return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens) def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: From c6dc4a7d2c8975d35643f50b71ba5e415011c697 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 7 Mar 2024 15:39:59 +0000 Subject: [PATCH 04/13] feat(modeling): model is immediately loaded on device --- .../server/text_generation_server/modeling.py | 22 +++++++++++++------ .../tests/test_generator_slot.py | 3 ++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/modeling.py b/text-generation-inference/server/text_generation_server/modeling.py index 3814c07c..5f4bcceb 100644 --- a/text-generation-inference/server/text_generation_server/modeling.py +++ b/text-generation-inference/server/text_generation_server/modeling.py @@ -22,6 +22,7 @@ import torch_xla.core.xla_model as xm from loguru import logger from transformers import AutoModelForCausalLM +from transformers.utils import is_accelerate_available # TODO: For now TpuModelForCausalLM is just a shallow wrapper of @@ -38,7 +39,19 @@ def from_pretrained( *model_args: Any, **kwargs: Any, ): - model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + if "PJRT_DEVICE" not in environ: + logger.info("PJRT_DEVICE environment variable not found. Setting it to 'TPU'.") + environ["PJRT_DEVICE"] = "TPU" + device = xm.xla_device() + if is_accelerate_available(): + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, device_map=device, *model_args, **kwargs + ) + else: + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, *model_args, **kwargs + ) + model.to(device) # Update config with specific data) if task is not None or getattr(model.config, "task", None) is None: model.config.task = task @@ -47,12 +60,7 @@ def from_pretrained( if sequence_length is not None or getattr(model.config, "sequence_length", None) is None: model.config.sequence_length = sequence_length - if "PJRT_DEVICE" not in environ: - logger.warning("PJRT_DEVICE environment variable not found. Setting it to 'TPU'.") - environ["PJRT_DEVICE"] = "TPU" - dev = xm.xla_device() - # Do eval, move model to device and compile - model.to(dev) + # Do eval, and compile model.eval() model = torch.compile(model, backend="openxla_eval") diff --git a/text-generation-inference/tests/test_generator_slot.py b/text-generation-inference/tests/test_generator_slot.py index c1a85d72..7578dbfa 100644 --- a/text-generation-inference/tests/test_generator_slot.py +++ b/text-generation-inference/tests/test_generator_slot.py @@ -31,7 +31,8 @@ def tokenizer(request): ids=["spaces", "chinese-utf8", "emojis"], ) def test_decode_streaming(tokenizer, input_text, generated_text): - slot = Slot(0, tokenizer) + # Note: device used is cpu to make it faster + slot = Slot(0, tokenizer, "cpu") request = Request(id=0, inputs=input_text) slot.assign(request, GenerationConfig()) assert slot.cached_text == input_text From 60ad09ce79a31155b2d2730ad12c1112445d1556 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 7 Mar 2024 17:23:00 +0000 Subject: [PATCH 05/13] debug: added env var to debug on CPU if DBG_DEVICE env var is set, it will used to set the device for the model. --- .../server/text_generation_server/modeling.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/modeling.py b/text-generation-inference/server/text_generation_server/modeling.py index 5f4bcceb..3b1a3a9e 100644 --- a/text-generation-inference/server/text_generation_server/modeling.py +++ b/text-generation-inference/server/text_generation_server/modeling.py @@ -19,7 +19,6 @@ from typing import Any import torch -import torch_xla.core.xla_model as xm from loguru import logger from transformers import AutoModelForCausalLM from transformers.utils import is_accelerate_available @@ -42,7 +41,11 @@ def from_pretrained( if "PJRT_DEVICE" not in environ: logger.info("PJRT_DEVICE environment variable not found. Setting it to 'TPU'.") environ["PJRT_DEVICE"] = "TPU" - device = xm.xla_device() + if "DBG_DEVICE" in environ: + device = environ["DBG_DEVICE"] + logger.debug(f"Device set to: {device}") + else: + device = "xla" if is_accelerate_available(): model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, device_map=device, *model_args, **kwargs @@ -62,6 +65,7 @@ def from_pretrained( # Do eval, and compile model.eval() - model = torch.compile(model, backend="openxla_eval") + if device == "xla": + model = torch.compile(model, backend="openxla_eval") return model From 5a94c67cba29af21a3b9f015ef5a5030461da895 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 8 Mar 2024 09:33:54 +0000 Subject: [PATCH 06/13] feat(test): reduce overhad when retrieving model This will avoid loading the model twice. --- .../tests/test_generator.py | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/text-generation-inference/tests/test_generator.py b/text-generation-inference/tests/test_generator.py index 9fc07b04..fa7bfda4 100644 --- a/text-generation-inference/tests/test_generator.py +++ b/text-generation-inference/tests/test_generator.py @@ -1,14 +1,13 @@ -from tempfile import TemporaryDirectory - import pytest -from text_generation_server.generator import TpuGenerator, TpuModelForCausalLM +import os +from text_generation_server.generator import TpuGenerator +from text_generation_server.model import fetch_model from text_generation_server.pb.generate_pb2 import ( Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters, ) -from transformers import AutoTokenizer MODEL_ID = "openai-community/gpt2" @@ -18,18 +17,11 @@ @pytest.fixture(scope="module") def model_path(): - with TemporaryDirectory() as tmpdir: - AutoTokenizer.from_pretrained(MODEL_ID).save_pretrained(tmpdir) - model = TpuModelForCausalLM.from_pretrained( - MODEL_ID, - batch_size=BATCH_SIZE, - sequence_length=SEQUENCE_LENGTH, - ) - # Move model to cpu before saving. - # TODO: later on this should be handled by TpuModelForCausalLM - model.to("cpu") - model.save_pretrained(tmpdir) - yield tmpdir + # Add variables to environment so they can be used in TpuModelForCausalLM + os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE) + os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH) + path = fetch_model(MODEL_ID) + return path def test_info(model_path): From 15d16dd918d32e3c5ca030f0b81e0ee5a683a20a Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 8 Mar 2024 10:26:28 +0000 Subject: [PATCH 07/13] feat(modeling): make compilation optional Make compilation optional, it can be enabled with the environment variable DBG_COMPILE. This is because: 1. There are some models that produce bugs when the model is compiled. (notably gemma). 2. Models inference input params shapes change, triggering recompilation, leading to slow performance. 3. With the added xm.mark_step, performance is actually better when the model is not compiled. XLA builds a graph anyway, so performance is going to be good. --- .../server/text_generation_server/generator.py | 4 ++++ .../server/text_generation_server/modeling.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index bfac0b7b..0563639a 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -6,6 +6,7 @@ from typing import List, Optional, Tuple import torch +import torch_xla.core.xla_model as xm from loguru import logger from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers.generation import GenerationConfig @@ -460,6 +461,9 @@ def _generate_token( ) # Save KV cache self.past_key_values = outputs.past_key_values + # Barrier for XLA model + xm.mark_step(wait=False) + generations = [] active_slots = False for i, slot in enumerate(self.slots): diff --git a/text-generation-inference/server/text_generation_server/modeling.py b/text-generation-inference/server/text_generation_server/modeling.py index 3b1a3a9e..dd420e95 100644 --- a/text-generation-inference/server/text_generation_server/modeling.py +++ b/text-generation-inference/server/text_generation_server/modeling.py @@ -65,7 +65,8 @@ def from_pretrained( # Do eval, and compile model.eval() - if device == "xla": + if device == "xla" and "DBG_COMPILE" in environ: model = torch.compile(model, backend="openxla_eval") + logger.debug("Model compiled.") return model From 9ff56ad0951d808bd4ce0b7c19f81ecf9a5f8ee4 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Mon, 11 Mar 2024 15:44:27 +0000 Subject: [PATCH 08/13] feat: add @torch.no_grad decorators to decode and prefill This is to reduce useless gradient calculations. --- .../server/text_generation_server/generator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 0563639a..d1445958 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -329,6 +329,7 @@ def warmup(self, batch: Batch) -> int: self.prefill(batch) return self.model.config.batch_size * self.model.config.sequence_length + @torch.no_grad def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: """Prefill new requests. @@ -401,6 +402,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: logger.debug("Model ready for decoding") return generation, next_batch + @torch.no_grad def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: """Decode the specified prefilled requests. From 0999e21d3503c06b1084b9e3a41aef444cd6702e Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Mon, 11 Mar 2024 15:50:25 +0000 Subject: [PATCH 09/13] chore(generator): create buffers in device to avoid moving them --- .../text_generation_server/generator.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index d1445958..52241b52 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -95,10 +95,11 @@ class State(Enum): PAUSE = 1 READY = 2 - def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase): + def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase, device: [str, torch.device]): self._id = id self._tokenizer = tokenizer self.clear() + self._device = device def clear(self): """Clear the slot and mark it as available.""" @@ -239,8 +240,10 @@ def append(self, next_token: int) -> str: Return: The corresponding decoded text (if any). """ - self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])]) - self._mask = torch.cat([self._mask, torch.LongTensor([1])]) + self._tokens = torch.cat( + [self._tokens, torch.tensor([next_token], device=self._device, dtype=self._tokens.dtype)] + ) + self._mask = torch.cat([self._mask, torch.tensor([1], device=self._device, dtype=self._mask.dtype)]) self._generated_tokens += 1 next_text = self._decode_next_tokens() # Now that a new token has been generated, we can append the previous one to the generated text @@ -297,7 +300,7 @@ def __init__( tokenizer.padding_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids - self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)] + self.slots = [Slot(i, tokenizer, self.model.device) for i in range(self.model.config.batch_size)] self.past_key_values = None @property @@ -363,7 +366,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # for unfinished requests. inputs = [slot.cached_text for slot in self.slots] # Tokenize with padding - padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True) + padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.model.device) # If needed truncate sequences to fit into the static dimensions seq_length = min(padded_inputs.input_ids.shape[-1], self.model.config.sequence_length) input_ids = padded_inputs.input_ids[:, :seq_length] @@ -428,14 +431,19 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa if input_ids is None: # Create blank inputs covering all slots (even empty ones) input_ids = torch.full( - [self.model.config.batch_size, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64 + [self.model.config.batch_size, 1], + fill_value=self.tokenizer.eos_token_id, + dtype=torch.int64, + device=self.model.device, ) # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) input_ids[i, 0] = slot.next_token if attention_mask is None: # Create default mask covering all slots (even empty ones) attention_mask = torch.zeros( - [self.model.config.batch_size, slot.attention_mask.size(-1)], dtype=torch.int64 + [self.model.config.batch_size, slot.attention_mask.size(-1)], + dtype=torch.int64, + device=self.model.device, ) attention_mask[i, :] = slot.attention_mask if input_ids is None: @@ -448,10 +456,7 @@ def _generate_token( # Obtain position ids using attention mask. position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) position_ids = position_ids[:, -input_ids.shape[-1] :] - # Move input params to device - input_ids = input_ids.to(self.model.device) - attention_mask = attention_mask.to(self.model.device) - position_ids = position_ids.to(self.model.device) + # Forward outputs = self.model( input_ids, From bd264d941cde06b39513d49ca67babb322f15781 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Tue, 12 Mar 2024 10:39:44 +0000 Subject: [PATCH 10/13] refactor(generator): some model params are passed as dict This will allow to handle passing different params in different model configurations later. --- .../text_generation_server/generator.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 52241b52..8a63f320 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -393,10 +393,14 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: slot.reset(slot_input_ids, slot_attention_mask, selector) # Clear KV cache self.past_key_values = None + # Obtain position ids using attention mask. + position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) + position_ids = position_ids[:, -input_ids.shape[-1] :] + # Pause previously active slots during generation. # The KV cache of paused slots will be prefilled during generation but new tokens # will be ignored, as they have already been generated and sent back in the last decode. - generation, next_batch = self._generate_token(batch.id, input_ids, attention_mask) + generation, next_batch = self._generate_token(batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids) # Reactivate previously active slots for the next decode, and append # back their next token. for slot, next_token in zip(active_slots, next_tokens): @@ -426,6 +430,11 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa # Reconstruct input_ids and attention_mask from slots input_ids = None attention_mask = None + position_ids = torch.zeros( + [self.model.config.batch_size, 1], + dtype=torch.int64, + device=self.model.device, + ) for i, slot in enumerate(self.slots): if slot.state != Slot.State.EMPTY: if input_ids is None: @@ -446,25 +455,21 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa device=self.model.device, ) attention_mask[i, :] = slot.attention_mask + position_ids[i, 0] = slot.generated_tokens if input_ids is None: raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") - return self._generate_token(next_batch_id, input_ids, attention_mask) - def _generate_token( - self, next_batch_id: int, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None - ) -> Tuple[List[Generation], CachedBatch]: - # Obtain position ids using attention mask. - position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) - position_ids = position_ids[:, -input_ids.shape[-1] :] + return self._generate_token(next_batch_id, input_ids, attention_mask=attention_mask, position_ids=position_ids) + def _generate_token( + self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params) -> Tuple[List[Generation], CachedBatch]: # Forward outputs = self.model( input_ids, past_key_values=self.past_key_values, - attention_mask=attention_mask, - position_ids=position_ids, return_dict=True, use_cache=True, + **forward_extra_params, ) # Save KV cache self.past_key_values = outputs.past_key_values From 8b78fd3e5dd87e8ade6d62433f5742e2c6d147ff Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Tue, 12 Mar 2024 14:11:45 +0000 Subject: [PATCH 11/13] feat: use static KV cache when available Some models, like Gemma and Llama, support static KV cache in transformers. For these, it is possible to use this feature, leading to much higher performance. --- .../text_generation_server/generator.py | 86 +++++++++++++------ .../tests/test_generator_gemma.py | 71 +++++++++++++++ 2 files changed, 133 insertions(+), 24 deletions(-) create mode 100644 text-generation-inference/tests/test_generator_gemma.py diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 8a63f320..2a282a82 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -8,7 +8,7 @@ import torch import torch_xla.core.xla_model as xm from loguru import logger -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache from transformers.generation import GenerationConfig from .modeling import TpuModelForCausalLM @@ -108,7 +108,7 @@ def clear(self): self._inputs = "" self._generation_config = None self._tokens = [] - self._mask = [] + self._mask = None self._selector = None self._generated_tokens = 0 self._next_text_token_start = 0 @@ -141,6 +141,10 @@ def generation_config(self) -> GenerationConfig: def generated_tokens(self) -> int: return self._generated_tokens + @property + def cur_position(self) -> int: + return self._next_text_token_start + def assign(self, request: Request, generation_config: GenerationConfig): """Assign a request to a slot. @@ -181,7 +185,10 @@ def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, s self._next_text_token_start = 0 self._next_text_token_end = torch.numel(self._tokens) self._next_text = "" - self._mask = attention_mask.clone() + if attention_mask is not None: + self._mask = attention_mask.clone() + else: + self._mask = None self._selector = selector def pause(self): @@ -243,7 +250,9 @@ def append(self, next_token: int) -> str: self._tokens = torch.cat( [self._tokens, torch.tensor([next_token], device=self._device, dtype=self._tokens.dtype)] ) - self._mask = torch.cat([self._mask, torch.tensor([1], device=self._device, dtype=self._mask.dtype)]) + # Update mask only if it was set previously + if self._mask is not None: + self._mask = torch.cat([self._mask, torch.tensor([1], device=self._device, dtype=self._mask.dtype)]) self._generated_tokens += 1 next_text = self._decode_next_tokens() # Now that a new token has been generated, we can append the previous one to the generated text @@ -302,6 +311,14 @@ def __init__( self.special_tokens = self.tokenizer.all_special_ids self.slots = [Slot(i, tokenizer, self.model.device) for i in range(self.model.config.batch_size)] self.past_key_values = None + # _setup_cache is specific to some models (e.g.: Gemma and Llama). In those cases it is possible to setup + # a static cache, otherwise it is not. + self.use_static_cache = True + if getattr(self.model, "_setup_cache", False) is False: + logger.warning( + f"Static cache not available for {self.model.__class__.__name__}. Performance will be affected" + ) + self.use_static_cache = False @property def info(self) -> InfoResponse: @@ -389,7 +406,11 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: seed=slot.seed, ) slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) - slot_attention_mask = attention_mask[i] + if self.use_static_cache: + # Attention mask does not need to be tracked when using static cache + slot_attention_mask = None + else: + slot_attention_mask = attention_mask[i] slot.reset(slot_input_ids, slot_attention_mask, selector) # Clear KV cache self.past_key_values = None @@ -397,10 +418,17 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) position_ids = position_ids[:, -input_ids.shape[-1] :] - # Pause previously active slots during generation. - # The KV cache of paused slots will be prefilled during generation but new tokens - # will be ignored, as they have already been generated and sent back in the last decode. - generation, next_batch = self._generate_token(batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids) + extra_args = {} + if self.use_static_cache: + self.model._setup_cache(StaticCache, len(self.slots), self.model.config.sequence_length) + extra_args["cache_position"] = torch.arange(seq_length, device=self.model.device) + else: + # Reset/clear KV cache + self.past_key_values = None + generation, next_batch = self._generate_token( + batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids, **extra_args + ) + # Reactivate previously active slots for the next decode, and append # back their next token. for slot, next_token in zip(active_slots, next_tokens): @@ -447,32 +475,42 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa ) # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) input_ids[i, 0] = slot.next_token - if attention_mask is None: - # Create default mask covering all slots (even empty ones) - attention_mask = torch.zeros( - [self.model.config.batch_size, slot.attention_mask.size(-1)], - dtype=torch.int64, - device=self.model.device, - ) - attention_mask[i, :] = slot.attention_mask - position_ids[i, 0] = slot.generated_tokens + if not self.use_static_cache: + # When using dynamic cache, the whole attention mask needs to be passed over to the model at each iteration. + if attention_mask is None: + # Create default mask covering all slots (even empty ones) + attention_mask = torch.zeros( + [self.model.config.batch_size, slot.attention_mask.size(-1)], + dtype=torch.int64, + device=self.model.device, + ) + attention_mask[i, :] = slot.attention_mask + position_ids[i, 0] = slot.cur_position if input_ids is None: raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") - - return self._generate_token(next_batch_id, input_ids, attention_mask=attention_mask, position_ids=position_ids) + extra_args = {} + if self.use_static_cache: + extra_args["cache_position"] = position_ids.max().unsqueeze(0) + else: + extra_args["attention_mask"] = attention_mask + extra_args["past_key_values"] = self.past_key_values + return self._generate_token(next_batch_id, input_ids, position_ids=position_ids, **extra_args) def _generate_token( - self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params) -> Tuple[List[Generation], CachedBatch]: + self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params + ) -> Tuple[List[Generation], CachedBatch]: + # Add barrier to allow next graph step to always be the same + xm.mark_step() # Forward outputs = self.model( input_ids, - past_key_values=self.past_key_values, return_dict=True, use_cache=True, **forward_extra_params, ) - # Save KV cache - self.past_key_values = outputs.past_key_values + if not self.use_static_cache: + # Save KV cache + self.past_key_values = outputs.past_key_values # Barrier for XLA model xm.mark_step(wait=False) diff --git a/text-generation-inference/tests/test_generator_gemma.py b/text-generation-inference/tests/test_generator_gemma.py new file mode 100644 index 00000000..aa7fd4bb --- /dev/null +++ b/text-generation-inference/tests/test_generator_gemma.py @@ -0,0 +1,71 @@ +import pytest +import os +from text_generation_server.generator import TpuGenerator +from text_generation_server.model import fetch_model +from text_generation_server.pb.generate_pb2 import ( + Batch, + NextTokenChooserParameters, + Request, + StoppingCriteriaParameters, +) + + +MODEL_ID = "google/gemma-2b" +BATCH_SIZE = 4 +SEQUENCE_LENGTH = 1024 + + +@pytest.fixture(scope="module") +def model_path(): + # Add variables to environment so they can be used in TpuModelForCausalLM + os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE) + os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH) + path = fetch_model(MODEL_ID) + return path + + +def create_request( + id: int, + inputs: str, + max_new_tokens=20, + do_sample: bool = False, + top_k: int = 50, + top_p: float = 0.9, + temperature: float = 1.0, + seed: int = 0, + repetition_penalty: float = 1.0, +): + parameters = NextTokenChooserParameters( + temperature=temperature, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + seed=seed, + repetition_penalty=repetition_penalty, + ) + stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens) + return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters) + + +def test_decode_single(model_path): + input_text = "It was a bright cold day in April, and the clocks were striking thirteen." + max_new_tokens = 20 + generated_text = "\n\nThe first thing I noticed was the smell of the rain. It was a smell I had never" + + generator = TpuGenerator.from_pretrained(model_path) + request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) + batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) + generations, next_batch = generator.prefill(batch) + # We already generated one token: call decode max_new_tokens - 1 times + for _ in range(max_new_tokens - 1): + assert next_batch.size == 1 + assert next_batch.max_tokens == 1024 + assert len(generations) == 1 + assert len(generations[0].tokens.ids) == 1 + generations, next_batch = generator.decode([next_batch]) + assert next_batch is None + assert len(generations) == 1 + output = generations[0].generated_text + assert output.generated_tokens == max_new_tokens + assert output.finish_reason == 0 + assert output.text == generated_text From 27a2669e646b66372327cc4fa1765af58d0df914 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 13 Mar 2024 16:01:50 +0000 Subject: [PATCH 12/13] fix(CI): added HF_TOKEN to use models that require it Also manually install accelerate to avoid memory issues when loading gemma. --- .github/workflows/test-pytorch-xla-tpu-tgi.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi.yml b/.github/workflows/test-pytorch-xla-tpu-tgi.yml index 83c8a232..22286cb9 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi.yml @@ -32,4 +32,6 @@ jobs: run: python -c "import torch_xla.core.xla_model as xm; assert xm.xla_device().type == 'xla', 'XLA device not available'" - name: Build and test TGI server - run: make tgi_test + run: | + pip install accelerate==0.27.2 + HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test From be1194bef6ba30edc9063f7492dfb7830eb9ec50 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 14 Mar 2024 14:48:25 +0000 Subject: [PATCH 13/13] fix(CI): adapt expected result in do_sample test The test produces different results after some operations are being done in a slightly different order. --- text-generation-inference/tests/test_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text-generation-inference/tests/test_generator.py b/text-generation-inference/tests/test_generator.py index fa7bfda4..6baa8547 100644 --- a/text-generation-inference/tests/test_generator.py +++ b/text-generation-inference/tests/test_generator.py @@ -109,7 +109,7 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_ [ "It was a bright cold day in April, and the clocks were striking thirteen.", 20, - " We sat outside the house, drinking coffee, listening to the traffic. And then, suddenly, we", + " We sat outside the house, drinking coffee, listening to the orchestra playing through the window. We could", True, ], ],