Skip to content

Commit

Permalink
feat: basic speculative decoding support in mlx_lm.generate / mlx_lm.…
Browse files Browse the repository at this point in the history
…server

This basic version only supports bs=1, temp=0, max_kv_size=None.
Supporting samplers, rotating cache, and batching are deferred to future
commits in order to keep this diff small.
  • Loading branch information
llllvvuu committed Aug 26, 2024
1 parent 5105b31 commit fc4e076
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 32 deletions.
12 changes: 10 additions & 2 deletions llms/mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def setup_arg_parser():
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--draft-model",
type=str,
required=False,
help="The path to the local model directory or Hugging Face repo for speculative decoding.",
)
parser.add_argument(
"--adapter-path",
type=str,
Expand Down Expand Up @@ -81,7 +87,7 @@ def setup_arg_parser():
"--max-kv-size",
type=int,
default=1024,
help="Set the maximum key-value cache size",
help="Set the maximum key-value cache size (0 for unlimited)",
)
return parser

Expand Down Expand Up @@ -132,6 +138,7 @@ def main():
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
draft_model = load(args.draft_model)[0] if args.draft_model is not None else None

if args.use_default_chat_template:
if tokenizer.chat_template is None:
Expand Down Expand Up @@ -159,7 +166,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=args.max_kv_size,
max_kv_size=args.max_kv_size if args.max_kv_size > 0 else None,
draft_model=draft_model,
)


Expand Down
18 changes: 11 additions & 7 deletions llms/mlx_lm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class KVCache:

def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
Expand All @@ -23,6 +22,13 @@ def __init__(self, head_dim, n_kv_heads):
self.offset = 0
self.step = 256

def drop(self, n):
if n >= self.offset:
self.keys = self.values = None
self.offset = 0
elif n > 0:
self.offset -= n

def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
Expand All @@ -33,11 +39,10 @@ def update_and_fetch(self, keys, values):
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
self.keys = mx.concatenate([self.keys[..., :prev, :], new_k], axis=2)
self.values = mx.concatenate(
[self.values[..., :prev, :], new_v], axis=2
)
else:
self.keys, self.values = new_k, new_v

Expand All @@ -51,7 +56,6 @@ def state(self):


class RotatingKVCache:

def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
Expand Down
12 changes: 12 additions & 0 deletions llms/mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self, cli_args: argparse.Namespace):
# Preload the default model if it is provided
if self.cli_args.model is not None:
self.load("default_model")
if self.cli_args.draft_model is not None:
self.draft_model, _ = load(self.cli_args.model)

def _validate_model_path(self, model_path: str):
model_path = Path(model_path)
Expand Down Expand Up @@ -161,6 +163,7 @@ def __init__(self, model_provider: ModelProvider, *args, **kwargs):
"""
self.created = int(time.time())
self.model_provider = model_provider
self.draft_model = model_provider.draft_model
super().__init__(*args, **kwargs)

def _set_cors_headers(self):
Expand Down Expand Up @@ -412,6 +415,8 @@ def handle_completion(
generate_step(
prompts=prompt[None],
model=self.model,
draft_model=self.draft_model,
tokenizer=self.tokenizer,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
Expand Down Expand Up @@ -501,6 +506,8 @@ def handle_stream(
generate_step(
prompts=prompt[None],
model=self.model,
draft_model=self.draft_model,
tokenizer=self.tokenizer,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
Expand Down Expand Up @@ -649,6 +656,11 @@ def main():
type=str,
help="The path to the MLX model weights, tokenizer, and config",
)
parser.add_argument(
"--draft-model",
type=str,
help="The path to the MLX model weights and config for speculative decoding",
)
parser.add_argument(
"--adapter-path",
type=str,
Expand Down
109 changes: 88 additions & 21 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,24 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
return model_path


def create_cache(model: nn.Module, max_kv_size: Optional[int]) -> list[KVCache]:
if hasattr(model, "make_cache"):
return model.make_cache()
else:
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
return [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
return [KVCache(model.head_dim, n) for n in kv_heads]


def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.
Expand Down Expand Up @@ -129,6 +147,7 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
def generate_step(
prompts: mx.array,
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
temp: float = 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
Expand All @@ -138,6 +157,8 @@ def generate_step(
logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
draft_model: Optional[nn.Module] = None,
speculation_lookahead: int = 5,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Expand All @@ -161,6 +182,11 @@ def generate_step(
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
draft_model (nn.Module, optional): The model to use for drafting
(speculative decoding). Speculative decoding is currently only
supported for bs=1, temp=0, max_kv_size=None.
speculation_lookahead (int, optional): Number of tokens to generate
speculatively. Only used if `draft_model` is provided.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
Expand All @@ -173,6 +199,24 @@ def generate_step(
f"Shape of prompts should be (bs, seq_len), got {prompts.shape}"
)

if draft_model is not None:
if prompts.shape[0] != 1:
# https://github.com/huggingface/transformers/issues/32165
raise ValueError(
f"Speculative decoding currently only supports batch size 1, got batch size {prompts.shape[0]}"
)
if temp != 0:
# Samplers would need to be refactored to return
# transformed logprobs instead of sampled tokens
raise ValueError(
f"Speculative decoding currently only supports greedy sampling, got temp={temp}"
)
if max_kv_size is not None:
# `RotatingKVCache` assumes one token generated at a time per prompt
raise ValueError(
f"Speculative decoding currently does not support max_kv_size, got max_kv_size={max_kv_size}"
)

def sample(logits: mx.array) -> Tuple[mx.array, mx.array]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
Expand Down Expand Up @@ -200,21 +244,8 @@ def sample(logits: mx.array) -> Tuple[mx.array, mx.array]:
)

y = prompts
if hasattr(model, "make_cache"):
cache = model.make_cache()
else:
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
cache = [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
cache = [KVCache(model.head_dim, n) for n in kv_heads]
cache = create_cache(model, max_kv_size)
draft_cache = create_cache(draft_model, max_kv_size) if draft_model else None

repetition_context = prompts

Expand Down Expand Up @@ -243,16 +274,52 @@ def _step(y):
while y.shape[1] > prefill_step_size:
model(y[:, :prefill_step_size], cache=cache)
mx.eval([c.state for c in cache])
if draft_model is not None:
draft_model(y[:, :prefill_step_size], cache=draft_cache)
mx.eval([c.state for c in draft_cache])
y = y[:, prefill_step_size:]

old_y = y
y, logprobs = _step(y)
mx.async_eval(y)
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
if draft_model is not None:
draft_model(old_y, cache=draft_cache)
mx.eval(y)
yield y, logprobs
y, logprobs = next_y, next_logprobs
while True:
if draft_model is not None:
draft_y = y
draft_tokens: list[mx.array] = []
for _ in range(speculation_lookahead):
draft_logits = draft_model(draft_y, cache=draft_cache)
draft_y = mx.argmax(draft_logits[:, -1, :], axis=-1, keepdims=True)
draft_tokens.append(draft_y)
if draft_y.item() == tokenizer.eos_token_id:
break
actual = mx.concatenate(draft_tokens, axis=-1)
verify_tokens = mx.concatenate([y, actual[:, :-1]], axis=-1)
logits = model(verify_tokens, cache=cache)
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
expected = logits.argmax(axis=-1)
n_accepted = (expected == actual).astype(mx.uint8).cummin().sum().item()
for i in range(n_accepted):
y = draft_tokens[i]
yield y, logprobs[:, i, :]
n_rejected = len(draft_tokens) - n_accepted
if n_rejected > 0:
y = expected[:, n_accepted : n_accepted + 1]
yield y, logprobs[:, n_accepted, :]
n_rejected -= 1
for c in cache:
c.drop(n_rejected)
for c in draft_cache:
c.drop(n_rejected)
else:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
mx.eval(y)
yield y, logprobs
y, logprobs = next_y, next_logprobs


def stream_generate(
Expand Down Expand Up @@ -283,7 +350,7 @@ def stream_generate(

detokenizer.reset()
for (token, _), n in zip(
generate_step(prompt_tokens[None], model, **kwargs),
generate_step(prompt_tokens[None], model, tokenizer, **kwargs),
range(max_tokens),
):
token = token.item()
Expand Down Expand Up @@ -346,7 +413,7 @@ def generate(
tic = time.perf_counter()

for (tokens, logprobs), n in zip(
generate_step(prompt_tokens, model, **kwargs),
generate_step(prompt_tokens, model, tokenizer, **kwargs),
range(max_tokens),
):
if n == 0:
Expand Down
12 changes: 10 additions & 2 deletions llms/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class TestModels(unittest.TestCase):

def test_kv_cache(self):
cache = KVCache(32, 4)

Expand All @@ -29,6 +28,16 @@ def test_kv_cache(self):
self.assertTrue(mx.array_equal(v_up, expected))
self.assertEqual(cache.offset, cache.step + 1)

cache.drop(5)
k = mx.ones((1, 4, 3, 32), mx.float16)
v = mx.ones((1, 4, 3, 32), mx.float16)
k_up, v_up = cache.update_and_fetch(k, v)

expected = mx.ones((1, 4, cache.step - 1, 32), mx.float16)
self.assertTrue(mx.array_equal(k_up, expected))
self.assertTrue(mx.array_equal(v_up, expected))
self.assertEqual(cache.offset, cache.step - 1)

def test_rotating_kv_cache(self):
b, h, d = 1, 2, 32
cache = RotatingKVCache(d, h, max_size=8, step=4)
Expand Down Expand Up @@ -88,7 +97,6 @@ def test_rotating_kv_cache(self):
idx = 2

def model_test_runner(self, model, model_type, vocab_size, num_layers):

self.assertEqual(len(model.layers), num_layers)
self.assertEqual(model.model_type, model_type)

Expand Down

0 comments on commit fc4e076

Please sign in to comment.