From 7fbd21ad8b1b7a260ce004f9a4a68164efa8bfd4 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Mon, 26 Aug 2024 01:36:10 -0700 Subject: [PATCH] feat: basic speculative decoding support in mlx_lm.generate / mlx_lm.server Supporting samplers, rotating cache, and batching are deferred to future commits in order to keep this diff small. --- llms/mlx_lm/generate.py | 12 +++- llms/mlx_lm/models/base.py | 9 +++ llms/mlx_lm/server.py | 13 +++++ llms/mlx_lm/utils.py | 109 ++++++++++++++++++++++++++++++------- 4 files changed, 120 insertions(+), 23 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 6707d25c3..88798014e 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -23,6 +23,12 @@ def setup_arg_parser(): default="mlx_model", help="The path to the local model directory or Hugging Face repo.", ) + parser.add_argument( + "--draft-model", + type=str, + required=False, + help="The path to the local model directory or Hugging Face repo for speculative decoding.", + ) parser.add_argument( "--adapter-path", type=str, @@ -81,7 +87,7 @@ def setup_arg_parser(): "--max-kv-size", type=int, default=1024, - help="Set the maximum key-value cache size", + help="Set the maximum key-value cache size (0 for unlimited)", ) return parser @@ -132,6 +138,7 @@ def main(): adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, ) + draft_model = load(args.draft_model)[0] if args.draft_model is not None else None if args.use_default_chat_template: if tokenizer.chat_template is None: @@ -159,7 +166,8 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, - max_kv_size=args.max_kv_size, + max_kv_size=args.max_kv_size if args.max_kv_size > 0 else None, + draft_model=draft_model, ) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3e84554cb..2c85d1012 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -23,6 +23,15 @@ def __init__(self, head_dim, n_kv_heads): self.offset = 0 self.step = 256 + def drop(self, n): + if n >= self.offset: + self.keys = self.values = None + self.offset = 0 + elif n > 0: + self.keys = self.keys[..., :-n, :] + self.values = self.values[..., :-n, :] + self.offset -= n + def update_and_fetch(self, keys, values): prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index aa2c5ed7c..d137215e0 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -104,6 +104,9 @@ def __init__(self, cli_args: argparse.Namespace): # Preload the default model if it is provided if self.cli_args.model is not None: self.load("default_model") + if self.cli_args.draft_model is not None: + self.draft_model, _ = load(self.cli_args.model) + def _validate_model_path(self, model_path: str): model_path = Path(model_path) @@ -161,6 +164,7 @@ def __init__(self, model_provider: ModelProvider, *args, **kwargs): """ self.created = int(time.time()) self.model_provider = model_provider + self.draft_model = model_provider.draft_model super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -412,6 +416,8 @@ def handle_completion( generate_step( prompts=prompt[None], model=self.model, + draft_model=self.draft_model, + tokenizer=self.tokenizer, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, @@ -501,6 +507,8 @@ def handle_stream( generate_step( prompts=prompt[None], model=self.model, + draft_model=self.draft_model, + tokenizer=self.tokenizer, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, @@ -649,6 +657,11 @@ def main(): type=str, help="The path to the MLX model weights, tokenizer, and config", ) + parser.add_argument( + "--draft-model", + type=str, + help="The path to the MLX model weights and config for speculative decoding", + ) parser.add_argument( "--adapter-path", type=str, diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 58ab084fb..4b6aca418 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -102,6 +102,24 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path +def create_cache(model: nn.Module, max_kv_size: Optional[int]) -> list[KVCache]: + if hasattr(model, "make_cache"): + return model.make_cache() + else: + kv_heads = ( + [model.n_kv_heads] * len(model.layers) + if isinstance(model.n_kv_heads, int) + else model.n_kv_heads + ) + if max_kv_size is not None: + return [ + RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) + for n in kv_heads + ] + else: + return [KVCache(model.head_dim, n) for n in kv_heads] + + def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float): """ Apply repetition penalty to specific logits based on the given context. @@ -129,6 +147,7 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f def generate_step( prompts: mx.array, model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], temp: float = 0.0, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, @@ -138,6 +157,8 @@ def generate_step( logit_bias: Optional[Dict[int, float]] = None, prefill_step_size: int = 512, max_kv_size: Optional[int] = None, + draft_model: Optional[nn.Module] = None, + speculation_lookahead: int = 5, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -161,6 +182,11 @@ def generate_step( prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. + draft_model (nn.Module, optional): The model to use for drafting + (speculative decoding). Speculative decoding is currently only + supported for bs=1, temp=0, max_kv_size=None. + speculation_lookahead (int, optional): Number of tokens to generate + speculatively. Only used if `draft_model` is provided. Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing @@ -173,6 +199,24 @@ def generate_step( f"Shape of prompts should be (bs, seq_len), got {prompts.shape}" ) + if draft_model is not None: + if prompts.shape[0] != 1: + # https://github.com/huggingface/transformers/issues/32165 + raise ValueError( + f"Speculative decoding currently only supports batch size 1, got batch size {prompts.shape[0]}" + ) + if temp != 0: + # Samplers would need to be refactored to return + # transformed logprobs instead of sampled tokens + raise ValueError( + f"Speculative decoding currently only supports greedy sampling, got temp={temp}" + ) + if max_kv_size is not None: + # `RotatingKVCache` assumes one token generated at a time per prompt + raise ValueError( + f"Speculative decoding currently does not support max_kv_size, got max_kv_size={max_kv_size}" + ) + def sample(logits: mx.array) -> Tuple[mx.array, mx.array]: if logit_bias: indices = mx.array(list(logit_bias.keys())) @@ -200,21 +244,8 @@ def sample(logits: mx.array) -> Tuple[mx.array, mx.array]: ) y = prompts - if hasattr(model, "make_cache"): - cache = model.make_cache() - else: - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads - ) - if max_kv_size is not None: - cache = [ - RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) - for n in kv_heads - ] - else: - cache = [KVCache(model.head_dim, n) for n in kv_heads] + cache = create_cache(model, max_kv_size) + draft_cache = create_cache(draft_model, max_kv_size) if draft_model else None repetition_context = prompts @@ -243,16 +274,52 @@ def _step(y): while y.shape[1] > prefill_step_size: model(y[:, :prefill_step_size], cache=cache) mx.eval([c.state for c in cache]) + if draft_model is not None: + draft_model(y[:, :prefill_step_size], cache=draft_cache) + mx.eval([c.state for c in draft_cache]) y = y[:, prefill_step_size:] + old_y = y y, logprobs = _step(y) mx.async_eval(y) - while True: - next_y, next_logprobs = _step(y) - mx.async_eval(next_y) + if draft_model is not None: + draft_model(old_y, cache=draft_cache) mx.eval(y) yield y, logprobs - y, logprobs = next_y, next_logprobs + while True: + if draft_model is not None: + draft_y = y + draft_tokens: list[mx.array] = [] + for _ in range(speculation_lookahead): + draft_logits = draft_model(draft_y, cache=draft_cache) + draft_y = mx.argmax(draft_logits[:, -1, :], axis=-1, keepdims=True) + draft_tokens.append(draft_y) + if draft_y.item() == tokenizer.eos_token_id: + break + actual = mx.concatenate(draft_tokens, axis=-1) + verify_tokens = mx.concatenate([y, actual[:, :-1]], axis=-1) + logits = model(verify_tokens, cache=cache) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + expected = logits.argmax(axis=-1) + n_accepted = (expected == actual).astype(mx.uint8).cummin().sum().item() + for i in range(n_accepted): + y = draft_tokens[i] + yield y, logprobs[:, i, :] + n_rejected = len(draft_tokens) - n_accepted + if n_rejected > 0: + y = expected[:, n_accepted : n_accepted + 1] + yield y, logprobs[:, n_accepted, :] + n_rejected -= 1 + for c in cache: + c.drop(n_rejected) + for c in draft_cache: + c.drop(n_rejected) + else: + next_y, next_logprobs = _step(y) + mx.async_eval(next_y) + mx.eval(y) + yield y, logprobs + y, logprobs = next_y, next_logprobs def stream_generate( @@ -283,7 +350,7 @@ def stream_generate( detokenizer.reset() for (token, _), n in zip( - generate_step(prompt_tokens[None], model, **kwargs), + generate_step(prompt_tokens[None], model, tokenizer, **kwargs), range(max_tokens), ): token = token.item() @@ -346,7 +413,7 @@ def generate( tic = time.perf_counter() for (tokens, logprobs), n in zip( - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt_tokens, model, tokenizer, **kwargs), range(max_tokens), ): if n == 0: