From ef92993abb499787b5961df758faae653a3e1f7a Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Wed, 21 Aug 2024 00:46:19 +0800 Subject: [PATCH 1/3] feat: support batch input in `generate()` The `prompt` argument can now be either a `str` or `list[str]`. The change to `generate()` is backwards-compatible. The changes to `generate_step()`, `top_p_sampling()`, and `min_p_sampling()` are backwards-incompatible in order to unify shapes; this could be changed by adding a few if-statements, if preferred. --- llms/mlx_lm/sample_utils.py | 64 +++++++++------ llms/mlx_lm/server.py | 7 +- llms/mlx_lm/utils.py | 157 ++++++++++++++++++++++-------------- 3 files changed, 139 insertions(+), 89 deletions(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 20b008fac..1b403b393 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -26,7 +26,10 @@ def min_p_sampling( 0.99-0.8 range. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered. Default: ``1``. - + temperature: Temperature parameter for softmax distribution reshaping. + Returns: + token(s) selected based on the min-p criterion. + Shape: same as logits, but with the last dimension having size 1. """ if not (0 <= min_p <= 1.0): raise ValueError( @@ -39,14 +42,14 @@ def min_p_sampling( # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 # Softmax probabilities - probs = mx.softmax(logits * (1 / temperature), axis=-1) + probs = mx.softmax(logits / temperature, axis=-1) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logits).squeeze(0) - sorted_probs = probs[..., sorted_indices] + sorted_indices = mx.argsort(-logits) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) # Top probability - top_probs = probs[..., sorted_indices[0]] + top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1) # Calculate the min_p threshold scaled_min_p = min_p * top_probs @@ -58,13 +61,18 @@ def min_p_sampling( # Create pool of tokens with probability less than scaled min_p selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) - # Return sampled token - sorted_token = mx.random.categorical(mx.log(selected_probs)) - return sorted_indices[sorted_token] + # Return sampled token(s) + sampled_indices = mx.random.categorical(mx.log(selected_probs)) + tokens = mx.take_along_axis( + sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1 + ) + return tokens.squeeze(-1) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: +def top_p_sampling( + logits: mx.array, top_p: float, temperature: float, axis: int = -1 +) -> mx.array: """ Apply top-p (nucleus) sampling to logits. @@ -72,29 +80,35 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr logits: The logits from the model's output. top_p: The cumulative probability threshold for top-p filtering. temperature: Temperature parameter for softmax distribution reshaping. + axis: The axis along which to apply top-p sampling. Returns: - token selected based on the top-p criterion. + token(s) selected based on the top-p criterion. """ - # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 - probs = mx.softmax(logits * (1 / temperature), axis=-1) + # Apply temperature and compute softmax + probs = mx.softmax(logits / temperature, axis=axis) - # sort probs in ascending order - sorted_indices = mx.argsort(probs, axis=-1) - sorted_probs = probs[..., sorted_indices.squeeze(0)] + # Sort probs in descending order + sorted_indices = mx.argsort(-probs, axis=axis) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis) - cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + # Compute cumulative probabilities + cumulative_probs = mx.cumsum(sorted_probs, axis=axis) - # select tokens with cumulative probs below threshold - top_probs = mx.where( - cumulative_probs > 1 - top_p, - sorted_probs, - 0, - ) + # Create a mask for probs above the threshold + mask = cumulative_probs <= top_p + + # Apply the mask to the sorted probabilities + masked_probs = sorted_probs * mask - sorted_token = mx.random.categorical(mx.log(top_probs)) - token = sorted_indices.squeeze(0)[sorted_token] + # Sample from the normalized probabilities + sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis) + + # Gather the original token indices + tokens = mx.take_along_axis( + sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis + ) - return token + return tokens.squeeze(axis) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 79ac18361..aa2c5ed7c 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -410,7 +410,7 @@ def handle_completion( top_tokens = [] for (token, logprobs), _ in zip( generate_step( - prompt=prompt, + prompts=prompt[None], model=self.model, temp=self.temperature, top_p=self.top_p, @@ -420,6 +420,8 @@ def handle_completion( ), range(self.max_tokens), ): + token = token.item() + logprobs = logprobs.squeeze(0) detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) @@ -497,7 +499,7 @@ def handle_stream( for (token, _), _ in zip( generate_step( - prompt=prompt, + prompts=prompt[None], model=self.model, temp=self.temperature, top_p=self.top_p, @@ -506,6 +508,7 @@ def handle_stream( ), range(self.max_tokens), ): + token = token.item() detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 441967667..10609ecc3 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -9,7 +9,7 @@ import time from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union import mlx.core as mx import mlx.nn as nn @@ -117,17 +117,17 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f logits (mx.array): Logits with repetition penalty applied to generated tokens. """ if len(generated_tokens) > 0: - indices = mx.array([token for token in generated_tokens]) - selected_logits = logits[:, indices] + indices = generated_tokens + selected_logits = mx.take_along_axis(logits, indices, axis=-1) selected_logits = mx.where( selected_logits < 0, selected_logits * penalty, selected_logits / penalty ) - logits[:, indices] = selected_logits + logits[mx.arange(indices.shape[0])[:, None], indices] = selected_logits return logits def generate_step( - prompt: mx.array, + prompts: mx.array, model: nn.Module, temp: float = 0.0, repetition_penalty: Optional[float] = None, @@ -143,7 +143,7 @@ def generate_step( A generator producing token ids based on the given prompt from the model. Args: - prompt (mx.array): The input prompt. + prompts (mx.array): The input prompt(s). Shape: ``(bs, seq_len)``. model (nn.Module): The model to use for generation. temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. @@ -164,27 +164,33 @@ def generate_step( Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + one token and a vector of log probabilities per prompt. + Shapes: ``(bs, 1), (bs, vocab_size)``. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: + if prompts.ndim != 2: + raise ValueError( + f"Shape of prompts should be (bs, seq_len), got {prompts.shape}" + ) + + def sample(logits: mx.array) -> Tuple[mx.array, mx.array]: if logit_bias: indices = mx.array(list(logit_bias.keys())) values = mx.array(list(logit_bias.values())) logits[:, indices] += values - logprobs = logits - mx.logsumexp(logits) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) if temp == 0: - token = mx.argmax(logits, axis=-1) + tokens = mx.argmax(logits, axis=-1) else: if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) + tokens = top_p_sampling(logits, top_p, temp) elif min_p != 0.0: - token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) + tokens = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) else: - token = categorical_sampling(logits, temp) + tokens = categorical_sampling(logits, temp) - return token, logprobs + return mx.expand_dims(tokens, axis=-1), logprobs if repetition_penalty and ( repetition_penalty < 0 or not isinstance(repetition_penalty, float) @@ -193,7 +199,7 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: f"repetition_penalty must be a non-negative float, got {repetition_penalty}" ) - y = prompt + y = prompts if hasattr(model, "make_cache"): cache = model.make_cache() else: @@ -210,14 +216,14 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: else: cache = [KVCache(model.head_dim, n) for n in kv_heads] - repetition_context = prompt.tolist() + repetition_context = prompts if repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] + repetition_context = repetition_context[:, -repetition_context_size:] def _step(y): nonlocal repetition_context - logits = model(y[None], cache=cache) + logits = model(y, cache=cache) logits = logits[:, -1, :] if repetition_penalty: @@ -225,27 +231,27 @@ def _step(y): logits, repetition_context, repetition_penalty ) y, logprobs = sample(logits) - repetition_context.append(y.item()) + repetition_context = mx.concatenate([repetition_context, y], axis=-1) else: y, logprobs = sample(logits) if repetition_context_size: - if len(repetition_context) > repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] - return y, logprobs.squeeze(0) + if repetition_context.shape[1] > repetition_context_size: + repetition_context = repetition_context[:, -repetition_context_size:] + return y, logprobs - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=cache) + while y.shape[1] > prefill_step_size: + model(y[:, :prefill_step_size], cache=cache) mx.eval([c.state for c in cache]) - y = y[prefill_step_size:] + y = y[:, prefill_step_size:] y, logprobs = _step(y) - mx.async_eval(y) while True: next_y, next_logprobs = _step(y) mx.async_eval(next_y) - yield y.item(), logprobs + mx.eval(y) + yield y, logprobs y, logprobs = next_y, next_logprobs @@ -277,9 +283,10 @@ def stream_generate( detokenizer.reset() for (token, _), n in zip( - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt_tokens[None], model, **kwargs), range(max_tokens), ): + token = token.item() if token == tokenizer.eos_token_id: break detokenizer.add_token(token) @@ -294,19 +301,19 @@ def stream_generate( def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, + prompt: Union[str, List[str]], max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> Union[str, List[str]]: """ Generate a complete response from the model. Args: model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (str): The string prompt. + prompts (str): The string prompt(s). max_tokens (int): The maximum number of tokens. Default: ``100``. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. @@ -315,56 +322,82 @@ def generate( kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. """ + is_batch = isinstance(prompt, list) if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - if verbose: - print("=" * 10) - print("Prompt:", prompt) - - prompt_tokens = mx.array(tokenizer.encode(prompt)) - detokenizer = tokenizer.detokenizer + if is_batch: + tokenizer._tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer._tokenizer.pad_token = tokenizer.eos_token + tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id + prompt_tokens = mx.array( + tokenizer._tokenizer(prompt, padding=True)["input_ids"] + ) + output_toks = [] + else: + prompt_tokens = mx.array(tokenizer.encode(prompt))[None] + detokenizer = tokenizer.detokenizer + detokenizer.reset() + if verbose: + print("=" * 10) + print("Prompt:", prompt) tic = time.perf_counter() - detokenizer.reset() - for (token, logprobs), n in zip( + for (tokens, logprobs), n in zip( generate_step(prompt_tokens, model, **kwargs), range(max_tokens), ): if n == 0: prompt_time = time.perf_counter() - tic tic = time.perf_counter() - if token == tokenizer.eos_token_id: + if (tokens == tokenizer.eos_token_id).all(): break - detokenizer.add_token(token) - - if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) - else: - print(detokenizer.last_segment, end="", flush=True) - - token_count = n + 1 - detokenizer.finalize() + if is_batch: + output_toks.append(tokens) + else: + token = tokens.item() + logprobs = logprobs.squeeze(0) + detokenizer.add_token(token) + if verbose: + if formatter: + # We have to finalize so that the prob corresponds to the last segment + detokenizer.finalize() + formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) + else: + print(detokenizer.last_segment, end="", flush=True) + + if is_batch: + output_toks = mx.concatenate(output_toks, axis=1) + token_count = output_toks.size + response = [ + response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0] + for response in tokenizer.batch_decode(output_toks.tolist()) + ] + else: + token_count = n + detokenizer.finalize() + response = detokenizer.text if verbose: gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: + if token_count <= 0: print("No tokens generated for this prompt") - return + if is_batch: + for p, resp in zip(prompt, response): + print("=" * 10) + print("Prompt:", p) + print(resp) + else: + print(detokenizer.last_segment, flush=True) prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 - print(f"Peak memory: {peak_mem:.3f} GB") + gen_tps = token_count / gen_time + print("=" * 10) + print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {gen_tps:.3f} tokens-per-sec") - return detokenizer.text + return response def load_config(model_path: Path) -> dict: From 5105b31cf75c81369ba21ed2692e547317bb5bdf Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 23 Aug 2024 16:27:50 +0900 Subject: [PATCH 2/3] feat: show batch generation progress --- llms/mlx_lm/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 10609ecc3..58ab084fb 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -356,6 +356,8 @@ def generate( break if is_batch: output_toks.append(tokens) + if verbose: + print(".", end="", flush=True) else: token = tokens.item() logprobs = logprobs.squeeze(0) @@ -385,6 +387,7 @@ def generate( if token_count <= 0: print("No tokens generated for this prompt") if is_batch: + print() for p, resp in zip(prompt, response): print("=" * 10) print("Prompt:", p) From 7d0e1cc4e0774a6d1e7092c3ed98454f7ed4f3ac Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Mon, 26 Aug 2024 01:36:10 -0700 Subject: [PATCH 3/3] feat: basic speculative decoding support in mlx_lm.generate / mlx_lm.server This basic version only supports bs=1, temp=0, max_kv_size=None. Supporting samplers, rotating cache, and batching are deferred to future commits in order to keep this diff small. --- llms/mlx_lm/generate.py | 12 ++++- llms/mlx_lm/models/base.py | 18 ++++--- llms/mlx_lm/server.py | 12 +++++ llms/mlx_lm/utils.py | 102 +++++++++++++++++++++++++++++-------- llms/tests/test_models.py | 12 ++++- 5 files changed, 124 insertions(+), 32 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 6707d25c3..88798014e 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -23,6 +23,12 @@ def setup_arg_parser(): default="mlx_model", help="The path to the local model directory or Hugging Face repo.", ) + parser.add_argument( + "--draft-model", + type=str, + required=False, + help="The path to the local model directory or Hugging Face repo for speculative decoding.", + ) parser.add_argument( "--adapter-path", type=str, @@ -81,7 +87,7 @@ def setup_arg_parser(): "--max-kv-size", type=int, default=1024, - help="Set the maximum key-value cache size", + help="Set the maximum key-value cache size (0 for unlimited)", ) return parser @@ -132,6 +138,7 @@ def main(): adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, ) + draft_model = load(args.draft_model)[0] if args.draft_model is not None else None if args.use_default_chat_template: if tokenizer.chat_template is None: @@ -159,7 +166,8 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, - max_kv_size=args.max_kv_size, + max_kv_size=args.max_kv_size if args.max_kv_size > 0 else None, + draft_model=draft_model, ) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3e84554cb..81d3e2681 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -9,7 +9,6 @@ class KVCache: - def __init__(self, head_dim, n_kv_heads): self.n_kv_heads = n_kv_heads if isinstance(head_dim, int): @@ -23,6 +22,13 @@ def __init__(self, head_dim, n_kv_heads): self.offset = 0 self.step = 256 + def drop(self, n): + if n >= self.offset: + self.keys = self.values = None + self.offset = 0 + elif n > 0: + self.offset -= n + def update_and_fetch(self, keys, values): prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: @@ -33,11 +39,10 @@ def update_and_fetch(self, keys, values): new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: - if prev % self.step != 0: - self.keys = self.keys[..., :prev, :] - self.values = self.values[..., :prev, :] - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) + self.keys = mx.concatenate([self.keys[..., :prev, :], new_k], axis=2) + self.values = mx.concatenate( + [self.values[..., :prev, :], new_v], axis=2 + ) else: self.keys, self.values = new_k, new_v @@ -51,7 +56,6 @@ def state(self): class RotatingKVCache: - def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256): self.n_kv_heads = n_kv_heads if isinstance(head_dim, int): diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index aa2c5ed7c..978734a2b 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -104,6 +104,8 @@ def __init__(self, cli_args: argparse.Namespace): # Preload the default model if it is provided if self.cli_args.model is not None: self.load("default_model") + if self.cli_args.draft_model is not None: + self.draft_model, _ = load(self.cli_args.model) def _validate_model_path(self, model_path: str): model_path = Path(model_path) @@ -161,6 +163,7 @@ def __init__(self, model_provider: ModelProvider, *args, **kwargs): """ self.created = int(time.time()) self.model_provider = model_provider + self.draft_model = model_provider.draft_model super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -412,6 +415,8 @@ def handle_completion( generate_step( prompts=prompt[None], model=self.model, + draft_model=self.draft_model, + tokenizer=self.tokenizer, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, @@ -501,6 +506,8 @@ def handle_stream( generate_step( prompts=prompt[None], model=self.model, + draft_model=self.draft_model, + tokenizer=self.tokenizer, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, @@ -649,6 +656,11 @@ def main(): type=str, help="The path to the MLX model weights, tokenizer, and config", ) + parser.add_argument( + "--draft-model", + type=str, + help="The path to the MLX model weights and config for speculative decoding", + ) parser.add_argument( "--adapter-path", type=str, diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 58ab084fb..6d9404079 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -102,6 +102,24 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path +def create_cache(model: nn.Module, max_kv_size: Optional[int]) -> list[KVCache]: + if hasattr(model, "make_cache"): + return model.make_cache() + else: + kv_heads = ( + [model.n_kv_heads] * len(model.layers) + if isinstance(model.n_kv_heads, int) + else model.n_kv_heads + ) + if max_kv_size is not None: + return [ + RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) + for n in kv_heads + ] + else: + return [KVCache(model.head_dim, n) for n in kv_heads] + + def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float): """ Apply repetition penalty to specific logits based on the given context. @@ -129,6 +147,7 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f def generate_step( prompts: mx.array, model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], temp: float = 0.0, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, @@ -138,6 +157,8 @@ def generate_step( logit_bias: Optional[Dict[int, float]] = None, prefill_step_size: int = 512, max_kv_size: Optional[int] = None, + draft_model: Optional[nn.Module] = None, + speculation_lookahead: int = 5, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -161,6 +182,11 @@ def generate_step( prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. + draft_model (nn.Module, optional): The model to use for drafting + (speculative decoding). Speculative decoding is currently only + supported for bs=1, temp=0, max_kv_size=None. + speculation_lookahead (int, optional): Number of tokens to generate + speculatively. Only used if `draft_model` is provided. Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing @@ -173,6 +199,24 @@ def generate_step( f"Shape of prompts should be (bs, seq_len), got {prompts.shape}" ) + if draft_model is not None: + if prompts.shape[0] != 1: + # https://github.com/huggingface/transformers/issues/32165 + raise ValueError( + f"Speculative decoding currently only supports batch size 1, got batch size {prompts.shape[0]}" + ) + if temp != 0: + # Samplers would need to be refactored to return + # transformed logprobs instead of sampled tokens + raise ValueError( + f"Speculative decoding currently only supports greedy sampling, got temp={temp}" + ) + if max_kv_size is not None: + # `RotatingKVCache` assumes one token generated at a time per prompt + raise ValueError( + f"Speculative decoding currently does not support max_kv_size, got max_kv_size={max_kv_size}" + ) + def sample(logits: mx.array) -> Tuple[mx.array, mx.array]: if logit_bias: indices = mx.array(list(logit_bias.keys())) @@ -200,21 +244,8 @@ def sample(logits: mx.array) -> Tuple[mx.array, mx.array]: ) y = prompts - if hasattr(model, "make_cache"): - cache = model.make_cache() - else: - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads - ) - if max_kv_size is not None: - cache = [ - RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) - for n in kv_heads - ] - else: - cache = [KVCache(model.head_dim, n) for n in kv_heads] + cache = create_cache(model, max_kv_size) + draft_cache = create_cache(draft_model, max_kv_size) if draft_model else None repetition_context = prompts @@ -243,16 +274,45 @@ def _step(y): while y.shape[1] > prefill_step_size: model(y[:, :prefill_step_size], cache=cache) mx.eval([c.state for c in cache]) + if draft_model is not None: + draft_model(y[:, :prefill_step_size], cache=draft_cache) + mx.eval([c.state for c in draft_cache]) y = y[:, prefill_step_size:] + old_y = y y, logprobs = _step(y) mx.async_eval(y) - while True: - next_y, next_logprobs = _step(y) - mx.async_eval(next_y) + if draft_model is not None: + draft_model(old_y, cache=draft_cache) mx.eval(y) yield y, logprobs - y, logprobs = next_y, next_logprobs + while True: + if draft_model is not None: + draft_input = y + draft = mx.zeros((1, 0), dtype=y.dtype) + for _ in range(speculation_lookahead): + draft_logits = draft_model(draft_input, cache=draft_cache) + draft_input = mx.argmax(draft_logits[:, -1, :], axis=-1, keepdims=True) + draft = mx.concatenate([draft, draft_input], axis=-1) + if draft_input.item() == tokenizer.eos_token_id: + break + input_tokens = mx.concatenate([y, draft[:, :-1]], axis=-1) + logits = model(input_tokens, cache=cache) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + output = logits.argmax(axis=-1) + n_accepted = (output == draft).astype(mx.uint8).cummin().sum().item() + n_used = min(n_accepted + 1, draft.shape[1]) + for i in range(n_used): + y = output[:, i : i + 1] + yield y, logprobs[:, i, :] + for c in cache + draft_cache: + c.drop(draft.shape[1] - n_used) + else: + next_y, next_logprobs = _step(y) + mx.async_eval(next_y) + mx.eval(y) + yield y, logprobs + y, logprobs = next_y, next_logprobs def stream_generate( @@ -283,7 +343,7 @@ def stream_generate( detokenizer.reset() for (token, _), n in zip( - generate_step(prompt_tokens[None], model, **kwargs), + generate_step(prompt_tokens[None], model, tokenizer, **kwargs), range(max_tokens), ): token = token.item() @@ -346,7 +406,7 @@ def generate( tic = time.perf_counter() for (tokens, logprobs), n in zip( - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt_tokens, model, tokenizer, **kwargs), range(max_tokens), ): if n == 0: diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index fcf1dc331..2ded81282 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -8,7 +8,6 @@ class TestModels(unittest.TestCase): - def test_kv_cache(self): cache = KVCache(32, 4) @@ -29,6 +28,16 @@ def test_kv_cache(self): self.assertTrue(mx.array_equal(v_up, expected)) self.assertEqual(cache.offset, cache.step + 1) + cache.drop(5) + k = mx.ones((1, 4, 3, 32), mx.float16) + v = mx.ones((1, 4, 3, 32), mx.float16) + k_up, v_up = cache.update_and_fetch(k, v) + + expected = mx.ones((1, 4, cache.step - 1, 32), mx.float16) + self.assertTrue(mx.array_equal(k_up, expected)) + self.assertTrue(mx.array_equal(v_up, expected)) + self.assertEqual(cache.offset, cache.step - 1) + def test_rotating_kv_cache(self): b, h, d = 1, 2, 32 cache = RotatingKVCache(d, h, max_size=8, step=4) @@ -88,7 +97,6 @@ def test_rotating_kv_cache(self): idx = 2 def model_test_runner(self, model, model_type, vocab_size, num_layers): - self.assertEqual(len(model.layers), num_layers) self.assertEqual(model.model_type, model_type)