Skip to content

Commit

Permalink
[Core][v1] Unify allocating slots in prefill and decode in KV cache m…
Browse files Browse the repository at this point in the history
…anager (vllm-project#12608)

As mentioned in RFC vllm-project#12254,
this PR achieves the task: combine allocate_slots and append_slots.

There should be no functionality change, except that in decode, also
raise exception when num_tokens is zero (like prefill), and change the
unit test case accordingly.

@comaniac @rickyyx @WoosukKwon @youkaichao @heheda12345 @simon-mo

---------

Signed-off-by: Shawn Du <[email protected]>
  • Loading branch information
ShawnD200 authored Feb 2, 2025
1 parent abfcdcd commit f8ece6e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 116 deletions.
24 changes: 13 additions & 11 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_decode():
req0.num_computed_tokens = 55
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 4)
new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None

Expand All @@ -175,7 +175,7 @@ def test_decode():
# the preallocated block.
for _ in range(5 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.append_slots(req0, 15)
new_blocks = manager.allocate_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None

Expand All @@ -185,7 +185,7 @@ def test_decode():
# the preallocated block.
for _ in range(6 + 11):
req0.append_output_token_ids(12)
new_blocks = manager.append_slots(req0, 17)
new_blocks = manager.allocate_slots(req0, 17)
# Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2

Expand Down Expand Up @@ -395,12 +395,14 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks

# Assume all computed.
manager.append_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks)
# Assume all computed, only when num_preallocate_tokens > 0, we need to
# consume the previously preallocated blocks.
if num_preallocated_blocks > 0:
manager.allocate_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks)

# Append 1 block.
blocks = manager.append_slots(req, block_size)
blocks = manager.allocate_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks


Expand Down Expand Up @@ -503,7 +505,7 @@ def test_mm_prefix_caching():
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 5)
new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0

# The just completed block should have hashes with extra keys.
Expand Down Expand Up @@ -603,7 +605,7 @@ def test_reset_prefix_cache():
unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55, [])
blocks = manager.allocate_slots(req0, 55)
assert [b.block_id for b in blocks] == [0, 1, 2, 3]

unique_token_ids = [4] * 7
Expand Down Expand Up @@ -639,7 +641,7 @@ def test_uncache_blocks():
)

req0 = make_request("0", list(range(30)))
blocks = manager.allocate_slots(req0, 30, [])
blocks = manager.allocate_slots(req0, 30)
assert [b.block_id for b in blocks] == [0, 1]
assert len(manager.cached_block_hash_to_block) == 1

Expand All @@ -648,7 +650,7 @@ def test_uncache_blocks():
# Simulate speculative tokens.
for _ in range(5):
req0.append_output_token_ids(8)
manager.append_slots(req0, 5)
manager.allocate_slots(req0, 5)
assert len(manager.cached_block_hash_to_block) == 2

# After sampling, assuming only 1 token is accepted.
Expand Down
168 changes: 64 additions & 104 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Tuple
from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple

from vllm.logger import init_logger
from vllm.utils import cdiv
Expand Down Expand Up @@ -67,7 +67,8 @@ def __init__(
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}
self.req_to_blocks: DefaultDict[str,
List[KVCacheBlock]] = defaultdict(list)

@property
def usage(self) -> float:
Expand Down Expand Up @@ -115,33 +116,75 @@ def get_computed_blocks(
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens

def append_slots(
def allocate_slots(
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[List[KVCacheBlock]] = None
) -> Optional[List[KVCacheBlock]]:
"""Append slots to the block table of the request.
We first append slots to already allocated blocks. If the allocated
blocks are not enough, we allocate new blocks.
"""Add slots for a request with new tokens to append.
Args:
request: The request to append slots.
num_tokens: The number of tokens to append.
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
Blocks layout:
-----------------------------------------------------------------------
| < computed > | < new computed > | < new > | < pre-allocated > |
-----------------------------------------------------------------------
| < required > |
--------------------------------------------------
| < full > |
------------------------------------------------
| <new full> |
--------------
The following *_blocks are illustrated in this layout.
Returns:
A list of new blocks if new blocks are allocated, or None
if new blocks are required but cannot be allocated.
A list of new allocated blocks.
"""
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
if num_tokens == 0:
raise ValueError("num_tokens must be greater than 0")

new_computed_blocks = new_computed_blocks or []

# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size)
req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))

num_new_blocks = num_required_blocks - len(req_blocks)
if num_new_blocks > self.free_block_queue.num_free_blocks:
# Need to allocate new blocks due to insufficient pre-allocated
# slots, but we cannot allocate new blocks due to the limit.
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
if blk.ref_cnt == 0)
if (num_new_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
# Cannot allocate new blocks
return None

# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self._touch(new_computed_blocks)
else:
assert not new_computed_blocks, (
"Computed blocks should be empty when "
"prefix caching is disabled")

# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
req_blocks.extend(new_computed_blocks)

# Start to handle new blocks

if num_new_blocks <= 0:
# No new block is needed.
new_blocks = []
Expand All @@ -160,112 +203,29 @@ def append_slots(
)
assert num_new_blocks > 0

# Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)

if not self.enable_caching:
return new_blocks

num_computed_full_blocks = (request.num_computed_tokens //
self.block_size)

# NOTE(rickyx): We are assuming the `num_tokens` are actual
# tokens rather than lookahead slots (e.g. for speculative decoding).
# TODO(rickyx): When supporting speculative decoding, we will need to
# differentiate between them so that we can know how many blocks are
# full after appending the actual tokens.
num_full_blocks_after_append = (request.num_computed_tokens +
num_tokens) // self.block_size
assert num_full_blocks_after_append <= len(req_blocks)

new_full_blocks = req_blocks[
num_computed_full_blocks:num_full_blocks_after_append]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=num_computed_full_blocks,
full_blocks=new_full_blocks,
prev_block=req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks >= 1 else None,
)

return new_blocks

def allocate_slots(
self,
request: Request,
num_tokens: int,
computed_blocks: List[KVCacheBlock],
) -> Optional[List[KVCacheBlock]]:
"""Allocate slots for a new request.
Args:
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks: A list of computed blocks.
Returns:
A list of new allocated blocks.
"""
if num_tokens == 0:
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")

# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
if blk.ref_cnt == 0)

num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
# Cannot allocate new blocks.
return None

# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self._touch(computed_blocks)
else:
assert not computed_blocks, (
"Computed blocks should be empty when "
"prefix caching is disabled")

# Determine the number of new blocks to allocate considering
# preallocated blocks.
num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
# TODO(woosuk): Check and reject requests if
# num_prompt_tokens + max_tokens > max_model_len.
self.max_num_blocks_per_req - len(computed_blocks),
)
assert num_new_blocks > 0

# Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks)
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks

if not self.enable_caching:
return new_blocks

num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size

new_full_blocks = self.req_to_blocks[
request.request_id][len(computed_blocks):num_full_blocks]
num_computed_full_blocks = num_computed_tokens // self.block_size
new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
blk_start_idx=num_computed_full_blocks,
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None,
)
prev_block=(req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks > 0 else None))

return new_blocks

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def schedule(self) -> "SchedulerOutput":
assert num_new_tokens > 0

while True:
new_blocks = self.kv_cache_manager.append_slots(
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens)
if new_blocks is None:
# The request cannot be scheduled.
Expand Down

0 comments on commit f8ece6e

Please sign in to comment.