Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core]: (Last/N) Support prefill only models by Workflow Defined Engine #8964

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Empty file added demo_temporary/__init__.py
Empty file.
Empty file.
90 changes: 90 additions & 0 deletions demo_temporary/benchmarks/benchmark_attention_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import random
import time


def benchmark_vllm(args):
random.seed(args.seed)
os.environ["VLLM_ATTENTION_BACKEND"] = args.attention_impl

import gc

import torch

from vllm.entrypoints.wde_llm import LLMEngine
from vllm.model_executor.encode_only.arg_utils import ( # noqa: E501
EncodeOnlyEngineArgs as EngineArgs)

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

engine_args = EngineArgs(model=args.model,
tokenizer=args.tokenizer,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
device=args.device,
max_num_seqs=32,
scheduling=args.scheduling)

engine = LLMEngine.from_engine_args(engine_args)

for batchsize in args.batchsize:
engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize)

start = time.perf_counter()
for request_id, prompt in enumerate(requests):
engine.add_request(str(request_id), prompt)

n_step = 0
while engine.has_unfinished_requests():
engine.step()
n_step += 1
end = time.perf_counter()

elapsed_time = end - start
delay = elapsed_time / n_step

print(f"Batchsize {batchsize}, Throughput: "
f"{len(requests) / elapsed_time:.4f} requests/s, "
f"Delay {delay * 1000:0.2f} ms, n_step {n_step}")

engine.executor.shutdown_execute_loop()
gc.collect()
torch.cuda.empty_cache()


if __name__ == '__main__':
from easydict import EasyDict as edict

from vllm.attention.prefill_only.selector import AttentionImpls
args = edict()

args.input_len = 256
args.num_prompts = 10000

args.model = "google-bert/bert-base-uncased"

args.trust_remote_code = False
args.tokenizer = args.model
args.seed = 0
args.max_model_len = None
args.device = "cuda"
args.batchsize = [1, 2, 4, 8, 16, 32, 64]
args.scheduling = "double_buffer"

from concurrent.futures import ProcessPoolExecutor

def run_vllm(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark_vllm, args)
f.result()

for dtype, attention_impls in AttentionImpls.items():
print("dtype:", dtype)
for attention_impl in attention_impls:
print("attention_impl:", attention_impl)
args.attention_impl = attention_impl
args.dtype = dtype
run_vllm(args)
83 changes: 83 additions & 0 deletions demo_temporary/benchmarks/benchmark_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import random
import time


def benchmark_vllm(args):
random.seed(args.seed)

import gc

import torch

from vllm.entrypoints.wde_llm import LLMEngine
from vllm.model_executor.encode_only.arg_utils import ( # noqa: E501
EncodeOnlyEngineArgs as EngineArgs)

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

engine_args = EngineArgs(model=args.model,
tokenizer=args.tokenizer,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
device=args.device,
max_num_seqs=32,
scheduling=args.scheduling)

engine = LLMEngine.from_engine_args(engine_args)

for batchsize in args.batchsize:
engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize)

start = time.perf_counter()
for request_id, prompt in enumerate(requests):
engine.add_request(str(request_id), prompt)

n_step = 0
while engine.has_unfinished_requests():
engine.step()
n_step += 1
end = time.perf_counter()

elapsed_time = end - start
delay = elapsed_time / n_step

print(f"Batchsize {batchsize}, Throughput: "
f"{len(requests) / elapsed_time:.4f} requests/s, "
f"Delay {delay * 1000:0.2f} ms, n_step {n_step}")

engine.executor.shutdown_execute_loop()
gc.collect()
torch.cuda.empty_cache()


if __name__ == '__main__':
from easydict import EasyDict as edict
args = edict()

args.input_len = 256
args.num_prompts = 10000

args.model = "google-bert/bert-base-uncased"

args.trust_remote_code = False
args.tokenizer = args.model
args.seed = 0
args.max_model_len = None
args.dtype = "half"
args.device = "cuda"
args.batchsize = [1, 2, 4, 8, 16, 32, 64]

from concurrent.futures import ProcessPoolExecutor

def run_vllm(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark_vllm, args)
f.result()

for scheduling in ["sync", "async", "double_buffer"]:
print(scheduling)
args.scheduling = scheduling
run_vllm(args)
Empty file.
14 changes: 14 additions & 0 deletions demo_temporary/examples/offline_inference_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from vllm.entrypoints.wde_llm import LLM

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

llm = LLM(model="google-bert/bert-base-uncased")

outputs = llm.encode(prompts)
for output in outputs:
print(output.outputs.shape)
Empty file added tests/attention/__init__.py
Empty file.
Empty file.
89 changes: 89 additions & 0 deletions tests/attention/prefill_only/test_basic_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import itertools as it

import pytest
import torch
import torch.nn.functional as F

from vllm.attention.layer import Attention
from vllm.attention.prefill_only.abstract import AttentionType
from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend,
_Backend)
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE


def compare_embeddings(embeddings1, embeddings2):
similarities = [
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
for e1, e2 in zip(embeddings1, embeddings2)
]
return similarities


SEQ_LENS = [1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29]


@pytest.mark.parametrize("head_dim", [64])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("num_kv_heads", [1, 2, 4, 8])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"])
@pytest.mark.parametrize("n_seqs", list(range(1, len(SEQ_LENS))))
def test_basic_correctness(head_dim: int, num_heads: int, num_kv_heads: int,
attn_type: str, dtype: str, n_seqs: int):
assert num_heads % num_kv_heads == 0

torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]

attention_impls = AttentionImpls[dtype]

seq_lens = SEQ_LENS[:n_seqs]
batchsize = sum(seq_lens)

query = torch.rand((batchsize, num_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))
key = torch.rand((batchsize, num_kv_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))
value = torch.rand((batchsize, num_kv_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))

impl_outputs_list = []

for attention_impl in attention_impls:
selected_backend = _Backend.backend_name_to_enum(attention_impl)
backend_cls = AttnBackend.get_backend_cls(selected_backend)

attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type)

attn_backend = backend_cls(attn_type_enum)
scaling = head_dim**-0.5

attn = Attention(num_heads,
head_dim,
scale=scaling,
num_kv_heads=num_kv_heads,
attn_backend=attn_backend)

metadata_builder = attn_backend.make_metadata_builder()
attn_metadata = metadata_builder(seq_lens=seq_lens)
attn_metadata = attn_metadata.to("cuda:0")

outputs = attn.forward(query,
key,
value,
kv_cache=None,
attn_metadata=attn_metadata)

impl_outputs_list.append((attention_impl, outputs))

tolerance = 1e-2
for a, b in it.combinations(impl_outputs_list, 2):
similarities = compare_embeddings(a[1], b[1])
all_similarities = torch.stack(similarities)

assert torch.all(
(all_similarities <= 1.0 + tolerance)
& (all_similarities >= 1.0 - tolerance)
), f"{a[0]} vs {b[0]}, not all values are within {tolerance} of 1.0"
54 changes: 54 additions & 0 deletions tests/attention/prefill_only/test_enum_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest

from vllm.attention.prefill_only.abstract import (AttentionType,
PrefillOnlyAttentionBackend)
from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend,
_Backend)


def get_attn_backend(attention_impl: str, attn_type: str):
selected_backend = _Backend.backend_name_to_enum(attention_impl)
backend_cls = AttnBackend.get_backend_cls(selected_backend)

attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type)

attn_backend = backend_cls(attn_type_enum)
return attn_backend


@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
def test_backend(dtype: str, attn_type: str):
attention_impls = AttentionImpls[dtype]

for attention_impl in attention_impls:
attn_backend = get_attn_backend(attention_impl, attn_type)

assert isinstance(attn_backend, PrefillOnlyAttentionBackend)


@pytest.mark.parametrize("attn_type", ["ENCODER_DECODER"])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
def test_ENCODER_DECODER_not_supported(dtype: str, attn_type: str):
attention_impls = AttentionImpls[dtype]

for attention_impl in attention_impls:
with pytest.raises(NotImplementedError):
get_attn_backend(attention_impl, attn_type)


def test_not_supported_backend():
attention_impls = ["not_supported_backend", 0, 1.0]

for attention_impl in attention_impls:
with pytest.raises(ValueError):
selected_backend = _Backend.backend_name_to_enum(attention_impl)
AttnBackend.get_backend_cls(selected_backend)


def test_not_supported_attn_type():
attn_types = ["not_supported_attn_type", 0, 1.0]

for attn_type in attn_types:
with pytest.raises(ValueError):
AttentionType.attn_type_name_to_enum(attn_type)
Empty file.
Loading