Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][CPU] Support chunked-prefill and prefix-caching on CPU #10355

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ function cpu_tests() {
pytest -s -v \
tests/quantization/test_ipex_quant.py"

# Run chunked-prefill and prefix-cache test
docker exec cpu-test bash -c "
set -e
pytest -s -v -k cpu_only \
tests/basic_correctness/test_chunked_prefill.py"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to increase the timeout, I think 25 minutes isn't enough now.


# online inference
docker exec cpu-test bash -c "
set -e
Expand Down
9 changes: 4 additions & 5 deletions docs/source/getting_started/cpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ Installation with CPU

vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:

- Tensor Parallel (``-tp = N``)
- Quantization (``INT8 W8A8, AWQ``)

.. note::
More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon.
Comment on lines -11 to -12
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since FP8 KV cache is still under development, perhaps we can keep this note temporarily.

- Tensor Parallel
- Model Quantization (``INT8 W8A8, AWQ``)
- Chunked-prefill
- Prefix-caching

Table of contents:

Expand Down
4 changes: 2 additions & 2 deletions docs/source/serving/compatibility_matrix.rst
Original file line number Diff line number Diff line change
Expand Up @@ -311,15 +311,15 @@ Feature x Hardware
- ✅
- ✅
- ✅
-
-
- ✅
* - :ref:`APC <apc>`
- `<https://github.com/vllm-project/vllm/issues/3687>`__
- ✅
- ✅
- ✅
- ✅
-
-
- ✅
* - :ref:`LoRA <lora>`
- ✅
Expand Down
71 changes: 70 additions & 1 deletion tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest

from tests.kernels.utils import override_backend_env_variable
from vllm.platforms import current_platform

from ..models.utils import check_logprobs_close, check_outputs_equal
from ..utils import multi_gpu_test
Expand Down Expand Up @@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache(
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("dtype", ["half"])
def test_with_prefix_caching(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
tensor_parallel_size: int,
dtype: str,
) -> None:
"""
Checks exact match decode with and without prefix caching
Expand All @@ -233,7 +236,7 @@ def test_with_prefix_caching(
for enable in (True, False):
with vllm_runner(
model,
dtype="half",
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=True,
enable_prefix_caching=enable,
Expand All @@ -260,3 +263,69 @@ def test_with_prefix_caching(
name_0="w/o prefix caching",
name_1="with prefix caching",
)


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"])
@pytest.mark.cpu_only
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you define this mark in pyproject.toml?

@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_models_cpu(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
attention_backend: str,
monkeypatch,
) -> None:
test_models(
hf_runner,
vllm_runner,
example_prompts,
model,
dtype,
max_tokens,
chunked_prefill_token_size,
enforce_eager,
tensor_parallel_size,
attention_backend,
monkeypatch,
)


@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.cpu_only
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_with_prefix_caching_cpu(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
tensor_parallel_size: int,
dtype: str,
) -> None:
test_with_prefix_caching(
vllm_runner,
max_tokens,
enforce_eager,
chunk_size,
tensor_parallel_size,
dtype,
)
80 changes: 44 additions & 36 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,8 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.platforms import current_platform

if current_platform.is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
from vllm.attention.ops.paged_attn import PagedAttention
else:
from vllm.attention.ops.paged_attn import PagedAttention


class TorchSDPABackend(AttentionBackend):
Expand Down Expand Up @@ -71,9 +63,15 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
chunked_prefill: bool
seq_lens: Optional[List[int]] = None # For non-chunked prefill

# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None

# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
Expand Down Expand Up @@ -123,20 +121,14 @@ def is_all_cross_attn_metadata_set(self):

@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self

return None
if self.num_prefill_tokens == 0:
return None
return self

@property
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
if self.num_decode_tokens == 0:
return None

return self

def get_seq_lens(
Expand Down Expand Up @@ -409,19 +401,35 @@ def forward(
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens

output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0
or prefill_meta.block_tables.numel() == 0):
output = self._run_sdpa_forward(query,
key,
value,
prefill_meta,
attn_type=attn_type)
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
self._run_sdpa_forward(output,
query,
key,
value,
prefill_meta,
attn_type=attn_type)
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
assert not self.need_mask
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[:prefill_meta.num_prefill_tokens, :, :],
query[:prefill_meta.num_prefill_tokens, :, :],
key_cache,
value_cache,
prefill_meta.query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)

if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
Expand All @@ -433,8 +441,9 @@ def forward(
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)

output = PagedAttention.forward_decode(
query,
PagedAttention.forward_decode(
output[attn_metadata.num_prefill_tokens:, :, :],
query[attn_metadata.num_prefill_tokens:, :, :],
key_cache,
value_cache,
block_tables_arg,
Expand All @@ -453,12 +462,13 @@ def forward(

def _run_sdpa_forward(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
) -> None:
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
Expand All @@ -479,7 +489,6 @@ def _run_sdpa_forward(
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)

output = torch.empty_like(query)
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
Expand All @@ -502,7 +511,6 @@ def _run_sdpa_forward(
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
return output


def _make_alibi_bias(
Expand Down
Loading