-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hardware][CPU] Support chunked-prefill and prefix-caching on CPU #10355
Open
bigPYJ1151
wants to merge
7
commits into
vllm-project:main
Choose a base branch
from
bigPYJ1151:upstream_chunked_prefill
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+529
−365
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
5f99876
support chunked prefill
bigPYJ1151 ba2575b
rebase
bigPYJ1151 5980981
fix test
bigPYJ1151 c45b967
support prefix cache
bigPYJ1151 46be06a
add tests
bigPYJ1151 d33a175
update doc
bigPYJ1151 96cafe5
fix multi modal chunked prefill
bigPYJ1151 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,11 +5,10 @@ Installation with CPU | |
|
||
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features: | ||
|
||
- Tensor Parallel (``-tp = N``) | ||
- Quantization (``INT8 W8A8, AWQ``) | ||
|
||
.. note:: | ||
More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon. | ||
Comment on lines
-11
to
-12
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||
- Tensor Parallel | ||
- Model Quantization (``INT8 W8A8, AWQ``) | ||
- Chunked-prefill | ||
- Prefix-caching | ||
|
||
Table of contents: | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
import pytest | ||
|
||
from tests.kernels.utils import override_backend_env_variable | ||
from vllm.platforms import current_platform | ||
|
||
from ..models.utils import check_logprobs_close, check_outputs_equal | ||
from ..utils import multi_gpu_test | ||
|
@@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache( | |
# NOTE: Increasing this in this suite will fail CI because we currently cannot | ||
# reset distributed env properly. Use a value > 1 just when you test. | ||
@pytest.mark.parametrize("tensor_parallel_size", [1]) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
def test_with_prefix_caching( | ||
vllm_runner, | ||
max_tokens: int, | ||
enforce_eager: bool, | ||
chunk_size: int, | ||
tensor_parallel_size: int, | ||
dtype: str, | ||
) -> None: | ||
""" | ||
Checks exact match decode with and without prefix caching | ||
|
@@ -233,7 +236,7 @@ def test_with_prefix_caching( | |
for enable in (True, False): | ||
with vllm_runner( | ||
model, | ||
dtype="half", | ||
dtype=dtype, | ||
max_num_batched_tokens=max_num_batched_tokens, | ||
enable_chunked_prefill=True, | ||
enable_prefix_caching=enable, | ||
|
@@ -260,3 +263,69 @@ def test_with_prefix_caching( | |
name_0="w/o prefix caching", | ||
name_1="with prefix caching", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", ["facebook/opt-125m"]) | ||
@pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
@pytest.mark.parametrize("max_tokens", [32]) | ||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) | ||
@pytest.mark.parametrize("enforce_eager", [False]) | ||
# NOTE: Increasing this in this suite will fail CI because we currently cannot | ||
# reset distributed env properly. Use a value > 1 just when you test. | ||
@pytest.mark.parametrize("tensor_parallel_size", [1]) | ||
@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"]) | ||
@pytest.mark.cpu_only | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you define this mark in |
||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") | ||
def test_models_cpu( | ||
hf_runner, | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
chunked_prefill_token_size: int, | ||
enforce_eager: bool, | ||
tensor_parallel_size: int, | ||
attention_backend: str, | ||
monkeypatch, | ||
) -> None: | ||
test_models( | ||
hf_runner, | ||
vllm_runner, | ||
example_prompts, | ||
model, | ||
dtype, | ||
max_tokens, | ||
chunked_prefill_token_size, | ||
enforce_eager, | ||
tensor_parallel_size, | ||
attention_backend, | ||
monkeypatch, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("max_tokens", [16]) | ||
@pytest.mark.parametrize("enforce_eager", [False]) | ||
@pytest.mark.parametrize("chunk_size", [30, 32]) | ||
# NOTE: Increasing this in this suite will fail CI because we currently cannot | ||
# reset distributed env properly. Use a value > 1 just when you test. | ||
@pytest.mark.parametrize("tensor_parallel_size", [1]) | ||
@pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
@pytest.mark.cpu_only | ||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") | ||
def test_with_prefix_caching_cpu( | ||
vllm_runner, | ||
max_tokens: int, | ||
enforce_eager: bool, | ||
chunk_size: int, | ||
tensor_parallel_size: int, | ||
dtype: str, | ||
) -> None: | ||
test_with_prefix_caching( | ||
vllm_runner, | ||
max_tokens, | ||
enforce_eager, | ||
chunk_size, | ||
tensor_parallel_size, | ||
dtype, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may need to increase the timeout, I think 25 minutes isn't enough now.