diff --git a/demo_temporary/__init__.py b/demo_temporary/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/demo_temporary/benchmarks/__init__.py b/demo_temporary/benchmarks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/demo_temporary/benchmarks/benchmark_attention_impl.py b/demo_temporary/benchmarks/benchmark_attention_impl.py new file mode 100644 index 0000000000000..d3b45a0175185 --- /dev/null +++ b/demo_temporary/benchmarks/benchmark_attention_impl.py @@ -0,0 +1,90 @@ +import os +import random +import time + + +def benchmark_vllm(args): + random.seed(args.seed) + os.environ["VLLM_ATTENTION_BACKEND"] = args.attention_impl + + import gc + + import torch + + from vllm.entrypoints.wde_llm import LLMEngine + from vllm.model_executor.encode_only.arg_utils import ( # noqa: E501 + EncodeOnlyEngineArgs as EngineArgs) + + prompt = "if" * args.input_len + requests = [prompt for _ in range(args.num_prompts)] + + engine_args = EngineArgs(model=args.model, + tokenizer=args.tokenizer, + seed=args.seed, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + max_model_len=args.max_model_len, + device=args.device, + max_num_seqs=32, + scheduling=args.scheduling) + + engine = LLMEngine.from_engine_args(engine_args) + + for batchsize in args.batchsize: + engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize) + + start = time.perf_counter() + for request_id, prompt in enumerate(requests): + engine.add_request(str(request_id), prompt) + + n_step = 0 + while engine.has_unfinished_requests(): + engine.step() + n_step += 1 + end = time.perf_counter() + + elapsed_time = end - start + delay = elapsed_time / n_step + + print(f"Batchsize {batchsize}, Throughput: " + f"{len(requests) / elapsed_time:.4f} requests/s, " + f"Delay {delay * 1000:0.2f} ms, n_step {n_step}") + + engine.executor.shutdown_execute_loop() + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + from easydict import EasyDict as edict + + from vllm.attention.prefill_only.selector import AttentionImpls + args = edict() + + args.input_len = 256 + args.num_prompts = 10000 + + args.model = "google-bert/bert-base-uncased" + + args.trust_remote_code = False + args.tokenizer = args.model + args.seed = 0 + args.max_model_len = None + args.device = "cuda" + args.batchsize = [1, 2, 4, 8, 16, 32, 64] + args.scheduling = "double_buffer" + + from concurrent.futures import ProcessPoolExecutor + + def run_vllm(args): + with ProcessPoolExecutor(1) as executor: + f = executor.submit(benchmark_vllm, args) + f.result() + + for dtype, attention_impls in AttentionImpls.items(): + print("dtype:", dtype) + for attention_impl in attention_impls: + print("attention_impl:", attention_impl) + args.attention_impl = attention_impl + args.dtype = dtype + run_vllm(args) diff --git a/demo_temporary/benchmarks/benchmark_bert.py b/demo_temporary/benchmarks/benchmark_bert.py new file mode 100644 index 0000000000000..1d4c7e7d8a99f --- /dev/null +++ b/demo_temporary/benchmarks/benchmark_bert.py @@ -0,0 +1,83 @@ +import random +import time + + +def benchmark_vllm(args): + random.seed(args.seed) + + import gc + + import torch + + from vllm.entrypoints.wde_llm import LLMEngine + from vllm.model_executor.encode_only.arg_utils import ( # noqa: E501 + EncodeOnlyEngineArgs as EngineArgs) + + prompt = "if" * args.input_len + requests = [prompt for _ in range(args.num_prompts)] + + engine_args = EngineArgs(model=args.model, + tokenizer=args.tokenizer, + seed=args.seed, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + max_model_len=args.max_model_len, + device=args.device, + max_num_seqs=32, + scheduling=args.scheduling) + + engine = LLMEngine.from_engine_args(engine_args) + + for batchsize in args.batchsize: + engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize) + + start = time.perf_counter() + for request_id, prompt in enumerate(requests): + engine.add_request(str(request_id), prompt) + + n_step = 0 + while engine.has_unfinished_requests(): + engine.step() + n_step += 1 + end = time.perf_counter() + + elapsed_time = end - start + delay = elapsed_time / n_step + + print(f"Batchsize {batchsize}, Throughput: " + f"{len(requests) / elapsed_time:.4f} requests/s, " + f"Delay {delay * 1000:0.2f} ms, n_step {n_step}") + + engine.executor.shutdown_execute_loop() + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + from easydict import EasyDict as edict + args = edict() + + args.input_len = 256 + args.num_prompts = 10000 + + args.model = "google-bert/bert-base-uncased" + + args.trust_remote_code = False + args.tokenizer = args.model + args.seed = 0 + args.max_model_len = None + args.dtype = "half" + args.device = "cuda" + args.batchsize = [1, 2, 4, 8, 16, 32, 64] + + from concurrent.futures import ProcessPoolExecutor + + def run_vllm(args): + with ProcessPoolExecutor(1) as executor: + f = executor.submit(benchmark_vllm, args) + f.result() + + for scheduling in ["sync", "async", "double_buffer"]: + print(scheduling) + args.scheduling = scheduling + run_vllm(args) diff --git a/demo_temporary/examples/__init__.py b/demo_temporary/examples/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/demo_temporary/examples/offline_inference_bert.py b/demo_temporary/examples/offline_inference_bert.py new file mode 100644 index 0000000000000..80c215214c60e --- /dev/null +++ b/demo_temporary/examples/offline_inference_bert.py @@ -0,0 +1,14 @@ +from vllm.entrypoints.wde_llm import LLM + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +llm = LLM(model="google-bert/bert-base-uncased") + +outputs = llm.encode(prompts) +for output in outputs: + print(output.outputs.shape) diff --git a/tests/attention/__init__.py b/tests/attention/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/attention/prefill_only/__init__.py b/tests/attention/prefill_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/attention/prefill_only/test_basic_correctness.py b/tests/attention/prefill_only/test_basic_correctness.py new file mode 100644 index 0000000000000..39e721ea8d6bf --- /dev/null +++ b/tests/attention/prefill_only/test_basic_correctness.py @@ -0,0 +1,89 @@ +import itertools as it + +import pytest +import torch +import torch.nn.functional as F + +from vllm.attention.layer import Attention +from vllm.attention.prefill_only.abstract import AttentionType +from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend, + _Backend) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + + +def compare_embeddings(embeddings1, embeddings2): + similarities = [ + F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0) + for e1, e2 in zip(embeddings1, embeddings2) + ] + return similarities + + +SEQ_LENS = [1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29] + + +@pytest.mark.parametrize("head_dim", [64]) +@pytest.mark.parametrize("num_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4, 8]) +@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) +@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"]) +@pytest.mark.parametrize("n_seqs", list(range(1, len(SEQ_LENS)))) +def test_basic_correctness(head_dim: int, num_heads: int, num_kv_heads: int, + attn_type: str, dtype: str, n_seqs: int): + assert num_heads % num_kv_heads == 0 + + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + + attention_impls = AttentionImpls[dtype] + + seq_lens = SEQ_LENS[:n_seqs] + batchsize = sum(seq_lens) + + query = torch.rand((batchsize, num_heads, head_dim), + dtype=torch_dtype, + device="cuda:0").view((batchsize, -1)) + key = torch.rand((batchsize, num_kv_heads, head_dim), + dtype=torch_dtype, + device="cuda:0").view((batchsize, -1)) + value = torch.rand((batchsize, num_kv_heads, head_dim), + dtype=torch_dtype, + device="cuda:0").view((batchsize, -1)) + + impl_outputs_list = [] + + for attention_impl in attention_impls: + selected_backend = _Backend.backend_name_to_enum(attention_impl) + backend_cls = AttnBackend.get_backend_cls(selected_backend) + + attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type) + + attn_backend = backend_cls(attn_type_enum) + scaling = head_dim**-0.5 + + attn = Attention(num_heads, + head_dim, + scale=scaling, + num_kv_heads=num_kv_heads, + attn_backend=attn_backend) + + metadata_builder = attn_backend.make_metadata_builder() + attn_metadata = metadata_builder(seq_lens=seq_lens) + attn_metadata = attn_metadata.to("cuda:0") + + outputs = attn.forward(query, + key, + value, + kv_cache=None, + attn_metadata=attn_metadata) + + impl_outputs_list.append((attention_impl, outputs)) + + tolerance = 1e-2 + for a, b in it.combinations(impl_outputs_list, 2): + similarities = compare_embeddings(a[1], b[1]) + all_similarities = torch.stack(similarities) + + assert torch.all( + (all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"{a[0]} vs {b[0]}, not all values are within {tolerance} of 1.0" diff --git a/tests/attention/prefill_only/test_enum_verify.py b/tests/attention/prefill_only/test_enum_verify.py new file mode 100644 index 0000000000000..1996f4f42cbba --- /dev/null +++ b/tests/attention/prefill_only/test_enum_verify.py @@ -0,0 +1,54 @@ +import pytest + +from vllm.attention.prefill_only.abstract import (AttentionType, + PrefillOnlyAttentionBackend) +from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend, + _Backend) + + +def get_attn_backend(attention_impl: str, attn_type: str): + selected_backend = _Backend.backend_name_to_enum(attention_impl) + backend_cls = AttnBackend.get_backend_cls(selected_backend) + + attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type) + + attn_backend = backend_cls(attn_type_enum) + return attn_backend + + +@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"]) +@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) +def test_backend(dtype: str, attn_type: str): + attention_impls = AttentionImpls[dtype] + + for attention_impl in attention_impls: + attn_backend = get_attn_backend(attention_impl, attn_type) + + assert isinstance(attn_backend, PrefillOnlyAttentionBackend) + + +@pytest.mark.parametrize("attn_type", ["ENCODER_DECODER"]) +@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) +def test_ENCODER_DECODER_not_supported(dtype: str, attn_type: str): + attention_impls = AttentionImpls[dtype] + + for attention_impl in attention_impls: + with pytest.raises(NotImplementedError): + get_attn_backend(attention_impl, attn_type) + + +def test_not_supported_backend(): + attention_impls = ["not_supported_backend", 0, 1.0] + + for attention_impl in attention_impls: + with pytest.raises(ValueError): + selected_backend = _Backend.backend_name_to_enum(attention_impl) + AttnBackend.get_backend_cls(selected_backend) + + +def test_not_supported_attn_type(): + attn_types = ["not_supported_attn_type", 0, 1.0] + + for attn_type in attn_types: + with pytest.raises(ValueError): + AttentionType.attn_type_name_to_enum(attn_type) diff --git a/tests/core/prefill_only/__init__.py b/tests/core/prefill_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/core/prefill_only/test_base_scheduler.py b/tests/core/prefill_only/test_base_scheduler.py new file mode 100644 index 0000000000000..684cb98e6bd1e --- /dev/null +++ b/tests/core/prefill_only/test_base_scheduler.py @@ -0,0 +1,73 @@ +import pytest + +from vllm.core.prefill_only_scheduler import Scheduler +from vllm.inputs.prefill_only.data import Request +from vllm.model_executor.prefill_only.engine_io import RequestOutput + + +class Scheduler4Test(Scheduler): + + def schedule(self): + pass + + +@pytest.mark.parametrize("n_request", [9, 99, 199]) +def test_add_request_and_abort_request(n_request: int): + scheduler = Scheduler4Test(None, None) + + # add requests + for i in range(1, n_request + 1): + scheduler.add_request(Request(request_id=str(i), arrival_time=0.)) + assert len(scheduler.waiting) == i + assert len(scheduler.requests) == i + + # abort irrelevant requests + for i in range(1, n_request + 1): + scheduler.abort_request(request_id=str(100000 + i)) + assert len(scheduler.waiting) == n_request + assert len(scheduler.requests) == n_request + + # abort requests + for i in range(1, n_request + 1): + scheduler.abort_request(request_id=str(i)) + assert len(scheduler.waiting) == n_request + assert len(scheduler.requests) == n_request - i + + # Lazy abort_request, only test whether to abort during scheduling + assert len(scheduler.waiting) == n_request + assert len(scheduler.requests) == 0 + + +@pytest.mark.parametrize("n_request", [9, 99, 199]) +def test_remove_abort_request(n_request: int): + scheduler = Scheduler4Test(None, None) + + request_outputs = [] + for i in range(1, n_request + 1): + scheduler.add_request(Request(request_id=str(i), arrival_time=0.)) + request_outputs.append( + RequestOutput(request_id=str(i), arrival_time=0., finished=True)) + assert len(scheduler.waiting) == i + assert len(scheduler.requests) == i + assert len(request_outputs) == i + + # abort half of requests + for i in range(1, n_request // 2): + scheduler.abort_request(request_id=str(i)) + assert len(scheduler.waiting) == n_request + assert len(scheduler.requests) == n_request - i + + finished_requests = scheduler.remove_abort_request(request_outputs) + assert len(finished_requests) == n_request - n_request // 2 + 1 + assert len(scheduler.requests) == n_request - n_request // 2 + 1 + assert len(scheduler.aborted_requests) == 0 + + finished_request_ids = set(request.request_id + for request in finished_requests + if request.finished) + + assert len(finished_request_ids - scheduler.requests) == 0 + assert len(scheduler.requests - finished_request_ids) == 0 + + scheduler.free_finished_request(finished_requests) + assert len(scheduler.requests) == 0 diff --git a/tests/core/prefill_only/test_prefill_only_scheduler.py b/tests/core/prefill_only/test_prefill_only_scheduler.py new file mode 100644 index 0000000000000..bbda98ccf6ff8 --- /dev/null +++ b/tests/core/prefill_only/test_prefill_only_scheduler.py @@ -0,0 +1,128 @@ +import pytest + +from vllm.config import SchedulerConfig +from vllm.core.prefill_only_scheduler import PrefillOnlyScheduler +from vllm.inputs.prefill_only.data import (TextOnlyInputs, + TextSchedulableRequest) +from vllm.inputs.prefill_only.preprocessor import Request, RequestProcessor +from vllm.model_executor.prefill_only.engine_io import RequestOutput + + +class TestRequestProcessor(RequestProcessor): + + def __init__(self, num_new_tokens): + self.num_new_tokens = num_new_tokens + + def __call__(self, request: Request) -> TextSchedulableRequest: + return TextSchedulableRequest(**request.__dict__, + inputs=TextOnlyInputs( + prompt_token_ids=[0] * + self.num_new_tokens)) + + def from_engine(cls, engine): + pass + + +@pytest.mark.parametrize("num_new_tokens", [9, 99, 199]) +@pytest.mark.parametrize("n_request", [9, 99, 199]) +@pytest.mark.parametrize("max_num_requests", [1, 2, 3, 5, 7]) +def test_limited_by_max_num_requests(n_request: int, num_new_tokens: int, + max_num_requests: int): + max_model_len = num_new_tokens + 1 + + scheduler = PrefillOnlyScheduler( + scheduler_config=SchedulerConfig(max_num_batched_tokens=max_model_len * + max_num_requests, + max_model_len=max_model_len, + max_num_seqs=max_num_requests), + request_processor=TestRequestProcessor(num_new_tokens=num_new_tokens)) + + for i in range(1, n_request + 1): + scheduler.add_request(Request(request_id=str(i), arrival_time=0.)) + + while scheduler.has_unfinished_requests(): + scheduler_output = scheduler.schedule() + + request_outputs = [ + RequestOutput(request_id=request.request_id, + arrival_time=request.arrival_time, + finished=True) + for request in scheduler_output.scheduled_requests + ] + + scheduler.free_finished_request(request_outputs) + + if scheduler.has_unfinished_requests(): + assert len(scheduler_output.scheduled_requests) == max_num_requests + else: + assert len(scheduler_output.scheduled_requests) <= max_num_requests + assert len(scheduler_output.ignored_requests) == 0 + + +@pytest.mark.parametrize("num_new_tokens", [9, 99, 199]) +@pytest.mark.parametrize("n_request", [9, 99, 199]) +@pytest.mark.parametrize("max_num_requests", [2, 3, 5, 7]) +def test_limited_by_token_budget(n_request: int, num_new_tokens: int, + max_num_requests: int): + scheduler = PrefillOnlyScheduler(scheduler_config=SchedulerConfig( + max_model_len=num_new_tokens + 1, + max_num_seqs=max_num_requests, + max_num_batched_tokens=(num_new_tokens + 1) * (max_num_requests - 1)), + request_processor=TestRequestProcessor( + num_new_tokens=num_new_tokens)) + + for i in range(1, n_request + 1): + scheduler.add_request(Request(request_id=str(i), arrival_time=0.)) + + n_scheduled_requests = 0 + while scheduler.has_unfinished_requests(): + scheduler_output = scheduler.schedule() + n_scheduled_requests += len(scheduler_output.scheduled_requests) + + request_outputs = [ + RequestOutput(request_id=request.request_id, + arrival_time=request.arrival_time, + finished=True) + for request in scheduler_output.scheduled_requests + ] + + scheduler.free_finished_request(request_outputs) + + if scheduler.has_unfinished_requests(): + assert len( + scheduler_output.scheduled_requests) == max_num_requests - 1 + else: + assert len( + scheduler_output.scheduled_requests) <= max_num_requests - 1 + assert len(scheduler_output.ignored_requests) == 0 + + assert n_scheduled_requests == n_request + + +@pytest.mark.parametrize("num_new_tokens", [9, 99, 199]) +@pytest.mark.parametrize("n_request", [9, 99, 199]) +@pytest.mark.parametrize("max_num_requests", [2, 3, 5, 7]) +def test_ignored_requests(n_request: int, num_new_tokens: int, + max_num_requests: int): + max_model_len = num_new_tokens // 2 + + scheduler = PrefillOnlyScheduler( + scheduler_config=SchedulerConfig(max_num_batched_tokens=max_model_len * + max_num_requests, + max_model_len=max_model_len, + max_num_seqs=max_num_requests), + request_processor=TestRequestProcessor(num_new_tokens=num_new_tokens)) + + for i in range(1, n_request + 1): + scheduler.add_request(Request(request_id=str(i), arrival_time=0.)) + + n_ignored_requests = 0 + while scheduler.has_unfinished_requests(): + scheduler_output = scheduler.schedule() + + assert len(scheduler_output.scheduled_requests) == 0 + assert len(scheduler_output.ignored_requests) > 0 + + n_ignored_requests += len(scheduler_output.ignored_requests) + + assert n_ignored_requests == n_request diff --git a/tests/inputs/__init__.py b/tests/inputs/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/inputs/prefill_only/__init__.py b/tests/inputs/prefill_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/inputs/prefill_only/test_input_processor.py b/tests/inputs/prefill_only/test_input_processor.py new file mode 100644 index 0000000000000..6bc95cfd76c4f --- /dev/null +++ b/tests/inputs/prefill_only/test_input_processor.py @@ -0,0 +1,81 @@ +# mypy: ignore-errors +import pytest + +from vllm.inputs.prefill_only.data import (TextOnlyInputs, TextPrompt, + TokensPrompt, ValidationError) +from vllm.inputs.prefill_only.preprocessor import TextInputProcessor + +input_processor = TextInputProcessor() + + +@pytest.fixture(scope="session") +def request_id(): + return "0" + + +def test_input_processor_1(request_id): + prompt = "test" + request = input_processor(request_id, prompt) + + assert request.inputs == {"prompt": prompt} + + +def test_input_processor_2(request_id): + prompt = "test" + inputs = TextPrompt(prompt=prompt) + request = input_processor(request_id, inputs) + + assert request.inputs == {"prompt": prompt} + + +def test_input_processor_3(request_id): + prompt_token_ids = [0] + inputs = TokensPrompt(prompt_token_ids=prompt_token_ids) + request = input_processor(request_id, inputs) + + assert request.inputs == {"prompt_token_ids": prompt_token_ids} + + +def test_input_processor_4(request_id): + prompt = "test" + prompt_token_ids = [0] + inputs = TextOnlyInputs(prompt_token_ids=prompt_token_ids) + request = input_processor(request_id, inputs) + + assert request.inputs == {"prompt_token_ids": prompt_token_ids} + + inputs = TextOnlyInputs(prompt_token_ids=prompt_token_ids, prompt=prompt) + request = input_processor(request_id, inputs) + + assert request.inputs == { + "prompt_token_ids": prompt_token_ids, + "prompt": prompt + } + + +def test_input_processor_5(request_id): + prompt = "test" + prompt_token_ids = [0] + inputs = {"prompt_token_ids": prompt_token_ids, "prompt": prompt} + + request = input_processor(request_id, inputs) + + assert request.inputs == inputs + + +def test_validation_error(request_id): + with pytest.raises(ValidationError): + inputs = {} + input_processor(request_id, inputs) + + with pytest.raises(ValidationError): + inputs = {"foo": "bar"} + input_processor(request_id, inputs) + + with pytest.raises(ValidationError): + inputs = 0 + input_processor(request_id, inputs) + + with pytest.raises(ValidationError): + inputs = 0.0 + input_processor(request_id, inputs) diff --git a/tests/inputs/prefill_only/test_request_processor.py b/tests/inputs/prefill_only/test_request_processor.py new file mode 100644 index 0000000000000..c388df79e8aa0 --- /dev/null +++ b/tests/inputs/prefill_only/test_request_processor.py @@ -0,0 +1,41 @@ +import pytest + +from vllm.inputs.prefill_only.data import TextOnlyInputs, TokensPrompt +from vllm.inputs.prefill_only.preprocessor import (TextInputProcessor, + TextRequestProcessor) +from vllm.inputs.prefill_only.tokenizer import Tokenizer + + +@pytest.fixture(scope="session") +def request_id(): + return "0" + + +TOKENIZER_NAMES = ["facebook/opt-125m", "gpt2"] + + +@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES) +def test_request_processor(request_id: str, tokenizer_name: str): + + tokenizer = Tokenizer(tokenizer_name=tokenizer_name) + input_processor = TextInputProcessor() + request_processor = TextRequestProcessor(tokenizer) + + prompt = "test" + request = input_processor(request_id, prompt) + + assert request.inputs == {"prompt": prompt} + + schedulable_request = request_processor(request) + + assert isinstance(schedulable_request.inputs, TextOnlyInputs) + assert len(schedulable_request.inputs.prompt_token_ids) > 0 + + prompt_token_ids = [0] + request = input_processor(request_id, + TokensPrompt(prompt_token_ids=prompt_token_ids)) + + schedulable_request = request_processor(request) + + assert isinstance(schedulable_request.inputs, TextOnlyInputs) + assert len(schedulable_request.inputs.prompt_token_ids) > 0 diff --git a/tests/test_inputs.py b/tests/inputs/test_inputs.py similarity index 100% rename from tests/test_inputs.py rename to tests/inputs/test_inputs.py diff --git a/tests/model_executor/encode_only/__init__.py b/tests/model_executor/encode_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/model_executor/encode_only/attention_impl/__init__.py b/tests/model_executor/encode_only/attention_impl/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/model_executor/encode_only/attention_impl/basic_correctness.py b/tests/model_executor/encode_only/attention_impl/basic_correctness.py new file mode 100644 index 0000000000000..d657c666e4252 --- /dev/null +++ b/tests/model_executor/encode_only/attention_impl/basic_correctness.py @@ -0,0 +1,87 @@ +import itertools as it +import random +from typing import TypeVar + +import pytest +import torch +import torch.nn as nn +from transformers import BatchEncoding, BatchFeature + +from tests.model_executor.prefill_only.utils import (VllmRunner, + compare_embeddings) + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = ["google-bert/bert-base-uncased"] + +AttentionImpls_fp32 = ["TORCH_SDPA", "XFORMERS", "TORCH_NAIVE"] +AttentionImpls_fp16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] +AttentionImpls_bf16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] + +AttentionImpls = { + "float": AttentionImpls_fp32, + "half": AttentionImpls_fp16, + "bfloat16": AttentionImpls_bf16, +} + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize("scheduling", ["sync"]) +@torch.inference_mode +def test_basic_correctness( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + attention_impls = AttentionImpls[dtype] + + impl_outputs_list = [] + + for attention_impl in attention_impls: + with vllm_runner( + model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling, + attention_impl=attention_impl, + ) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + impl_outputs_list.append((attention_impl, vllm_outputs)) + + tolerance = 1e-2 + for a, b in it.combinations(impl_outputs_list, 2): + similarities = compare_embeddings(a[1], b[1]) + all_similarities = torch.stack(similarities) + + assert torch.all( + (all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"{a[0]} vs {b[0]}, not all values are within {tolerance} of 1.0" diff --git a/tests/model_executor/encode_only/models/__init__.py b/tests/model_executor/encode_only/models/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/model_executor/encode_only/models/test_bert.py b/tests/model_executor/encode_only/models/test_bert.py new file mode 100644 index 0000000000000..d59b8885f83f4 --- /dev/null +++ b/tests/model_executor/encode_only/models/test_bert.py @@ -0,0 +1,81 @@ +import random +from typing import List, TypeVar + +import pytest +import torch +import torch.nn as nn +from transformers import BatchEncoding, BatchFeature, BertModel + +from tests.model_executor.prefill_only.utils import (HfRunner, VllmRunner, + compare_embeddings) + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +class BertHfRunner(HfRunner): + + @torch.inference_mode + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + encoded_input = self.tokenizer(prompts, + padding=True, + truncation=True, + return_tensors="pt").to("cuda") + + outputs = self.model(**encoded_input).pooler_output + return outputs + + +@pytest.fixture(scope="session") +def hf_runner(): + return BertHfRunner + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = ["google-bert/bert-base-uncased"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_num_seqs", [2, 3, 5, 7]) +@pytest.mark.parametrize("scheduling", ["sync", "async", "double_buffer"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + with hf_runner(model, dtype=dtype, auto_cls=BertModel) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/model_executor/prefill_only/__init__.py b/tests/model_executor/prefill_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/model_executor/prefill_only/test_loader.py b/tests/model_executor/prefill_only/test_loader.py new file mode 100644 index 0000000000000..aa2e86fc25f29 --- /dev/null +++ b/tests/model_executor/prefill_only/test_loader.py @@ -0,0 +1,56 @@ +import pytest + +from vllm.attention.prefill_only.abstract import PrefillOnlyAttentionBackend +from vllm.attention.prefill_only.selector import (AttentionImpls, + AttentionType, AttnBackend, + _Backend) +from vllm.config import DeviceConfig, LoadConfig, ModelConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.prefill_only.loader import (get_model_loader, + initialize_model) +from vllm.model_executor.prefill_only.utils import fix_distributed_environment +from vllm.utils import DeviceMemoryProfiler + +MODELS = ["google-bert/bert-base-uncased"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) +@pytest.mark.parametrize("attn_type", ["DECODER"]) +def test_loader(model: str, dtype: str, attn_type: str): + attention_impls = AttentionImpls[dtype] + for attention_impl in attention_impls: + selected_backend = _Backend.backend_name_to_enum(attention_impl) + backend_cls = AttnBackend.get_backend_cls(selected_backend) + attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type) + attn_backend = backend_cls(attn_type_enum) + + engine_args = EngineArgs(model=model) + engine_config = engine_args.create_engine_config() + + model_memory_usage = load_model(engine_config.model_config, + engine_config.load_config, + engine_config.device_config, + attn_backend=attn_backend) + + assert model_memory_usage > 0 + + +def load_model(model_config: ModelConfig, load_config: LoadConfig, + device_config: DeviceConfig, + attn_backend: PrefillOnlyAttentionBackend): + fix_distributed_environment() + + with DeviceMemoryProfiler() as m: + loader = get_model_loader(load_config) + model = initialize_model(model_config=model_config, + load_config=load_config, + device_config=device_config, + attn_backend=attn_backend) + + loader.load_model(model, + model_config=model_config, + device_config=device_config) + + model_memory_usage = m.consumed_memory + return model_memory_usage diff --git a/tests/model_executor/prefill_only/utils.py b/tests/model_executor/prefill_only/utils.py new file mode 100644 index 0000000000000..f620126c08808 --- /dev/null +++ b/tests/model_executor/prefill_only/utils.py @@ -0,0 +1,134 @@ +import gc +import os +from typing import Any, Dict, List, Optional, TypeVar + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding, + BatchFeature) + +from vllm.entrypoints.wde_llm import LLM +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_cpu + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +class VllmRunner: + + def __init__(self, + model_name: str, + max_num_seqs: int = 4, + tokenizer_name: Optional[str] = None, + dtype: str = "half", + scheduling: str = "sync", + attention_impl: Optional[str] = None, + **kwargs) -> None: + if attention_impl is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = attention_impl + + self.model = LLM(model=model_name, + tokenizer=tokenizer_name, + trust_remote_code=True, + max_num_seqs=max_num_seqs, + dtype=dtype, + scheduling=scheduling, + **kwargs) + + if attention_impl is not None: + assert (self.model.llm_engine.attn_backend.get_name().lower() == + attention_impl.lower()) + + def encode(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs + outputs.append(embedding) + return outputs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +class HfRunner: + + def wrap_device(self, input: _T) -> _T: + if not is_cpu(): + # Check if the input is already on the GPU + if hasattr(input, "device") and input.device.type == "cuda": + return input # Already on GPU, no need to move + return input.to("cuda") + else: + # Check if the input is already on the CPU + if hasattr(input, "device") and input.device.type == "cpu": + return input # Already on CPU, no need to move + return input.to("cpu") + + def __init__( + self, + model_name: str, + dtype: str = "half", + *, + model_kwargs: Optional[Dict[str, Any]] = None, + auto_cls=AutoModelForCausalLM, + ) -> None: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + + self.model_name = model_name + + model_kwargs = model_kwargs if model_kwargs is not None else {} + + self.model = self.wrap_device( + auto_cls.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + **model_kwargs, + )) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + @torch.inference_mode + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + encoded_input = self.tokenizer(prompts, + padding=True, + truncation=True, + return_tensors="pt").to("cuda") + + logits = self.model(**encoded_input).logits + seq_len = encoded_input.attention_mask.sum(axis=1) + + logits_list = [] + for e, s in zip(logits, seq_len): + logits_list.append(e[:s]) + return logits_list + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +def compare_embeddings(embeddings1, embeddings2): + similarities = [ + F.cosine_similarity(e1, e2, dim=0) + for e1, e2 in zip(embeddings1, embeddings2) + ] + return similarities + + +def cleanup(): + gc.collect() + if not is_cpu(): + torch.cuda.empty_cache() diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2bc36ff18a96b..c16e4126c016c 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -18,6 +18,19 @@ class AttentionType(Enum): ENCODER = auto() # Encoder attention between previous layer Q/K/V ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + @staticmethod + def attn_type_name_to_enum(attn_type: str) -> "AttentionType": + assert attn_type is not None + + attn_type_members = AttentionType.__members__ + if attn_type not in attn_type_members: + raise ValueError( + f"Invalid attn_type '{attn_type}'. " + f"Available backends: {', '.join(attn_type_members)} " + "(case-sensitive).") + + return AttentionType[attn_type] + class AttentionBackend(ABC): """Abstract class for attention backends.""" diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ecf964fa49d9b..4e15427704706 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.attention import AttentionMetadata, AttentionType +from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import ( @@ -36,6 +36,7 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, prefix: str = "", + attn_backend: Optional[AttentionBackend] = None, ) -> None: super().__init__() if cache_config is not None: @@ -73,14 +74,18 @@ def __init__( self.quant_method = quant_method self.quant_method.create_weights(self) - # During model initialization, the default dtype is set as the model - # weight and activation dtype. - dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, - sliding_window, dtype, kv_cache_dtype, - block_size, blocksparse_params - is not None) - impl_cls = attn_backend.get_impl_cls() + if attn_backend is None: + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + + dtype = torch.get_default_dtype() + self.attn_backend = get_attn_backend( + num_heads, head_size, num_kv_heads, sliding_window, dtype, + kv_cache_dtype, block_size, blocksparse_params is not None)() + else: + self.attn_backend = attn_backend + + impl_cls = self.attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap) @@ -94,6 +99,11 @@ def forward( attn_metadata: AttentionMetadata, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: + if hasattr(self.attn_backend, "attn_type"): + return self.impl.forward(query, key, value, kv_cache, + attn_metadata, self._k_scale, + self._v_scale, + self.attn_backend.attn_type) return self.impl.forward(query, key, diff --git a/vllm/attention/prefill_only/__init__.py b/vllm/attention/prefill_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/attention/prefill_only/abstract.py b/vllm/attention/prefill_only/abstract.py new file mode 100644 index 0000000000000..292b03b668af0 --- /dev/null +++ b/vllm/attention/prefill_only/abstract.py @@ -0,0 +1,125 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyAttentionBackend(ABC): + + def __init__(self, attn_type: AttentionType): + if attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyAttentionBackend") + + self._attn_type = attn_type + + @property + def attn_type(self) -> AttentionType: + return self._attn_type + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type["PrefillOnlyAttentionImpl"]: + raise NotImplementedError + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]: + return PrefillOnlyAttentionMetadata + + @classmethod + def make_metadata(cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]: + return PrefillOnlyAttentionMetadataBuilder + + @classmethod + def make_metadata_builder( + cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + + +@dataclass +class PrefillOnlyAttentionMetadata: + max_seq_len: int + seq_lens: List[int] + + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + def to(self, device, non_blocking=False): + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + self.__dict__[k] = v.to(device, non_blocking=non_blocking) + + return self + + +T = TypeVar("T", bound=PrefillOnlyAttentionMetadata) + + +class PrefillOnlyAttentionMetadataBuilder(Generic[T]): + + def __call__(self, seq_lens: List[int]): + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.long, + pin_memory=pin_memory, + device="cpu") + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + return PrefillOnlyAttentionMetadata(seq_lens=seq_lens, + max_seq_len=max(seq_lens), + seq_start_loc=seq_start_loc) + + +class PrefillOnlyAttentionImpl(ABC): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: T, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/prefill_only/flash_attn.py b/vllm/attention/prefill_only/flash_attn.py new file mode 100644 index 0000000000000..63566c2bf629a --- /dev/null +++ b/vllm/attention/prefill_only/flash_attn.py @@ -0,0 +1,126 @@ +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.attention.prefill_only.abstract import (AttentionType, + PrefillOnlyAttentionBackend, + PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata) + + +class PrefillOnlyFlashAttentionBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "flash_attn" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyFlashAttentionImpl"]: + return PrefillOnlyFlashAttentionImpl + + +class PrefillOnlyFlashAttentionImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + PrefillOnlyFlashAttentionBackend.get_supported_head_sizes()) + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + from vllm_flash_attn import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: PrefillOnlyAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + + assert kv_cache is None + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyFlashAttentionImpl") + + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + attn_output = self.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_seq_len, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=causal, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + + # Reshape the output tensor. + return attn_output.view(num_tokens, hidden_size) diff --git a/vllm/attention/prefill_only/flashinfer.py b/vllm/attention/prefill_only/flashinfer.py new file mode 100644 index 0000000000000..846d74a27f2d9 --- /dev/null +++ b/vllm/attention/prefill_only/flashinfer.py @@ -0,0 +1,22 @@ +from typing import Type + +from vllm.attention.prefill_only.flash_attn import ( + PrefillOnlyFlashAttentionBackend, PrefillOnlyFlashAttentionImpl) + + +class PrefillOnlyFlashInferBackend(PrefillOnlyFlashAttentionBackend): + + @staticmethod + def get_name() -> str: + return "flashinfer" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyFlashInferImpl"]: + return PrefillOnlyFlashInferImpl + + +class PrefillOnlyFlashInferImpl(PrefillOnlyFlashAttentionImpl): + # Because prefill only models do not involve kv cache, + # When using Flashinfer backend in prefill only models, + # you are actually using FLASH ATTN backend + pass diff --git a/vllm/attention/prefill_only/selector.py b/vllm/attention/prefill_only/selector.py new file mode 100644 index 0000000000000..58c6d2662f624 --- /dev/null +++ b/vllm/attention/prefill_only/selector.py @@ -0,0 +1,140 @@ +import enum +from typing import Optional + +import torch + +import vllm.envs as envs +from vllm.attention.prefill_only.abstract import AttentionType +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + OPENVINO = enum.auto() + FLASHINFER = enum.auto() + PALLAS = enum.auto() + IPEX = enum.auto() + TORCH_NAIVE = enum.auto() + + @staticmethod + def backend_name_to_enum(backend_name: str) -> "_Backend": + assert backend_name is not None + + backend_members = _Backend.__members__ + if backend_name not in backend_members: + raise ValueError( + f"Invalid attention backend '{backend_name}'. " + f"Available backends: {', '.join(backend_members)} " + "(case-sensitive).") + + return _Backend[backend_name] + + +class AttnBackend: + + @classmethod + def from_engine(cls, engine): + model_config = engine.engine_config.model_config + num_heads = model_config.get_num_attention_heads() + head_size = model_config.get_head_size() + num_kv_heads = model_config.get_num_kv_heads() + sliding_window = model_config.get_sliding_window() + dtype = model_config.dtype + + backend = cls.which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype) + + backend_cls = cls.get_backend_cls(backend) + + attn_type = AttentionType.attn_type_name_to_enum( + engine.workflow.attn_type) + + return backend_cls(attn_type) + + @staticmethod + def get_backend_cls(backend): + if backend == _Backend.FLASH_ATTN: + logger.info("Using FLASH ATTN backend.") + from vllm.attention.prefill_only.flash_attn import ( # noqa: E501 + PrefillOnlyFlashAttentionBackend) + return PrefillOnlyFlashAttentionBackend + if backend == _Backend.XFORMERS: + logger.info("Using XFormers backend.") + from vllm.attention.prefill_only.xformers import ( # noqa: E501 + PrefillOnlyXFormersBackend) + return PrefillOnlyXFormersBackend + elif backend == _Backend.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + from vllm.attention.prefill_only.torch_sdpa import ( # noqa: E501 + PrefillOnlyTorchSDPABackend) + return PrefillOnlyTorchSDPABackend + elif backend == _Backend.FLASHINFER: + logger.info("Using Flashinfer backend.") + logger.info("When using Flashinfer backend in encode only models, " + "you are actually using FLASH ATTN backend") + from vllm.attention.prefill_only.flashinfer import ( # noqa: E501 + PrefillOnlyFlashInferBackend) + return PrefillOnlyFlashInferBackend + elif backend == _Backend.TORCH_NAIVE: + logger.info("Using Torch naive backend.") + from vllm.attention.prefill_only.torch_naive import ( # noqa: E501 + PrefillOnlyTorchNAIVEBackend) + return PrefillOnlyTorchNAIVEBackend + else: + raise ValueError("Invalid attention backend.") + + @classmethod + def which_attn_to_use(cls, num_heads: int, head_size: int, + num_kv_heads: int, sliding_window: Optional[int], + dtype: torch.dtype): + # Default case. + selected_backend = _Backend.FLASH_ATTN + + # get_env_variable_attn_backend + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = _Backend.backend_name_to_enum( + backend_by_env_var) + + # FlashAttn in NVIDIA GPUs. + if selected_backend == _Backend.FLASH_ATTN: + if current_platform.get_device_capability()[0] < 8: + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + selected_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + selected_backend = _Backend.XFORMERS + elif sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window." + ) + selected_backend = _Backend.XFORMERS + + return selected_backend + + +AttentionImpls_fp32 = ["TORCH_SDPA", "XFORMERS", "TORCH_NAIVE"] +AttentionImpls_fp16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] +AttentionImpls_bf16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] + +AttentionImpls = { + "float": AttentionImpls_fp32, + "half": AttentionImpls_fp16, + "bfloat16": AttentionImpls_bf16, +} \ No newline at end of file diff --git a/vllm/attention/prefill_only/torch_naive.py b/vllm/attention/prefill_only/torch_naive.py new file mode 100644 index 0000000000000..ac9df84a101bf --- /dev/null +++ b/vllm/attention/prefill_only/torch_naive.py @@ -0,0 +1,150 @@ +import math +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.attention.prefill_only.abstract import (AttentionType, + PrefillOnlyAttentionBackend, + PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata) +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyTorchNAIVEBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "torch_naive" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyTorchNaiveBackendImpl"]: + return PrefillOnlyTorchNaiveBackendImpl + + +class PrefillOnlyTorchNaiveBackendImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch Naive does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch Naive does not support logits soft cap.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch Naive backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: PrefillOnlyAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + + assert kv_cache is None + + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in Torch Naive.") + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyTorchNaiveBackendImpl") + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + sub_out = scaled_dot_product_attention( + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], + is_causal=causal, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def scaled_dot_product_attention(query, + key, + value, + attn_mask=None, + is_causal=False, + scale=None) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool, + device=query.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value diff --git a/vllm/attention/prefill_only/torch_sdpa.py b/vllm/attention/prefill_only/torch_sdpa.py new file mode 100644 index 0000000000000..8f0d806fd669a --- /dev/null +++ b/vllm/attention/prefill_only/torch_sdpa.py @@ -0,0 +1,124 @@ +from typing import Any, Dict, List, Optional, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from vllm.attention.prefill_only.abstract import (AttentionType, + PrefillOnlyAttentionBackend, + PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata) +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyTorchSDPABackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "torch_sdpa" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyTorchSDPABackendImpl"]: + return PrefillOnlyTorchSDPABackendImpl + + +class PrefillOnlyTorchSDPABackendImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch SPDA does not support logits soft cap.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch SDPA backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: PrefillOnlyAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + + assert kv_cache is None + + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in TorchSDPA.") + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyTorchSDPABackendImpl") + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + sub_out = scaled_dot_product_attention( + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], + dropout_p=0.0, + is_causal=causal, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/attention/prefill_only/xformers.py b/vllm/attention/prefill_only/xformers.py new file mode 100644 index 0000000000000..4497afadc49e4 --- /dev/null +++ b/vllm/attention/prefill_only/xformers.py @@ -0,0 +1,121 @@ +from typing import Any, Dict, List, Optional, Type + +import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + BlockDiagonalMask) + +from vllm.attention.prefill_only.abstract import (AttentionType, + PrefillOnlyAttentionBackend, + PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class PrefillOnlyXFormersBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "xformers" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyXFormersImpl"]: + return PrefillOnlyXFormersImpl + + +class PrefillOnlyXFormersImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError( + "XFormers does not support attention logits soft capping.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: PrefillOnlyAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + + assert kv_cache is None + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyXFormersImpl") + original_query = query + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + if causal: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) + else: + attn_bias = BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) + + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + out = xops.memory_efficient_attention_forward(query, + key, + value, + p=0.0, + attn_bias=attn_bias, + scale=self.scale) + return out.view_as(original_query) diff --git a/vllm/core/prefill_only_scheduler.py b/vllm/core/prefill_only_scheduler.py new file mode 100644 index 0000000000000..e6ac228370e6a --- /dev/null +++ b/vllm/core/prefill_only_scheduler.py @@ -0,0 +1,175 @@ +from abc import ABC, abstractmethod +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Iterable, List, Set, Union, cast + +from vllm.config import SchedulerConfig +from vllm.inputs.prefill_only.data import Request, SchedulableRequest +from vllm.inputs.prefill_only.preprocessor import RequestProcessor +from vllm.logger import init_logger +from vllm.model_executor.prefill_only.engine_io import ( + PrefillOnlySchedulerOutput, RequestOutput, SchedulerOutput) + +logger = init_logger(__name__) + + +class Scheduler(ABC): + support_scheduling: List[str] = [] + + def __init__( + self, + scheduler_config: SchedulerConfig, + request_processor: RequestProcessor, + ) -> None: + self.scheduler_config = scheduler_config + self.request_processor = request_processor + + self.waiting: Deque[Request] = deque() + + self.requests: Set[str] = set() + self.aborted_requests: Set[str] = set() + + @classmethod + def from_engine(cls, engine) -> "Scheduler": + raise NotImplementedError + + def add_request(self, request: Request) -> None: + if (request.request_id in self.requests + or request.request_id in self.aborted_requests): + logger.warning("[%s] request_id conflict") + return + + self.waiting.append(request) + self.requests.add(request.request_id) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + + self.requests -= request_ids + self.aborted_requests |= request_ids + + def remove_abort_request( + self, request_outputs: List[RequestOutput]) -> List[RequestOutput]: + if len(self.aborted_requests) == 0: + return request_outputs + + current_ids = set(request.request_id for request in request_outputs) + need_abort = self.aborted_requests & current_ids + + if len(need_abort) == 0: + return request_outputs + + request_outputs = [ + request for request in request_outputs + if request.request_id not in need_abort + ] + self.aborted_requests -= need_abort + + return request_outputs + + def has_unfinished_requests(self) -> bool: + return len(self.requests) != 0 + + def get_num_unfinished_requests(self) -> int: + return len(self.requests) + + @abstractmethod + def schedule(self) -> SchedulerOutput: + raise NotImplementedError + + def free_finished_request(self, request_outputs: List[RequestOutput]): + finished_request_ids = set(request.request_id + for request in request_outputs + if request.finished) + self.requests -= finished_request_ids + + +@dataclass +class PrefillOnlySchedulingBudget: + token_budget: int + max_num_requests: int + _curr_requests: Set[str] = field(default_factory=set) + _num_batched_tokens: int = 0 + + def can_schedule(self, *, num_new_tokens: int, num_new_request: int = 1): + assert num_new_tokens != 0 + assert num_new_request != 0 + a = self.num_batched_tokens + num_new_tokens <= self.token_budget + b = self.num_curr_request + num_new_request <= self.max_num_requests + return a and b + + def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + if req_id in self._curr_requests: + return + + self._curr_requests.add(req_id) + self._num_batched_tokens += num_batched_tokens + + @property + def num_batched_tokens(self): + return self._num_batched_tokens + + @property + def num_curr_request(self): + return len(self._curr_requests) + + +class PrefillOnlyScheduler(Scheduler): + support_scheduling = ["sync_scheduling", "async_scheduling"] + + def __init__( + self, + scheduler_config: SchedulerConfig, + request_processor: RequestProcessor, + ) -> None: + super().__init__(scheduler_config, request_processor) + + @classmethod + def from_engine(cls, engine): + return cls(engine.engine_config.scheduler_config, + engine.request_processor) + + def schedule(self) -> PrefillOnlySchedulerOutput: + budget = PrefillOnlySchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_requests=self.scheduler_config.max_num_seqs, + ) + + waiting_queue = self.waiting + + scheduled_requests = [] + ignored_requests = [] + while waiting_queue: + request = waiting_queue[0] + + if request.request_id in self.aborted_requests: + self.aborted_requests.remove(request.request_id) + waiting_queue.popleft() + continue + + if not isinstance(request, SchedulableRequest): + request = self.request_processor(request) + waiting_queue[0] = request + + request = cast(SchedulableRequest, request) + + num_new_tokens = request.num_new_tokens + + if num_new_tokens > self.scheduler_config.max_model_len: + self.requests.remove(request.request_id) + waiting_queue.popleft() + ignored_requests.append(request) + continue + + if not budget.can_schedule(num_new_tokens=num_new_tokens): + break + + budget.add_num_batched_tokens(request.request_id, num_new_tokens) + waiting_queue.popleft() + scheduled_requests.append(request) + + return PrefillOnlySchedulerOutput( + scheduled_requests=scheduled_requests, + ignored_requests=ignored_requests) diff --git a/vllm/engine/wde_engine.py b/vllm/engine/wde_engine.py new file mode 100644 index 0000000000000..b962d2b605702 --- /dev/null +++ b/vllm/engine/wde_engine.py @@ -0,0 +1,184 @@ +from queue import Empty, Queue +from typing import Dict, Iterable, List, Optional, Type, Union + +from vllm.config import EngineConfig +from vllm.inputs.prefill_only.data import Inputs, Params +from vllm.logger import init_logger +from vllm.model_executor.prefill_only.arg_utils import EngineArgs +from vllm.model_executor.prefill_only.engine_io import RequestOutput +from vllm.model_executor.prefill_only.workflow import Workflow + +logger = init_logger(__name__) + + +def lazy_import(module): + module_name, class_name = module.split(":") + import importlib + module = importlib.import_module(module_name) + return getattr(module, class_name) + + +class LLMEngine: + + def __init__(self, engine_config: EngineConfig, + workflow_cls: Type[Workflow]) -> None: + self.engine_config = engine_config + self.engine_config.log_config() + self.workflow = workflow_cls.from_engine(self) + + self._maybe_init_async_scheduling() + + self.attn_backend = lazy_import( + self.workflow.AttnBackend).from_engine(self) + self.executor = lazy_import(self.workflow.Executor).from_engine(self) + self.tokenizer = lazy_import(self.workflow.Tokenizer).from_engine(self) + self.model_inputs_builder = lazy_import( + self.workflow.ModelInputBuilder).from_engine(self) + + self.input_processor = lazy_import( + self.workflow.InputProcessor).from_engine(self) + self.request_processor = lazy_import( + self.workflow.RequestProcessor).from_engine(self) + self.scheduler = lazy_import(self.workflow.Scheduler).from_engine(self) + self.output_processor = lazy_import( + self.workflow.OutputProcessor).from_engine(self) + + def _maybe_init_async_scheduling(self): + executor_cls = lazy_import(self.workflow.Executor) + scheduler_cls = lazy_import(self.workflow.Scheduler) + + if ("async_scheduling" in executor_cls.support_scheduling + and "async_scheduling" in scheduler_cls.support_scheduling): + logger.info("Use async scheduling") + self.use_async_scheduling = True + + elif ("sync_scheduling" in executor_cls.support_scheduling + and "sync_scheduling" in scheduler_cls.support_scheduling): + logger.info("Use sync scheduling") + self.use_async_scheduling = False + + else: + raise RuntimeError(f"Executor support scheduling: " + f"{executor_cls.support_scheduling}." + f"Scheduler support scheduling: " + f"{executor_cls.support_scheduling}." + f"Not compatible") + + if self.use_async_scheduling: + self.executor_in: Queue = Queue() + self.executor_out: Queue = Queue() + self.max_num_on_the_fly = ( + self.engine_config.scheduler_config.max_num_on_the_fly) + self.num_on_the_fly = 0 + self.step = self.async_step + else: + self.step = self.sync_step + + @classmethod + def from_engine_args(cls, engine_args: Union[Dict, + EngineArgs]) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + + from vllm.model_executor.prefill_only.loader.utils import ( + get_model_workflow) + from vllm.transformers_utils.config import get_config + + if isinstance(engine_args, EngineArgs): + engine_args = engine_args.to_dict() + + hf_config = get_config(engine_args["model"], + engine_args.get("trust_remote_code", False), + engine_args.get("revision", None), + engine_args.get("code_revision", None)) + + workflow_class = get_model_workflow(hf_config) + workflow = lazy_import(workflow_class) + + engine_args = lazy_import(workflow.EngineArgs)(**engine_args) + + engine_config = engine_args.create_engine_config() + engine = cls(engine_config, workflow) + return engine + + def add_request(self, + request_id: str, + inputs: Optional[Union[str, Inputs]] = None, + params: Optional[Params] = None, + arrival_time: Optional[float] = None) -> None: + request = self.input_processor(request_id, inputs, params, + arrival_time) + + # The raised ValidationError will be passed to the upper call stack + self.scheduler.add_request(request) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + self.scheduler.abort_request(request_id) + + def sync_step(self) -> List[RequestOutput]: + scheduler_output = self.scheduler.schedule() + if scheduler_output.is_empty(): + return [] + + executor_input = self.model_inputs_builder(scheduler_output) + executor_output = self.executor.execute_model(executor_input) + request_outputs = self.output_processor(scheduler_output, + executor_output) + self.scheduler.free_finished_request(request_outputs) + request_outputs = self.scheduler.remove_abort_request(request_outputs) + return request_outputs + + def async_step(self) -> List[RequestOutput]: + self.executor.ensure_start_execute_loop() + self._put_as_many_as_possible() + + if self.num_on_the_fly == 0: + return [] + + return self._get(block=True) + + def _put_as_many_as_possible(self): + while self.num_on_the_fly < self.max_num_on_the_fly: + scheduler_output = self.scheduler.schedule() + if scheduler_output.is_empty(): + break + executor_input = self.model_inputs_builder(scheduler_output) + + self.executor_in.put((scheduler_output, executor_input)) + self.num_on_the_fly += 1 + + def _get(self, block): + try: + scheduler_output, executor_output = self.executor_out.get(block) + except Empty: + return + + self.num_on_the_fly -= 1 + + # Theoretically, this put is not needed + # practically, task can be inqueue before doing post-processing + self._put_as_many_as_possible() + + request_outputs = self.output_processor(scheduler_output, + executor_output) + self.scheduler.free_finished_request(request_outputs) + request_outputs = self.scheduler.remove_abort_request(request_outputs) + return request_outputs + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return self.scheduler.get_num_unfinished_requests() + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return self.scheduler.has_unfinished_requests() + + def __reduce__(self): + # This is to ensure that the LLMEngine is not referenced in + # the closure used to initialize Ray worker actors + raise RuntimeError("LLMEngine should not be pickled!") + + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if executor := getattr(self, "executor", None): + executor.shutdown_execute_loop() diff --git a/vllm/entrypoints/wde_llm.py b/vllm/entrypoints/wde_llm.py new file mode 100644 index 0000000000000..f279d44204ced --- /dev/null +++ b/vllm/entrypoints/wde_llm.py @@ -0,0 +1,144 @@ +from typing import List, Optional, Sequence, Union, cast + +from tqdm import tqdm + +from vllm.engine.wde_engine import LLMEngine +from vllm.inputs.prefill_only.data import Params +from vllm.inputs.prefill_only.data import TextOnlyInputs as PromptInputs +from vllm.inputs.prefill_only.data import ValidationError +from vllm.logger import init_logger +from vllm.model_executor.prefill_only.engine_io import RequestOutput +from vllm.utils import Counter + +logger = init_logger(__name__) + + +class LLM: + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + cpu_offload_gb: float = 0, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + removed_vision_keys = ("image_token_id", "image_feature_size", + "image_input_shape", "image_input_type") + if any(k in kwargs for k in removed_vision_keys): + raise TypeError( + "There is no need to pass vision-related arguments anymore.") + engine_args = dict( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + **kwargs, + ) + self.llm_engine = LLMEngine.from_engine_args(engine_args) + self.request_counter = Counter() + + def encode( + self, + inputs: Union[Union[PromptInputs, Sequence[PromptInputs]], + Optional[Union[str, List[str]]]] = None, + pooling_params: Optional[Union[Params, Sequence[Params]]] = None, + use_tqdm: bool = True, + ) -> List[RequestOutput]: + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], inputs) + + if pooling_params is None: + # Use default pooling params. + pooling_params = Params() + + self._validate_and_add_requests( + inputs=inputs, + params=pooling_params, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return outputs + + def _validate_and_add_requests( + self, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + params: Optional[Union[Params, Sequence[Params]]] = None, + ) -> None: + + params = params if params is not None else Params() + if isinstance(inputs, PromptInputs): + inputs = [inputs] + + # Add requests to the engine. + for i, request_inputs in enumerate(inputs): + try: + self._add_request( + request_inputs, + params[i] if isinstance(params, Sequence) else params) + except ValidationError as e: + raise e + + def _add_request( + self, + inputs: PromptInputs, + params: Params, + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request(request_id, inputs, params) + + def _run_engine(self, *, use_tqdm: bool) -> List[RequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, " + f"output: {0:.2f} toks/s"), + ) + # Run the engine. + outputs: List[RequestOutput] = [] + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + pbar.update(1) + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id)) diff --git a/vllm/executor/prefill_only_gpu_executor.py b/vllm/executor/prefill_only_gpu_executor.py new file mode 100644 index 0000000000000..042478d486934 --- /dev/null +++ b/vllm/executor/prefill_only_gpu_executor.py @@ -0,0 +1,208 @@ +import atexit +import queue +from queue import Queue +from threading import Thread +from typing import Optional + +import torch + +from vllm.attention.prefill_only.abstract import PrefillOnlyAttentionBackend +from vllm.config import EngineConfig +from vllm.logger import init_logger +from vllm.model_executor.prefill_only.execute_io import (ExecuteInput, + ExecuteOutput) +from vllm.model_executor.prefill_only.workflow import Workflow +from vllm.worker.prefill_only_gpu_worker import WorkerBase, create_worker + +logger = init_logger(__name__) + + +class GPUExecutor: + support_scheduling = ["sync_scheduling"] + + def __init__( + self, + engine_config: EngineConfig, + workflow: Workflow, + attn_backend: PrefillOnlyAttentionBackend, + ) -> None: + self.engine_config = engine_config + self.workflow = workflow + self.attn_backend = attn_backend + self.output_to_cpu = False + self._init_executor() + + @classmethod + def from_engine(cls, engine): + return cls(engine_config=engine.engine_config, + workflow=engine.workflow, + attn_backend=engine.attn_backend) + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + + worker_kwargs = dict( + engine_config=self.engine_config, + attn_backend=self.attn_backend, + ) + worker_kwargs.update(module=self.workflow.Worker) + + self.worker = create_worker(**worker_kwargs) + self.worker.init_device() + self.worker.load_model() + + def execute_model(self, + executor_input: ExecuteInput) -> Optional[ExecuteOutput]: + executor_input.model_input.to(self.worker.device) + output = self.worker(executor_input) + if self.output_to_cpu: + output.to("cpu") + return output + + def shutdown_execute_loop(self): + pass + + +class GPUAsyncExecutor(GPUExecutor): + support_scheduling = ["async_scheduling"] + + def __init__(self, engine_config: EngineConfig, workflow: Workflow, + attn_backend: PrefillOnlyAttentionBackend, executor_in: Queue, + executor_out: Queue) -> None: + super().__init__(engine_config, workflow, attn_backend) + self.executor_in = executor_in + self.executor_out = executor_out + + self.executor_thread: Optional[Thread] = None + + if self.engine_config.scheduler_config.scheduling == "double_buffer": + self.execute_loop = double_buffer_execute_loop + else: + self.execute_loop = simple_execute_loop + + @classmethod + def from_engine(cls, engine): + return cls(engine_config=engine.engine_config, + workflow=engine.workflow, + attn_backend=engine.attn_backend, + executor_in=engine.executor_in, + executor_out=engine.executor_out) + + def ensure_start_execute_loop(self): + if self.executor_thread is None or not self.executor_thread.is_alive(): + self.executor_thread = Thread(target=self.execute_loop, + args=(self.worker, self.executor_in, + self.executor_out, + self.output_to_cpu), + daemon=True) + self.executor_thread.start() + atexit.register(self.shutdown_execute_loop) + + def shutdown_execute_loop(self): + if (self.executor_thread is not None + and self.executor_thread.is_alive()): + self.executor_in.put(None) + self.executor_thread.join() + atexit.unregister(self.shutdown_execute_loop) + + +def simple_execute_loop(worker: WorkerBase, + executor_in: Queue, + executor_out: Queue, + output_to_cpu: bool = False): + + def execute_model(executor_input: ExecuteInput) -> ExecuteOutput: + executor_input.model_input.to(worker.device) + output = worker(executor_input) + if output_to_cpu: + output.to("cpu") + return output + + while True: + o = executor_in.get() + if o is None: + break + + scheduler_output, executor_input = o + executor_output = execute_model(executor_input) + executor_out.put((scheduler_output, executor_output)) + + +def double_buffer_execute_loop(worker: WorkerBase, + executor_in: Queue, + executor_out: Queue, + output_to_cpu: bool = False): + from dataclasses import dataclass + + from vllm.model_executor.prefill_only.engine_io import SchedulerOutput + + @dataclass + class Task: + scheduler_output: SchedulerOutput + executor_input: ExecuteInput + executor_output: Optional[ExecuteOutput] + + @classmethod + def get(cls, block): + o = executor_in.get(block) + if o is None: + return None + + scheduler_output, executor_input = o + + task = cls(scheduler_output=scheduler_output, + executor_input=executor_input, + executor_output=None) + return task + + current_task: Optional[Task] = None + next_task: Optional[Task] = None + compute_stream = torch.cuda.Stream() + io_stream = torch.cuda.Stream() + + go_on = True + while go_on: + if current_task is None: + current_task = Task.get(block=True) + if current_task is None: + break + + with torch.cuda.stream(compute_stream): + current_task.executor_input.model_input.to(worker.device, + non_blocking=True) + current_task.executor_output = worker( + current_task.executor_input) + end_compute = torch.cuda.Event() + else: + with torch.cuda.stream(compute_stream): + end_compute = torch.cuda.Event() + + try: + next_task = Task.get(block=False) + if next_task is None: + go_on = False + else: + with torch.cuda.stream(io_stream): + next_task.executor_input.model_input.to(worker.device, + non_blocking=True) + + compute_stream.wait_stream(io_stream) + + with torch.cuda.stream(compute_stream): + next_task.executor_output = worker( + next_task.executor_input) + except queue.Empty: + pass + + end_compute.wait() + if output_to_cpu: + with torch.cuda.stream(io_stream): + assert current_task.executor_output is not None + current_task.executor_output.to("cpu", non_blocking=True) + io_stream.synchronize() + executor_out.put( + (current_task.scheduler_output, current_task.executor_output)) + + current_task = next_task + next_task = None diff --git a/vllm/inputs/prefill_only/__init__.py b/vllm/inputs/prefill_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/inputs/prefill_only/data.py b/vllm/inputs/prefill_only/data.py new file mode 100644 index 0000000000000..8a2595d095a8b --- /dev/null +++ b/vllm/inputs/prefill_only/data.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + + +class Params: + pass + + +class Inputs: + pass + + +@dataclass +class TextPrompt(Inputs): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + +@dataclass +class TokensPrompt(Inputs): + """Schema for a tokenized prompt.""" + + prompt_token_ids: List[int] + """A list of token IDs to pass to the model.""" + + +@dataclass +class TextOnlyInputs(Inputs): + prompt_token_ids: List[int] + """The token IDs of the prompt.""" + + prompt: Optional[str] = None + """ + The original prompt text corresponding to the token IDs, if available. + """ + + +PromptInput = Union[str, Dict, TextPrompt, TokensPrompt, TextOnlyInputs] + + +@dataclass +class Request: + request_id: str + arrival_time: float + + +@dataclass +class TextRequest(Request): + inputs: Dict + + +class ValidationError(ValueError): + pass + + +class SchedulableRequest(Request): + + @property + def num_new_tokens(self): + raise NotImplementedError + + +@dataclass +class TextSchedulableRequest(SchedulableRequest): + inputs: TextOnlyInputs + + @property + def num_new_tokens(self): + return len(self.inputs.prompt_token_ids) diff --git a/vllm/inputs/prefill_only/preprocessor.py b/vllm/inputs/prefill_only/preprocessor.py new file mode 100644 index 0000000000000..3ed327be4db6c --- /dev/null +++ b/vllm/inputs/prefill_only/preprocessor.py @@ -0,0 +1,125 @@ +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, cast + +from vllm.inputs.prefill_only.data import (Params, PromptInput, Request, + SchedulableRequest, TextOnlyInputs, + TextPrompt, TextRequest, + TextSchedulableRequest, + TokensPrompt, ValidationError) +from vllm.inputs.prefill_only.tokenizer import Tokenizer + + +class InputProcessor(ABC): + """ + Input(request_id, inputs, params, arrival_time) -> InputProcessor -> Request + """ + + @abstractmethod + def __call__(self, + request_id: str, + inputs: Optional[Any] = None, + params: Optional[Params] = None, + arrival_time: Optional[float] = None) -> Request: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine): + raise NotImplementedError + + +class TextInputProcessor(InputProcessor): + + def __call__(self, + request_id: str, + inputs: Optional[PromptInput] = None, + params: Optional[Params] = None, + arrival_time: Optional[float] = None) -> TextRequest: + + if isinstance(inputs, str): + inputs = {"prompt": inputs} + elif isinstance(inputs, TextPrompt): + inputs = {"prompt": inputs.prompt} + elif isinstance(inputs, TokensPrompt): + inputs = {"prompt_token_ids": inputs.prompt_token_ids} + elif isinstance(inputs, TextOnlyInputs): + _inputs: Dict[str, Any] = { + "prompt_token_ids": inputs.prompt_token_ids + } + + if inputs.prompt is not None: + _inputs["prompt"] = inputs.prompt + + inputs = _inputs + + elif isinstance(inputs, dict): + if "prompt" not in inputs and "prompt_token_ids" not in inputs: + raise ValidationError('"prompt" and "prompt_token_ids" ' + 'have at least one in inputs.') + inputs = { + k: v + for k, v in inputs.items() + if k in {"prompt", "prompt_token_ids"} + } + else: + raise ValidationError( + f"Input does not support {type(inputs)} data type") + + if not arrival_time: + arrival_time = time.time() + request = TextRequest(request_id=str(request_id), + inputs=inputs, + arrival_time=arrival_time) + return request + + @classmethod + def from_engine(cls, engine): + return cls() + + +class RequestProcessor(ABC): + """ + Request -> RequestProcessor -> SchedulableRequest + """ + + @abstractmethod + def __call__(self, request: Request) -> SchedulableRequest: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine): + raise NotImplementedError + + +class TextRequestProcessor(RequestProcessor): + + def __init__(self, tokenizer: Tokenizer): + self.tokenizer = tokenizer + + def __call__(self, request: Request) -> TextSchedulableRequest: + assert isinstance(request, TextRequest) + + request = cast(TextRequest, request) + + inputs = request.inputs + + if "prompt_token_ids" not in inputs: + tokenizer = self.tokenizer + + prompt_token_ids = tokenizer.encode(inputs["prompt"]) + else: + prompt_token_ids = inputs["prompt_token_ids"] + + schedulable_request = TextSchedulableRequest( + request_id=request.request_id, + inputs=TextOnlyInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt")), + arrival_time=request.arrival_time) + + return schedulable_request + + @classmethod + def from_engine(cls, engine): + return cls(engine.tokenizer) diff --git a/vllm/inputs/prefill_only/tokenizer.py b/vllm/inputs/prefill_only/tokenizer.py new file mode 100644 index 0000000000000..018983c72d698 --- /dev/null +++ b/vllm/inputs/prefill_only/tokenizer.py @@ -0,0 +1,32 @@ +from vllm.transformers_utils.tokenizer import get_tokenizer + + +class Tokenizer: + + def __init__(self, tokenizer_name: str, **kwargs): + self.tokenizer_name = tokenizer_name + self.tokenizer_kwargs = kwargs + + self.tokenizer = get_tokenizer(tokenizer_name=self.tokenizer_name, + **self.tokenizer_kwargs) + + @classmethod + def from_engine(cls, engine): + init_kwargs = dict( + tokenizer_name=engine.engine_config.model_config.tokenizer, + tokenizer_mode=engine.engine_config.model_config.tokenizer_mode, + trust_remote_code=engine.engine_config.model_config. + trust_remote_code, + revision=engine.engine_config.model_config.tokenizer_revision) + + return cls(**init_kwargs) + + def __call__(self, *args, **kwargs): + return self.tokenizer(*args, **kwargs) + + def encode(self, *args, **kwargs): + return self.tokenizer.encode(*args, **kwargs) + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id \ No newline at end of file diff --git a/vllm/model_executor/encode_only/__init__.py b/vllm/model_executor/encode_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/encode_only/arg_utils.py b/vllm/model_executor/encode_only/arg_utils.py new file mode 100644 index 0000000000000..3e7c55ccede25 --- /dev/null +++ b/vllm/model_executor/encode_only/arg_utils.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +from vllm.logger import init_logger +from vllm.model_executor.encode_only.config import (EncodeOnlyEngineConfig, + ModelConfig, + PrefillOnlySchedulerConfig) +from vllm.model_executor.prefill_only.arg_utils import EngineArgs +from vllm.model_executor.prefill_only.config import (DeviceConfig, LoadConfig, + filter_unexpected_fields) + +logger = init_logger(__name__) + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +@filter_unexpected_fields +@dataclass +class EncodeOnlyEngineArgs(EngineArgs): + """Arguments for vLLM engine.""" + model: str + served_model_name: Optional[Union[List[str]]] = None + tokenizer: Optional[str] = None + skip_tokenizer_init: bool = False + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = 'auto' + dtype: str = 'auto' + kv_cache_dtype: str = 'auto' + quantization_param_path: Optional[str] = None + disable_sliding_window: bool = False + seed: int = 0 + + max_model_len: Optional[int] = None + max_num_batched_tokens: Optional[int] = None + max_num_seqs: int = 256 + max_num_on_the_fly: int = 3 + scheduling: str = "sync" + + disable_log_stats: bool = False + revision: Optional[str] = None + code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + disable_custom_all_reduce: bool = False + device: str = 'auto' + model_loader_extra_config: Optional[dict] = None + ignore_patterns: Optional[Union[str, List[str]]] = None + + def __post_init__(self): + if self.tokenizer is None: + self.tokenizer = self.model + + def create_engine_config(self) -> EncodeOnlyEngineConfig: + device_config = DeviceConfig(device=self.device) + model_config = ModelConfig( + model=self.model, + tokenizer=self.tokenizer, # type: ignore + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name) + + scheduler_config = PrefillOnlySchedulerConfig( + max_num_batched_tokens=self.max_num_batched_tokens, + max_num_seqs=self.max_num_seqs, + max_model_len=model_config.max_model_len, + max_num_on_the_fly=self.max_num_on_the_fly, + scheduling=self.scheduling) + + load_config = LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + return EncodeOnlyEngineConfig(model_config=model_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config) diff --git a/vllm/model_executor/encode_only/config.py b/vllm/model_executor/encode_only/config.py new file mode 100644 index 0000000000000..9947c2f811fda --- /dev/null +++ b/vllm/model_executor/encode_only/config.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass, fields + +from vllm.logger import init_logger +from vllm.model_executor.prefill_only.config import ( # noqa: E501 + EngineConfig, ModelConfig, PrefillOnlySchedulerConfig) + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +@dataclass(frozen=True) +class EncodeOnlyEngineConfig(EngineConfig): + model_config: ModelConfig + scheduler_config: PrefillOnlySchedulerConfig + + def to_dict(self): + """Return the configs as a dictionary, for use in **kwargs. + """ + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) + + def log_config(self): + from vllm.version import __version__ as VLLM_VERSION + logger.info( + "Initializing an Encode Only engine (v%s) with config: " + "model=%r, tokenizer=%r, " + "tokenizer_mode=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, " + "device_config=%s, served_model_name=%s, " + "max_num_on_the_fly=%d, scheduling=%s)", VLLM_VERSION, + self.model_config.model, self.model_config.tokenizer, + self.model_config.tokenizer_mode, + self.model_config.trust_remote_code, self.model_config.dtype, + self.model_config.max_model_len, self.load_config.download_dir, + self.load_config.load_format, self.device_config.device, + self.model_config.served_model_name, + self.scheduler_config.max_num_on_the_fly, + self.scheduler_config.scheduling) diff --git a/vllm/model_executor/encode_only/engine_io.py b/vllm/model_executor/encode_only/engine_io.py new file mode 100644 index 0000000000000..756be032e537f --- /dev/null +++ b/vllm/model_executor/encode_only/engine_io.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import List + +import torch + +from vllm.model_executor.prefill_only.engine_io import RequestOutput + + +@dataclass +class EncodeOnlyRequestOutput(RequestOutput): + prompt_token_ids: List[int] + outputs: torch.Tensor \ No newline at end of file diff --git a/vllm/model_executor/encode_only/execute_io.py b/vllm/model_executor/encode_only/execute_io.py new file mode 100644 index 0000000000000..7a1293f9d4aa8 --- /dev/null +++ b/vllm/model_executor/encode_only/execute_io.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.model_executor.prefill_only.execute_io import ExecuteOutput + + +@dataclass +class EncodeOnlyExecuteOutput(ExecuteOutput): + last_hidden_states: torch.Tensor + pooled_output: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/encode_only/modelzoo/__init__.py b/vllm/model_executor/encode_only/modelzoo/__init__.py new file mode 100644 index 0000000000000..9236d1a3a648a --- /dev/null +++ b/vllm/model_executor/encode_only/modelzoo/__init__.py @@ -0,0 +1,8 @@ +TASK = "encode_only" +PREFIX = f"vllm.model_executor.{TASK}.modelzoo" +WORKFLOW = "vllm.model_executor.encode_only.workflow:EncodeOnlyWorkflow" + +# Architecture -> (module, workflow). +ENCODE_ONLY_MODELS = { + "BertForMaskedLM": (PREFIX + ".bert:BertForMaskedLM", WORKFLOW), +} diff --git a/vllm/model_executor/encode_only/modelzoo/bert.py b/vllm/model_executor/encode_only/modelzoo/bert.py new file mode 100644 index 0000000000000..1ed97a96c59fc --- /dev/null +++ b/vllm/model_executor/encode_only/modelzoo/bert.py @@ -0,0 +1,409 @@ +# Derived from Bert implementation posted on HuggingFace; license below: +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # noqa: E501 +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import BertConfig +from transformers.utils import logging + +from vllm.attention import Attention, AttentionBackend, AttentionMetadata +from vllm.model_executor.encode_only.execute_io import EncodeOnlyExecuteOutput +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import is_pp_missing_parameter + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + + def __init__(self, config: BertConfig, + quant_config: Optional[QuantizationConfig]): + super().__init__() + self.config = config + self.position_embedding_type = getattr(config, + "position_embedding_type", + "absolute") + assert self.position_embedding_type == "absolute" + + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, quant_config=quant_config) + self.token_type_embeddings0 = None + self.position_embeddings = VocabParallelEmbedding( + config.max_position_embeddings, + config.hidden_size, + quant_config=quant_config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def init_token_type_embeddings0(self): + del self.token_type_embeddings0 + self.register_buffer( + "token_type_embeddings0", + torch.zeros(self.config.hidden_size, + dtype=self.word_embeddings.weight.dtype, + device=self.word_embeddings.weight.device)) + + def forward(self, input_ids, positions): + embeddings = self.word_embeddings(input_ids) + if self.token_type_embeddings0 is not None: + token_type_embeddings = self.token_type_embeddings0 + embeddings += token_type_embeddings + + embeddings += self.position_embeddings(positions) + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + num_kv_heads = config.num_attention_heads + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = hidden_size // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + num_heads, + num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + quant_config=quant_config, + attn_backend=attn_backend) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn(q, + k, + v, + kv_cache=None, + attn_metadata=attn_metadata) + return attn_output + + +class BertSelfOutput(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + quant_config=quant_config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.self = BertSelfAttention(config, attn_backend) + self.output = BertSelfOutput(config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + + self_outputs = self.self(hidden_states, attn_metadata) + attention_output = self.output(self_outputs, hidden_states) + return attention_output + + +class BertIntermediate(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.attention = BertAttention(config, attn_backend, quant_config) + self.intermediate = BertIntermediate(config, quant_config) + self.output = BertOutput(config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + attention_output = self.attention(hidden_states, attn_metadata) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.layer = nn.ModuleList([ + BertLayer(config, attn_backend, quant_config) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, attn_metadata) + return hidden_states + + +class BertPooler(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.hidden_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + self.activation = nn.Tanh() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + seq_start_loc = attn_metadata.seq_start_loc + first_token_tensor = hidden_states[seq_start_loc[:-1]] + pooled_output, _ = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + add_pooling_layer: bool = True, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.embeddings = BertEmbeddings(config, quant_config) + self.encoder = BertEncoder(config, attn_backend, quant_config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + embedding_output = self.embeddings( + input_ids=input_ids, + positions=positions, + ) + sequence_output = self.encoder(embedding_output, attn_metadata) + pooled_output = self.pooler( + sequence_output, + attn_metadata) if self.pooler is not None else None + return sequence_output, pooled_output + + +class BertForMaskedLM(nn.Module): + _ignore_weights_keys = [ + "cls.predictions.transform.LayerNorm.gamma", + "cls.predictions.transform.dense.weight", + "cls.seq_relationship.weight", + ] + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None, + *args, + **kwargs): + super().__init__() + self.config = config + self.quant_config = quant_config + + self.bert = BertModel(config, + attn_backend, + quant_config=quant_config, + add_pooling_layer=True) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> EncodeOnlyExecuteOutput: + sequence_output, pooled_output = self.bert( + input_ids, + positions, + attn_metadata, + ) + + return EncodeOnlyExecuteOutput(last_hidden_states=sequence_output, + pooled_output=pooled_output) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v") + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + for name, loaded_weight in weights: + if hasattr(self, "prefix"): + name = self.prefix + name + + if name in self._ignore_weights_keys: + continue + + if name == "bert.embeddings.token_type_embeddings.weight": + # token_type_ids is all zero, + # so we only need token_type_embeddings[0] + self.bert.embeddings.init_token_type_embeddings0() + default_weight_loader( + self.bert.embeddings.token_type_embeddings0, + loaded_weight[0]) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # https://huggingface.co/google-bert/bert-base-uncased/discussions/70 + # https://github.com/huggingface/transformers/blob/fee86516a48c92133847fc7b44ca2f83c7c5634d/src/transformers/modeling_utils.py#L691-L720 + if "LayerNorm.gamma" in name: + name = name.replace("LayerNorm.gamma", "LayerNorm.weight") + if "LayerNorm.beta" in name: + name = name.replace("LayerNorm.beta", "LayerNorm.bias") + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, + params_dict) # type: ignore + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/encode_only/output_processor.py b/vllm/model_executor/encode_only/output_processor.py new file mode 100644 index 0000000000000..f2694bfead4c1 --- /dev/null +++ b/vllm/model_executor/encode_only/output_processor.py @@ -0,0 +1,53 @@ +from typing import List, cast + +from vllm.model_executor.encode_only.engine_io import EncodeOnlyRequestOutput +from vllm.model_executor.encode_only.execute_io import EncodeOnlyExecuteOutput +from vllm.model_executor.prefill_only.engine_io import ( + PrefillOnlySchedulerOutput, SchedulerOutput) +from vllm.model_executor.prefill_only.output_processor import (OutputProcessor, + RequestOutput) + + +class PrefillOnlyModelOutputProcessor(OutputProcessor): + + def __init__(self): + pass + + @classmethod + def from_engine(cls, engine): + return cls() + + def __call__( + self, scheduler_output: SchedulerOutput, + execute_output: EncodeOnlyExecuteOutput) -> List[RequestOutput]: + assert isinstance(scheduler_output, PrefillOnlySchedulerOutput) + scheduler_output = cast(PrefillOnlySchedulerOutput, scheduler_output) + + if execute_output.pooled_output is not None: + request_outputs = [] + for request, outputs in zip(scheduler_output.scheduled_requests, + execute_output.pooled_output): + prompt_token_ids = request.inputs.prompt_token_ids + request_outputs.append( + EncodeOnlyRequestOutput(request_id=request.request_id, + arrival_time=request.arrival_time, + prompt_token_ids=prompt_token_ids, + finished=True, + outputs=outputs)) + return request_outputs + else: + request_outputs = [] + offset = 0 + for request in scheduler_output.scheduled_requests: + prompt_token_ids = request.inputs.prompt_token_ids + n_tokens = len(prompt_token_ids) + request_outputs.append( + EncodeOnlyRequestOutput(request_id=request.request_id, + arrival_time=request.arrival_time, + prompt_token_ids=prompt_token_ids, + finished=True, + outputs=execute_output. + last_hidden_states[offset:offset + + n_tokens])) + offset += n_tokens + return request_outputs diff --git a/vllm/model_executor/encode_only/workflow.py b/vllm/model_executor/encode_only/workflow.py new file mode 100644 index 0000000000000..7de6fc364aad3 --- /dev/null +++ b/vllm/model_executor/encode_only/workflow.py @@ -0,0 +1,9 @@ +from vllm.model_executor.prefill_only.workflow import PrefillOnlyWorkflow + + +class EncodeOnlyWorkflow(PrefillOnlyWorkflow): + EngineArgs: str = ("vllm.model_executor.encode_only.arg_utils" + ":EncodeOnlyEngineArgs") + OutputProcessor: str = ("vllm.model_executor.encode_only." + "output_processor:PrefillOnlyModelOutputProcessor") + attn_type: str = "ENCODER" diff --git a/vllm/model_executor/prefill_only/__init__.py b/vllm/model_executor/prefill_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/prefill_only/arg_utils.py b/vllm/model_executor/prefill_only/arg_utils.py new file mode 100644 index 0000000000000..a10ee22c28cca --- /dev/null +++ b/vllm/model_executor/prefill_only/arg_utils.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass, fields +from typing import List, Optional, Union + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + model: str + served_model_name: Optional[Union[List[str]]] = None + tokenizer: Optional[str] = None + skip_tokenizer_init: bool = False + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = 'auto' + dtype: str = 'auto' + seed: int = 0 + + def to_dict(self): + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) diff --git a/vllm/model_executor/prefill_only/config.py b/vllm/model_executor/prefill_only/config.py new file mode 100644 index 0000000000000..afd15c3c8a998 --- /dev/null +++ b/vllm/model_executor/prefill_only/config.py @@ -0,0 +1,819 @@ +import enum +import json +from dataclasses import dataclass, field, fields +from typing import List, Optional, Union + +import torch +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.transformers_utils.config import get_config, get_hf_text_config +from vllm.utils import (is_cpu, is_hip, is_neuron, is_openvino, is_xpu, + print_warning_once) + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +class DeviceConfig: + device: Optional[torch.device] + + def __init__(self, device: str = "auto") -> None: + if device == "auto": + # Automated device type detection + if is_neuron(): + self.device_type = "neuron" + elif is_openvino(): + self.device_type = "openvino" + elif is_cpu(): + self.device_type = "cpu" + elif is_xpu(): + self.device_type = "xpu" + else: + # We don't call torch.cuda.is_available() here to + # avoid initializing CUDA before workers are forked + self.device_type = "cuda" + else: + # Device type is assigned explicitly + self.device_type = device + + # Some device types require processing inputs on CPU + if self.device_type in ["neuron", "openvino"]: + self.device = torch.device("cpu") + elif self.device_type in ["tpu"]: + self.device = None + else: + # Set device with device type + self.device = torch.device(self.device_type) + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" + BITSANDBYTES = "bitsandbytes" + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + + """ + + load_format: Union[str, LoadFormat] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field( + default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads( + model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in LoadFormat.__members__ + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}") + + +class CacheConfig: + """Configuration for the KV cache. + + Args: + block_size: Size of a cache block in number of tokens. + gpu_memory_utilization: Fraction of GPU memory to use for the + vLLM execution. + swap_space: Size of the CPU swap space per GPU (in GiB). + cache_dtype: Data type for kv cache storage. + num_gpu_blocks_override: Number of GPU blocks to use. This overrides the + profiled num_gpu_blocks if specified. Does nothing if None. + """ + + def __init__( + self, + block_size: int, + gpu_memory_utilization: float, + swap_space: int, + cache_dtype: str, + num_gpu_blocks_override: Optional[int] = None, + sliding_window: Optional[int] = None, + enable_prefix_caching: bool = False, + cpu_offload_gb: float = 0, + ) -> None: + self.block_size = block_size + self.gpu_memory_utilization = gpu_memory_utilization + self.swap_space_bytes = swap_space * _GB + self.num_gpu_blocks_override = num_gpu_blocks_override + self.cache_dtype = cache_dtype + self.sliding_window = sliding_window + self.enable_prefix_caching = enable_prefix_caching + self.cpu_offload_gb = cpu_offload_gb + self._verify_args() + self._verify_cache_dtype() + self._verify_prefix_caching() + + # Will be set after profiling. + self.num_gpu_blocks = None + self.num_cpu_blocks = None + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_args(self) -> None: + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + if self.cache_dtype == "fp8": + raise NotImplementedError( + "Prefix caching is not supported for fp8 cache_dtype. " + "Run with --kv-cache-dtype auto to use prefix caching.") + + +class ModelConfig: + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. + code_revision: The specific revision to use for the model code on + Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. + rope_scaling: Dictionary containing the scaling configuration for the + RoPE embeddings. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. If unspecified, will use + the default version. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. + quantization: Quantization method that was used to quantize the model + weights. If None, we assume the model weights are not quantized. + quantization_param_path: Path to JSON file containing scaling factors. + Used to load KV cache scaling factors into the model when KV cache + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the + model dtype is FP8_E4M3 on ROCm. + disable_sliding_window: Whether to disable sliding window. If True, + we will disable the sliding window functionality of the model. + If the model does not support sliding window, this argument is + ignored. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. + served_model_name: The model name used in metrics tag `model_name`, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, + the model name will be the same as `model`. + """ + + def __init__( + self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + ) -> None: + self.model = model + self.tokenizer = tokenizer + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + # The tokenizer version is consistent with the model version by default. + if tokenizer_revision is None: + self.tokenizer_revision = revision + else: + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.quantization_param_path = quantization_param_path + self.disable_sliding_window = disable_sliding_window + self.skip_tokenizer_init = skip_tokenizer_init + + self.hf_config = get_config(self.model, trust_remote_code, revision, + code_revision, rope_scaling, rope_theta) + self.hf_text_config = get_hf_text_config(self.hf_config) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + + if (not self.disable_sliding_window + and self.hf_text_config.model_type == "gemma2" + and self.hf_text_config.sliding_window is not None): + print_warning_once( + "Gemma 2 uses sliding window attention for every odd layer, " + "which is currently not supported by vLLM. Disabling sliding " + "window and capping the max length to the sliding window size " + f"({self.hf_text_config.sliding_window}).") + self.disable_sliding_window = True + + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window()) + self.served_model_name = get_served_model_name(model, + served_model_name) + + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + self._verify_quantization() + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = self.tokenizer_mode.lower() + if tokenizer_mode not in ["auto", "slow"]: + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + "either 'auto' or 'slow'.") + self.tokenizer_mode = tokenizer_mode + + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) + return quant_cfg + + def _verify_quantization(self) -> None: + supported_quantization = [*QUANTIZATION_METHODS] + rocm_supported_quantization = ["gptq", "squeezellm"] + optimized_quantization_methods = [ + "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", + "fbgemm_fp8", "compressed_tensors", "compressed-tensors" + ] + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF model config, if available. + quant_cfg = self._parse_quant_hf_config() + + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + + # Detect which checkpoint is it + for _, method in QUANTIZATION_METHODS.items(): + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization) + if quantization_override: + quant_method = quantization_override + self.quantization = quantization_override + break + + # Verify quantization configurations. + if self.quantization is None: + self.quantization = quant_method + elif self.quantization != quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization}).") + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}.") + if is_hip( + ) and self.quantization not in rocm_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in ROCm.") + if self.quantization not in optimized_quantization_methods: + logger.warning( + "%s quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.", self.quantization) + + def get_hf_config_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled.""" + + # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in + # addition to sliding window size. We check if that field is present + # and if it's False, return None. + if (hasattr(self.hf_text_config, "use_sliding_window") + and not self.hf_text_config.use_sliding_window): + return None + return getattr(self.hf_text_config, "sliding_window", None) + + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + + def get_vocab_size(self) -> int: + return self.hf_text_config.vocab_size + + def get_hidden_size(self) -> int: + return self.hf_text_config.hidden_size + + def get_head_size(self) -> int: + # TODO remove hard code + if hasattr(self.hf_text_config, "model_type" + ) and self.hf_text_config.model_type == 'deepseek_v2': + # FlashAttention supports only head_size 32, 64, 128, 256, + # we need to pad head_size 192 to 256 + return 256 + if hasattr(self.hf_text_config, "head_dim"): + return self.hf_text_config.head_dim + # FIXME(woosuk): This may not be true for all models. + return (self.hf_text_config.hidden_size // + self.hf_text_config.num_attention_heads) + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False)) + if not new_decoder_arch_falcon and getattr(self.hf_text_config, + "multi_query", False): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": + return getattr(self.hf_config.attn_config, "kv_n_heads", + self.hf_config.num_attention_heads) + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + def get_num_kv_heads(self) -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, total_num_kv_heads) + + def get_num_attention_heads(self) -> int: + num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + return num_heads + + def get_num_layers(self) -> int: + + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) + + return total_num_hidden_layers + + def get_layers_block_type(self) -> List[str]: + num_layers = self.get_num_layers() + # Transformers supports layers_block_type @property + return getattr(self.hf_config, "layers_block_type", + ["attention"] * num_layers) + + def get_num_attention_layers(self) -> int: + return len( + [t for t in self.get_layers_block_type() if t == "attention"]) + + +class SchedulerConfig: + pass + + +class ParallelConfig: + pass + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + if config.model_type == "gemma2": + logger.info( + "For Gemma 2, we downcast float32 to bfloat16 instead " + "of float16 by default. Please specify `dtype` if you " + "want to use float16.") + torch_dtype = torch.bfloat16 + else: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + +def get_served_model_name(model: str, + served_model_name: Optional[Union[str, List[str]]]): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[int], +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys. + max_len_key = None + for key in possible_keys: + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + max_len_key = "sliding_window" \ + if sliding_window_len < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, sliding_window_len) + + # If none of the keys were found in the config, use a default and + # log a warning. + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + "%s. Assuming the model's maximum length is %d.", possible_keys, + default_max_len) + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None: + if "type" in rope_scaling: + rope_type = rope_scaling["type"] + elif "rope_type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have a 'type' or 'rope_type' key.") + + # The correct one should be "longrope", kept "su" here + # to be backward compatible + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. + if max_model_len is None: + max_model_len = int(derived_max_model_len) + elif max_model_len > derived_max_model_len: + # Some models might have a separate key for specifying model_max_length + # that will be bigger than derived_max_model_len. We compare user input + # with model_max_length and allow this override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") + else: + msg = ( + f"User-specified max_model_len ({max_model_len}) is greater " + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" + f"{model_max_length} in model's config.json). This may lead " + "to incorrect model outputs or CUDA errors.") + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning( + "%s Make sure the value is correct and within the " + "model context size.", msg) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") + return int(max_model_len) + + +@dataclass(frozen=True) +class EngineConfig: + model_config: ModelConfig + device_config: DeviceConfig + load_config: LoadConfig + scheduler_config: SchedulerConfig + parallel_config: Optional[ParallelConfig] = None + + def to_dict(self): + """Return the configs as a dictionary, for use in **kwargs. + """ + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) + + def log_config(self): + from vllm.version import __version__ as VLLM_VERSION + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, " + "quantization=%s, " + "quantization_param_path=%s, device_config=%s, " + "seed=%d, served_model_name=%s)", + VLLM_VERSION, + self.model_config.model, + self.model_config.tokenizer, + self.model_config.skip_tokenizer_init, + self.model_config.tokenizer_mode, + self.model_config.revision, + self.model_config.rope_scaling, + self.model_config.rope_theta, + self.model_config.tokenizer_revision, + self.model_config.trust_remote_code, + self.model_config.dtype, + self.model_config.max_model_len, + self.load_config.download_dir, + self.load_config.load_format, + self.model_config.quantization, + self.model_config.quantization_param_path, + self.device_config.device, + self.model_config.seed, + self.model_config.served_model_name, + ) + + +def filter_unexpected_fields(cls): + original_init = cls.__init__ + + def new_init(self, *args, **kwargs): + expected_fields = {field.name for field in fields(cls)} + cleaned_kwargs = { + key: value + for key, value in kwargs.items() if key in expected_fields + } + original_init(self, *args, **cleaned_kwargs) + + cls.__init__ = new_init + return cls + + +class PrefillOnlySchedulerConfig(SchedulerConfig): + + def __init__(self, + max_model_len: int, + max_num_batched_tokens: Optional[int] = None, + max_num_requests: Optional[int] = None, + max_num_seqs: Optional[int] = None, + max_num_on_the_fly: int = 3, + scheduling: str = "sync") -> None: + self.max_model_len = max_model_len + self.max_num_requests: int = 0 + self.max_num_batched_tokens: int = 0 + self.max_num_on_the_fly: int = max_num_on_the_fly + self.scheduling = scheduling + + self.set_args(max_num_batched_tokens, max_num_requests, max_num_seqs) + + def set_args(self, + max_num_batched_tokens: Optional[int] = None, + max_num_requests: Optional[int] = None, + max_num_seqs: Optional[int] = None): + if max_num_seqs is not None: + self.max_num_requests = max_num_seqs + elif max_num_requests is not None: + self.max_num_requests = max_num_requests + else: + raise ValueError("At least one of max_num_seqs " + "and max_num_requests is not None.") + + if max_num_batched_tokens is not None: + self.max_num_batched_tokens = max_num_batched_tokens + else: + self.max_num_batched_tokens = (self.max_model_len * + self.max_num_requests) + + self._verify_args() + + def _verify_args(self) -> None: + if self.max_num_batched_tokens < self.max_model_len: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_model_len " + f"({self.max_model_len}).") + + if self.max_num_on_the_fly < 2: + raise ValueError( + f"max_num_on_the_fly {self.max_num_on_the_fly} must " + "be greater than 1") + + if self.scheduling not in ["sync", "async", "double_buffer"]: + raise ValueError(f"scheduling {self.scheduling} must " + f"in sync, async and double_buffer") + + @property + def max_num_seqs(self) -> int: + return self.max_num_requests diff --git a/vllm/model_executor/prefill_only/engine_io.py b/vllm/model_executor/prefill_only/engine_io.py new file mode 100644 index 0000000000000..2a226ba9aa44b --- /dev/null +++ b/vllm/model_executor/prefill_only/engine_io.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import List + +from vllm.inputs.prefill_only.data import Request, SchedulableRequest + + +@dataclass +class SchedulerOutput: + pass + + +@dataclass +class PrefillOnlySchedulerOutput(SchedulerOutput): + scheduled_requests: List[SchedulableRequest] + ignored_requests: List[SchedulableRequest] + + def is_empty(self) -> bool: + return not self.scheduled_requests + + +@dataclass +class RequestOutput(Request): + finished: bool diff --git a/vllm/model_executor/prefill_only/execute_io.py b/vllm/model_executor/prefill_only/execute_io.py new file mode 100644 index 0000000000000..db8330f333da3 --- /dev/null +++ b/vllm/model_executor/prefill_only/execute_io.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.attention.prefill_only.abstract import PrefillOnlyAttentionBackend + + +@dataclass +class ModelInput: + pass + + +@dataclass +class WorkerInput: + pass + + +@dataclass +class ExecuteInput: + worker_input: Optional[WorkerInput] + model_input: Optional[ModelInput] + + +class ExecuteOutput: + + def to(self, target_device, non_blocking=False): + for k in self.__dict__: + self.__dict__[k] = self.__dict__[k].to(device=target_device, + non_blocking=non_blocking) + + +@dataclass +class ModelInputForGPU(ModelInput): + input_ids: torch.Tensor + positions: torch.Tensor + attn_metadata: PrefillOnlyAttentionBackend + + def to(self, target_device, non_blocking=False): + for k in self.__dict__: + self.__dict__[k] = self.__dict__[k].to(device=target_device, + non_blocking=non_blocking) + + def to_dict(self): + return self.__dict__ + + +class PrefillOnlyExecuteInput(ExecuteInput): + worker_input = None + model_input: ModelInputForGPU diff --git a/vllm/model_executor/prefill_only/loader/__init__.py b/vllm/model_executor/prefill_only/loader/__init__.py new file mode 100644 index 0000000000000..de53821b2a130 --- /dev/null +++ b/vllm/model_executor/prefill_only/loader/__init__.py @@ -0,0 +1,4 @@ +from vllm.model_executor.prefill_only.loader.loader import (get_model_loader, + initialize_model) + +__all__ = ["get_model_loader", "initialize_model"] diff --git a/vllm/model_executor/prefill_only/loader/loader.py b/vllm/model_executor/prefill_only/loader/loader.py new file mode 100644 index 0000000000000..bb630a8ff9d08 --- /dev/null +++ b/vllm/model_executor/prefill_only/loader/loader.py @@ -0,0 +1,631 @@ +# ruff: noqa: SIM117 +import fnmatch +import glob +import json +import math +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, Generator, List, Optional, Tuple + +import huggingface_hub +import numpy as np +import torch +from huggingface_hub import HfApi, hf_hub_download +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.attention.prefill_only.abstract import PrefillOnlyAttentionBackend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, + ModelConfig, SchedulerConfig) +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, + get_quant_config, np_cache_weights_iterator, pt_weights_iterator, + safetensors_weights_iterator) +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import is_pin_memory_available + +from .utils import get_model_architecture, set_default_torch_dtype + + +@contextmanager +def device_loading_context(module: torch.nn.Module, + target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided(size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + + +logger = init_logger(__name__) + + +def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + return quant_config + return None + + +def initialize_model( + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + attn_backend: PrefillOnlyAttentionBackend, + cache_config: Optional[CacheConfig] = None, +) -> nn.Module: + """Initialize a model with the given configurations.""" + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model_class = get_model_architecture(model_config)[0] + quant_config = _get_quantization_config(model_config, load_config) + + return model_class(config=model_config.hf_config, + cache_config=cache_config, + quant_config=quant_config, + attn_backend=attn_backend) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + """Load a model with the given configurations.""" + ... + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, index_file, + self.load_config.download_dir, revision) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], + fall_back_to_pt: bool + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision, fall_back_to_pt) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + model_name_or_path, self.load_config.download_dir, hf_folder, + hf_weights_files) + elif use_safetensors: + weights_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) + + return weights_iterator + + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + model.load_weights( + self._get_weights_iterator(model_config.model, + model_config.revision, + fall_back_to_pt=getattr( + model, + "fall_back_to_pt_during_load", + True)), ) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model.eval() + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + return model.eval() + + +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + default_target_modules = [ + "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", + "o_proj" + ] + + possible_config_file_names = ["adapter_config.json"] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # we don't need to quantize the whole model, only the target modules + # that are specified in the adapter config file. If the adapter config + # file is not provided, we will quantize the default modules. + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + self.target_modules = self.default_target_modules + return + + qlora_adapter = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + config_file_path = self._get_config_file(qlora_adapter) + + with open(config_file_path, "r") as f: + config = json.load(f) + self.target_modules = config["target_modules"] + + def _get_config_file(self, qlora_adapter: str) -> str: + is_local = os.path.isdir(qlora_adapter) + config_file_path = None + if is_local: + for file in self.possible_config_file_names: + config_file_path = os.path.join(qlora_adapter, file) + if os.path.exists(config_file_path): + break + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) + for file in self.possible_config_file_names: + if file in repo_files: + config_file_path = hf_hub_download(repo_id=qlora_adapter, + filename=file) + break + + if not config_file_path: + raise ValueError( + f"Cannot find adapter config file in {qlora_adapter}") + + return config_file_path + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None) -> Tuple[List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob( + os.path.join(model_name_or_path, pattern)) + if weight_files: + return weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + return glob.glob(os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError( + f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision) + + if matched_pattern != "*.safetensors": + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_weights_files, matched_pattern == "*.safetensors" + + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + if use_safetensors: + return safetensors_weights_iterator(hf_weights_files) + else: + return pt_weights_iterator(hf_weights_files) + + def _get_quantized_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], pre_quant: bool + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, + Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + from bitsandbytes.functional import QuantState + if bitsandbytes.__version__ < "0.42.0": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.42.0.") + from bitsandbytes.functional import quantize_4bit + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.42.0 via " + "`pip install bitsandbytes>=0.42.0` to use " + "bitsandbytes quantizer.") from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision) + + quant_state_dict = {} + + def quantized_checkpoint() -> Generator: + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for weight_name, weight_tensor in weight_iterator: + if weight_name.endswith(".weight"): + continue + # TODO: only nf4 quantization is supported for now + if weight_name.endswith(".quant_state.bitsandbytes__fp4"): + raise NotImplementedError( + "Only bitsandbytes_nf4 quantization" + f"is supported for now. {weight_name} is fp4 quantized" + ) + temp_state_dict[weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__nf4 in CPU + quant_state[param_name + + ".quant_state.bitsandbytes__nf4"] = quant_state[ + param_name + + ".quant_state.bitsandbytes__nf4"].cpu().data + return QuantState.from_dict(quant_state, device="cuda") + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + # Filter out all weights whose suffix is not ".weight" + if not weight_name.endswith(".weight"): + continue + if weight_name + ".quant_state.bitsandbytes__nf4" \ + in temp_state_dict: + quant_state = _parse_quant_state(weight_name, + temp_state_dict) + weight_name = weight_name.replace(".weight", ".qweight") + quant_state_dict[weight_name] = quant_state + yield weight_name.replace(".weight", + ".qweight"), weight_tensor + else: + yield weight_name, weight_tensor + + def generator() -> Generator: + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + if any(target_module in weight_name + for target_module in self.target_modules): + weight_name = weight_name.replace(".weight", ".qweight") + # bitsandbytes requires data in GPU + loaded_weight = weight_tensor.cuda().data + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4") + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight + + if pre_quant: + return quantized_checkpoint(), quant_state_dict + return generator(), quant_state_dict + + def _load_weights(self, model_config: ModelConfig, + model: nn.Module) -> None: + if not hasattr(model, 'load_weights'): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(self).__name__}.") + + if not hasattr(model, 'bitsandbytes_stacked_params_mapping'): + raise AttributeError( + f"Model {type(self).__name__} does not support BitsAndBytes " + "quantization yet.") + + logger.info("Loading weights with BitsAndBytes quantization. " + " May take a while ...") + + is_quantized_checkpoint = False + quant_config = getattr(model_config.hf_config, "quantization_config", + None) + if quant_config is not None and quant_config.get( + 'quant_method') == "bitsandbytes": + is_quantized_checkpoint = True + + qweight_iterator, quant_state_dict = \ + self._get_quantized_weights_iterator( + model_config.model, model_config.revision, is_quantized_checkpoint) + + model.load_weights(qweight_iterator) + + torch.cuda.empty_cache() + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + for quant_param_name in quant_state_dict: + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, index + ) in model.bitsandbytes_stacked_params_mapping.items(): + if shard_name in quant_param_name: + shard_index = index + quant_param_name = quant_param_name.replace( + shard_name, weight_name) + break + + if quant_param_name not in param_dict: + raise ValueError( + f"Parameter {quant_param_name} not found in the model.") + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = ( + quant_state_dict[non_stacked_param_name]) + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError( + f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in quant_states.items(): + num_elements[seq] = math.prod( + quant_state.shape) // pack_ratio + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + self._load_weights(model_config, model) + + return model.eval() + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/prefill_only/loader/utils.py b/vllm/model_executor/prefill_only/loader/utils.py new file mode 100644 index 0000000000000..22ea4f6568fa6 --- /dev/null +++ b/vllm/model_executor/prefill_only/loader/utils.py @@ -0,0 +1,48 @@ +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Tuple, Type + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.config import ModelConfig +from vllm.model_executor.prefill_only.modelzoo import ModelRegistry + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture( + model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_model_workflow(hf_config: PretrainedConfig) -> str: + architectures = getattr(hf_config, "architectures", []) + + for arch in architectures: + workflow = ModelRegistry.get_workflow(arch) + if workflow is not None: + return workflow + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] diff --git a/vllm/model_executor/prefill_only/model_input_builder.py b/vllm/model_executor/prefill_only/model_input_builder.py new file mode 100644 index 0000000000000..d5dd0bf04f3a3 --- /dev/null +++ b/vllm/model_executor/prefill_only/model_input_builder.py @@ -0,0 +1,72 @@ +from abc import ABC, abstractmethod +from typing import cast + +import torch + +from vllm.attention.prefill_only.abstract import ( + PrefillOnlyAttentionMetadataBuilder) +from vllm.model_executor.prefill_only.engine_io import ( + PrefillOnlySchedulerOutput, SchedulerOutput) +from vllm.model_executor.prefill_only.execute_io import (ExecuteInput, + ModelInputForGPU) +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +class ModelInputBuilder(ABC): + """ + scheduler_output = scheduler.schedule() + SchedulerOutput -> ModelInputBuilder -> ExecuteInput + """ + + @abstractmethod + def __call__(self, scheduler_output: SchedulerOutput) -> ExecuteInput: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine): + raise NotImplementedError + + +class PrefillOnlyModelInputBuilder(ModelInputBuilder): + + def __init__( + self, + attention_metadata_builder: PrefillOnlyAttentionMetadataBuilder): + self.attention_metadata_builder = attention_metadata_builder + + @classmethod + def from_engine(cls, engine): + return cls(engine.attn_backend.get_builder_cls()()) + + def __call__(self, scheduler_output: SchedulerOutput) -> ExecuteInput: + assert isinstance(scheduler_output, PrefillOnlySchedulerOutput) + scheduler_output = cast(PrefillOnlySchedulerOutput, scheduler_output) + + input_tokens = [] + input_positions = [] + seq_lens = [] + for request in scheduler_output.scheduled_requests: + prompt_token_ids = request.inputs.prompt_token_ids + n_tokens = len(prompt_token_ids) + input_tokens.extend(prompt_token_ids) + input_positions.extend(list(range(0, n_tokens))) + seq_lens.append(n_tokens) + + input_ids = torch.tensor(input_tokens, + dtype=torch.long, + pin_memory=pin_memory, + device="cpu") + positions = torch.tensor(input_positions, + dtype=torch.long, + pin_memory=pin_memory, + device="cpu") + attn_metadata = self.attention_metadata_builder(seq_lens) + + model_input = ModelInputForGPU(input_ids=input_ids, + positions=positions, + attn_metadata=attn_metadata) + + return ExecuteInput(worker_input=None, model_input=model_input) diff --git a/vllm/model_executor/prefill_only/modelzoo.py b/vllm/model_executor/prefill_only/modelzoo.py new file mode 100644 index 0000000000000..fdcb02c7df659 --- /dev/null +++ b/vllm/model_executor/prefill_only/modelzoo.py @@ -0,0 +1,65 @@ +import functools +import importlib +from typing import Dict, List, Optional, Tuple, Type + +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.model_executor.encode_only.modelzoo import ENCODE_ONLY_MODELS + +logger = init_logger(__name__) + +_MODELS_LIST = [ENCODE_ONLY_MODELS] + +# Architecture -> (module, workflow). +_MODELS: Dict[str, Tuple[str, str]] = dict() +for m in _MODELS_LIST: + _MODELS.update(**m) + +# Architecture -> type. +# out of tree models +_OOT_MODELS: Dict[str, Type[nn.Module]] = {} + + +class ModelRegistry: + + @staticmethod + @functools.lru_cache(maxsize=128) + def _get_model(model_arch: str): + module_str, workflow = _MODELS[model_arch] + module_name, model_cls_name = module_str.split(":") + module = importlib.import_module(module_name) + return getattr(module, model_cls_name, None) + + @staticmethod + def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] + if model_arch not in _MODELS: + return None + return ModelRegistry._get_model(model_arch) + + @staticmethod + def get_supported_archs() -> List[str]: + return list(_MODELS.keys()) + + @staticmethod + @functools.lru_cache(maxsize=128) + def get_workflow(model_arch: str): + module_str, workflow = _MODELS[model_arch] + return workflow + + @staticmethod + def register_model(model_arch: str, model_cls: Type[nn.Module]): + if model_arch in _MODELS: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls.__name__) + global _OOT_MODELS + _OOT_MODELS[model_arch] = model_cls + + +__all__ = [ + "ModelRegistry", +] diff --git a/vllm/model_executor/prefill_only/output_processor.py b/vllm/model_executor/prefill_only/output_processor.py new file mode 100644 index 0000000000000..e7b70531fa9ad --- /dev/null +++ b/vllm/model_executor/prefill_only/output_processor.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod +from typing import List + +import torch + +from vllm.model_executor.prefill_only.engine_io import (RequestOutput, + SchedulerOutput) + + +class OutputProcessor(ABC): + """ + scheduler_output, execute_output -> OutputProcessor -> RequestOutput + """ + + @abstractmethod + def __call__(self, scheduler_output: SchedulerOutput, + execute_output: torch.Tensor) -> List[RequestOutput]: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine): + raise NotImplementedError diff --git a/vllm/model_executor/prefill_only/utils.py b/vllm/model_executor/prefill_only/utils.py new file mode 100644 index 0000000000000..bf3621addf52a --- /dev/null +++ b/vllm/model_executor/prefill_only/utils.py @@ -0,0 +1,53 @@ +from typing import List, Optional + +import torch + + +class FakeGroupCoordinator: + rank: int = 0 + ranks: List[int] = [0] + world_size: int = 1 + local_rank: int = 0 + rank_in_group: int = 0 + + def destroy(self): + pass + + @property + def first_rank(self): + return self.ranks[0] + + @property + def last_rank(self): + return self.ranks[-1] + + @property + def is_first_rank(self): + return self.rank == self.first_rank + + @property + def is_last_rank(self): + return self.rank == self.last_rank + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + return input_ + + +def fix_distributed_environment(): + # This dirty_fix can make ParallelLinear etc. work properly. + # Why should tp and model layers be coupled together? + + import vllm.distributed.parallel_state + + fake_parallel_group = FakeGroupCoordinator() + vllm.distributed.parallel_state._TP = fake_parallel_group + vllm.distributed.parallel_state._PP = fake_parallel_group diff --git a/vllm/model_executor/prefill_only/workflow.py b/vllm/model_executor/prefill_only/workflow.py new file mode 100644 index 0000000000000..e38ba8f753cd7 --- /dev/null +++ b/vllm/model_executor/prefill_only/workflow.py @@ -0,0 +1,43 @@ +class Workflow: + EngineArgs: str + Scheduler: str + AttnBackend: str + attn_type: str + Tokenizer: str = "vllm.inputs.prefill_only.tokenizer:Tokenizer" + InputProcessor: str + RequestProcessor: str + OutputProcessor: str + ModelInputBuilder: str + Executor: str + Worker: str + + @classmethod + def from_engine(cls, engine): + return cls() + + +class PrefillOnlyWorkflow(Workflow): + InputProcessor: str = ("vllm.inputs.prefill_only.preprocessor" + ":TextInputProcessor") + RequestProcessor: str = ("vllm.inputs.prefill_only.preprocessor" + ":TextRequestProcessor") + ModelInputBuilder: str = ( + "vllm.model_executor.prefill_only.model_input_builder" + ":PrefillOnlyModelInputBuilder") + Worker: str = "vllm.worker.prefill_only_gpu_worker:Worker" + Executor: str = "vllm.executor.prefill_only_gpu_executor" + Scheduler: str = "vllm.core.prefill_only_scheduler:PrefillOnlyScheduler" + AttnBackend: str = "vllm.attention.prefill_only.selector:AttnBackend" + + @classmethod + def from_engine(cls, engine): + workflow = cls() + + if engine.engine_config.scheduler_config.scheduling in ["sync"]: + workflow.Executor += ":GPUExecutor" + elif engine.engine_config.scheduler_config.scheduling in [ + "async", "double_buffer" + ]: + workflow.Executor += ":GPUAsyncExecutor" + + return workflow diff --git a/vllm/wde/__init__.py b/vllm/wde/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/__init__.py b/vllm/wde/core/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/arg_utils.py b/vllm/wde/core/arg_utils.py new file mode 100644 index 0000000000000..a10ee22c28cca --- /dev/null +++ b/vllm/wde/core/arg_utils.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass, fields +from typing import List, Optional, Union + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + model: str + served_model_name: Optional[Union[List[str]]] = None + tokenizer: Optional[str] = None + skip_tokenizer_init: bool = False + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = 'auto' + dtype: str = 'auto' + seed: int = 0 + + def to_dict(self): + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) diff --git a/vllm/wde/core/config.py b/vllm/wde/core/config.py new file mode 100644 index 0000000000000..18f288a11b999 --- /dev/null +++ b/vllm/wde/core/config.py @@ -0,0 +1,761 @@ +import enum +import json +from dataclasses import dataclass, field, fields +from typing import List, Optional, Union + +import torch +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.transformers_utils.config import get_config, get_hf_text_config +from vllm.utils import (is_cpu, is_hip, is_neuron, is_openvino, is_xpu, + print_warning_once) + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +class DeviceConfig: + device: Optional[torch.device] + + def __init__(self, device: str = "auto") -> None: + if device == "auto": + # Automated device type detection + if is_neuron(): + self.device_type = "neuron" + elif is_openvino(): + self.device_type = "openvino" + elif is_cpu(): + self.device_type = "cpu" + elif is_xpu(): + self.device_type = "xpu" + else: + # We don't call torch.cuda.is_available() here to + # avoid initializing CUDA before workers are forked + self.device_type = "cuda" + else: + # Device type is assigned explicitly + self.device_type = device + + # Some device types require processing inputs on CPU + if self.device_type in ["neuron", "openvino"]: + self.device = torch.device("cpu") + elif self.device_type in ["tpu"]: + self.device = None + else: + # Set device with device type + self.device = torch.device(self.device_type) + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" + BITSANDBYTES = "bitsandbytes" + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + + """ + + load_format: Union[str, LoadFormat] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field( + default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads( + model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in LoadFormat.__members__ + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}") + + +class CacheConfig: + """Configuration for the KV cache. + + Args: + block_size: Size of a cache block in number of tokens. + gpu_memory_utilization: Fraction of GPU memory to use for the + vLLM execution. + swap_space: Size of the CPU swap space per GPU (in GiB). + cache_dtype: Data type for kv cache storage. + num_gpu_blocks_override: Number of GPU blocks to use. This overrides the + profiled num_gpu_blocks if specified. Does nothing if None. + """ + + def __init__( + self, + block_size: int, + gpu_memory_utilization: float, + swap_space: int, + cache_dtype: str, + num_gpu_blocks_override: Optional[int] = None, + sliding_window: Optional[int] = None, + enable_prefix_caching: bool = False, + cpu_offload_gb: float = 0, + ) -> None: + self.block_size = block_size + self.gpu_memory_utilization = gpu_memory_utilization + self.swap_space_bytes = swap_space * _GB + self.num_gpu_blocks_override = num_gpu_blocks_override + self.cache_dtype = cache_dtype + self.sliding_window = sliding_window + self.enable_prefix_caching = enable_prefix_caching + self.cpu_offload_gb = cpu_offload_gb + self._verify_args() + self._verify_cache_dtype() + self._verify_prefix_caching() + + # Will be set after profiling. + self.num_gpu_blocks = None + self.num_cpu_blocks = None + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_args(self) -> None: + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + if self.cache_dtype == "fp8": + raise NotImplementedError( + "Prefix caching is not supported for fp8 cache_dtype. " + "Run with --kv-cache-dtype auto to use prefix caching.") + + +class ModelConfig: + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. + code_revision: The specific revision to use for the model code on + Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. + rope_scaling: Dictionary containing the scaling configuration for the + RoPE embeddings. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. If unspecified, will use + the default version. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. + quantization: Quantization method that was used to quantize the model + weights. If None, we assume the model weights are not quantized. + quantization_param_path: Path to JSON file containing scaling factors. + Used to load KV cache scaling factors into the model when KV cache + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the + model dtype is FP8_E4M3 on ROCm. + disable_sliding_window: Whether to disable sliding window. If True, + we will disable the sliding window functionality of the model. + If the model does not support sliding window, this argument is + ignored. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. + served_model_name: The model name used in metrics tag `model_name`, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, + the model name will be the same as `model`. + """ + + def __init__( + self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + ) -> None: + self.model = model + self.tokenizer = tokenizer + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + # The tokenizer version is consistent with the model version by default. + if tokenizer_revision is None: + self.tokenizer_revision = revision + else: + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.quantization_param_path = quantization_param_path + self.disable_sliding_window = disable_sliding_window + self.skip_tokenizer_init = skip_tokenizer_init + + self.hf_config = get_config(self.model, trust_remote_code, revision, + code_revision, rope_scaling, rope_theta) + self.hf_text_config = get_hf_text_config(self.hf_config) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + + if (not self.disable_sliding_window + and self.hf_text_config.model_type == "gemma2" + and self.hf_text_config.sliding_window is not None): + print_warning_once( + "Gemma 2 uses sliding window attention for every odd layer, " + "which is currently not supported by vLLM. Disabling sliding " + "window and capping the max length to the sliding window size " + f"({self.hf_text_config.sliding_window}).") + self.disable_sliding_window = True + + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window()) + self.served_model_name = get_served_model_name(model, + served_model_name) + + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + self._verify_quantization() + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = self.tokenizer_mode.lower() + if tokenizer_mode not in ["auto", "slow"]: + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + "either 'auto' or 'slow'.") + self.tokenizer_mode = tokenizer_mode + + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) + return quant_cfg + + def _verify_quantization(self) -> None: + supported_quantization = [*QUANTIZATION_METHODS] + rocm_supported_quantization = ["gptq", "squeezellm"] + optimized_quantization_methods = [ + "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", + "fbgemm_fp8", "compressed_tensors", "compressed-tensors" + ] + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF model config, if available. + quant_cfg = self._parse_quant_hf_config() + + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + + # Detect which checkpoint is it + for _, method in QUANTIZATION_METHODS.items(): + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization) + if quantization_override: + quant_method = quantization_override + self.quantization = quantization_override + break + + # Verify quantization configurations. + if self.quantization is None: + self.quantization = quant_method + elif self.quantization != quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization}).") + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}.") + if is_hip( + ) and self.quantization not in rocm_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in ROCm.") + if self.quantization not in optimized_quantization_methods: + logger.warning( + "%s quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.", self.quantization) + + def get_hf_config_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled.""" + + # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in + # addition to sliding window size. We check if that field is present + # and if it's False, return None. + if (hasattr(self.hf_text_config, "use_sliding_window") + and not self.hf_text_config.use_sliding_window): + return None + return getattr(self.hf_text_config, "sliding_window", None) + + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + + def get_vocab_size(self) -> int: + return self.hf_text_config.vocab_size + + def get_hidden_size(self) -> int: + return self.hf_text_config.hidden_size + + def get_head_size(self) -> int: + # TODO remove hard code + if hasattr(self.hf_text_config, "model_type" + ) and self.hf_text_config.model_type == 'deepseek_v2': + # FlashAttention supports only head_size 32, 64, 128, 256, + # we need to pad head_size 192 to 256 + return 256 + if hasattr(self.hf_text_config, "head_dim"): + return self.hf_text_config.head_dim + # FIXME(woosuk): This may not be true for all models. + return (self.hf_text_config.hidden_size // + self.hf_text_config.num_attention_heads) + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False)) + if not new_decoder_arch_falcon and getattr(self.hf_text_config, + "multi_query", False): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": + return getattr(self.hf_config.attn_config, "kv_n_heads", + self.hf_config.num_attention_heads) + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + def get_num_kv_heads(self) -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, total_num_kv_heads) + + def get_num_attention_heads(self) -> int: + num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + return num_heads + + def get_num_layers(self) -> int: + + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) + + return total_num_hidden_layers + + def get_layers_block_type(self) -> List[str]: + num_layers = self.get_num_layers() + # Transformers supports layers_block_type @property + return getattr(self.hf_config, "layers_block_type", + ["attention"] * num_layers) + + def get_num_attention_layers(self) -> int: + return len( + [t for t in self.get_layers_block_type() if t == "attention"]) + + +class SchedulerConfig: + pass + + +class ParallelConfig: + pass + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + if config.model_type == "gemma2": + logger.info( + "For Gemma 2, we downcast float32 to bfloat16 instead " + "of float16 by default. Please specify `dtype` if you " + "want to use float16.") + torch_dtype = torch.bfloat16 + else: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + +def get_served_model_name(model: str, + served_model_name: Optional[Union[str, List[str]]]): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[int], +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys. + max_len_key = None + for key in possible_keys: + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + max_len_key = "sliding_window" \ + if sliding_window_len < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, sliding_window_len) + + # If none of the keys were found in the config, use a default and + # log a warning. + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + "%s. Assuming the model's maximum length is %d.", possible_keys, + default_max_len) + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None: + if "type" in rope_scaling: + rope_type = rope_scaling["type"] + elif "rope_type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have a 'type' or 'rope_type' key.") + + # The correct one should be "longrope", kept "su" here + # to be backward compatible + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. + if max_model_len is None: + max_model_len = int(derived_max_model_len) + elif max_model_len > derived_max_model_len: + # Some models might have a separate key for specifying model_max_length + # that will be bigger than derived_max_model_len. We compare user input + # with model_max_length and allow this override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") + else: + msg = ( + f"User-specified max_model_len ({max_model_len}) is greater " + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" + f"{model_max_length} in model's config.json). This may lead " + "to incorrect model outputs or CUDA errors.") + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning( + "%s Make sure the value is correct and within the " + "model context size.", msg) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") + return int(max_model_len) + + +@dataclass(frozen=True) +class EngineConfig: + model_config: ModelConfig + device_config: DeviceConfig + load_config: LoadConfig + scheduler_config: SchedulerConfig + parallel_config: Optional[ParallelConfig] = None + + def to_dict(self): + """Return the configs as a dictionary, for use in **kwargs. + """ + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) + + def log_config(self): + from vllm.version import __version__ as VLLM_VERSION + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, " + "quantization=%s, " + "quantization_param_path=%s, device_config=%s, " + "seed=%d, served_model_name=%s)", + VLLM_VERSION, + self.model_config.model, + self.model_config.tokenizer, + self.model_config.skip_tokenizer_init, + self.model_config.tokenizer_mode, + self.model_config.revision, + self.model_config.rope_scaling, + self.model_config.rope_theta, + self.model_config.tokenizer_revision, + self.model_config.trust_remote_code, + self.model_config.dtype, + self.model_config.max_model_len, + self.load_config.download_dir, + self.load_config.load_format, + self.model_config.quantization, + self.model_config.quantization_param_path, + self.device_config.device, + self.model_config.seed, + self.model_config.served_model_name, + ) + + +def filter_unexpected_fields(cls): + original_init = cls.__init__ + + def new_init(self, *args, **kwargs): + expected_fields = {field.name for field in fields(cls)} + cleaned_kwargs = { + key: value + for key, value in kwargs.items() if key in expected_fields + } + original_init(self, *args, **cleaned_kwargs) + + cls.__init__ = new_init + return cls diff --git a/vllm/wde/core/layers/__init__.py b/vllm/wde/core/layers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/layers/attention/__init__.py b/vllm/wde/core/layers/attention/__init__.py new file mode 100644 index 0000000000000..79bcc8ff6e9b5 --- /dev/null +++ b/vllm/wde/core/layers/attention/__init__.py @@ -0,0 +1,8 @@ +from vllm.wde.core.layers.attention.abstract import (AttentionBackend, + AttentionMetadata, + AttentionType) +from vllm.wde.core.layers.attention.layer import Attention + +__all__ = [ + "Attention", "AttentionMetadata", "AttentionBackend", "AttentionType" +] diff --git a/vllm/wde/core/layers/attention/abstract.py b/vllm/wde/core/layers/attention/abstract.py new file mode 100644 index 0000000000000..0e7ce7b71e4da --- /dev/null +++ b/vllm/wde/core/layers/attention/abstract.py @@ -0,0 +1,124 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum, auto +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar + +import torch + + +class AttentionType(Enum): + DECODER = auto() # Decoder attention between previous layer Q/K/V + ENCODER = auto() # Encoder attention between previous layer Q/K/V + ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + + @staticmethod + def attn_type_name_to_enum(attn_type: str) -> "AttentionType": + assert attn_type is not None + + attn_type_members = AttentionType.__members__ + if attn_type not in attn_type_members: + raise ValueError( + f"Invalid attn_type '{attn_type}'. " + f"Available backends: {', '.join(attn_type_members)} " + "(case-sensitive).") + + return AttentionType[attn_type] + + +class AttentionBackend(ABC): + """Abstract class for attention backends.""" + + def __init__(self, attn_type: AttentionType): + self._attn_type = attn_type + + @property + def attn_type(self) -> AttentionType: + return self._attn_type + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + raise NotImplementedError + + @classmethod + def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + @abstractmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @classmethod + def make_metadata_builder(cls, *args, + **kwargs) -> "AttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + + +@dataclass +class AttentionMetadata: + pass + + def to(self, device, non_blocking=False): + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + self.__dict__[k] = v.to(device, non_blocking=non_blocking) + + return self + + +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self) -> None: + raise NotImplementedError + + @abstractmethod + def __call__(self, *args, **kwargs) -> T: + raise NotImplementedError + + +class AttentionImpl(ABC, Generic[T]): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/wde/core/layers/attention/layer.py b/vllm/wde/core/layers/attention/layer.py new file mode 100644 index 0000000000000..9d9e55a07aed6 --- /dev/null +++ b/vllm/wde/core/layers/attention/layer.py @@ -0,0 +1,101 @@ +"""Attention layer.""" +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.wde.core.layers.attention.abstract import AttentionBackend + + +class Attention(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + attn_backend: AttentionBackend, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + prefix: str = "", + ) -> None: + super().__init__() + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + # block_size = cache_config.block_size + sliding_window = cache_config.sliding_window + else: + kv_cache_dtype = "auto" + # block_size = 16 + sliding_window = None + if num_kv_heads is None: + num_kv_heads = num_heads + + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self._k_scale = 1.0 + self._v_scale = 1.0 + quant_method = quant_config.get_quant_method( + self, prefix=prefix) if quant_config else None + if quant_method is not None: + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + + impl_cls = attn_backend.get_impl_cls() + self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap) + self.attn_type = attn_backend.attn_type + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_cache: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + return self.impl.forward(query, key, value, attn_metadata, kv_cache, + self._k_scale, self._v_scale, self.attn_type) + + def extra_repr(self) -> str: + s = f"head_size={self.impl.head_size}" # type: ignore + s += f", num_heads={self.impl.num_heads}" # type: ignore + s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore + s += f", scale={self.impl.scale}" # type: ignore + s += f", backend={self.impl.__class__.__name__}" + return s diff --git a/vllm/wde/core/llm_engine.py b/vllm/wde/core/llm_engine.py new file mode 100644 index 0000000000000..1edebf52e9275 --- /dev/null +++ b/vllm/wde/core/llm_engine.py @@ -0,0 +1,234 @@ +from contextlib import contextmanager +from queue import Empty, Queue +from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Type, Union + +from vllm.logger import init_logger +from vllm.wde.core.arg_utils import EngineArgs +from vllm.wde.core.config import EngineConfig +from vllm.wde.core.schema.engine_io import Inputs, Params, RequestOutput +from vllm.wde.core.workflow import Workflow + +logger = init_logger(__name__) +_O = RequestOutput + + +def lazy_import(module): + module_name, class_name = module.split(":") + import importlib + module = importlib.import_module(module_name) + return getattr(module, class_name) + + +class LLMEngine: + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return output + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + def __init__(self, engine_config: EngineConfig, + workflow_cls: Type[Workflow]) -> None: + self.engine_config = engine_config + self.engine_config.log_config() + self.workflow = workflow_cls.from_engine(self) + + self._maybe_init_async_scheduling() + + self.attn_backend = lazy_import( + self.workflow.AttnBackend).from_engine(self) + self.executor = lazy_import(self.workflow.Executor).from_engine(self) + self.tokenizer = lazy_import(self.workflow.Tokenizer).from_engine(self) + self.model_inputs_builder = lazy_import( + self.workflow.ModelInputBuilder).from_engine(self) + + self.input_processor = lazy_import( + self.workflow.InputProcessor).from_engine(self) + self.request_processor = lazy_import( + self.workflow.RequestProcessor).from_engine(self) + self.scheduler = lazy_import(self.workflow.Scheduler).from_engine(self) + self.output_processor = lazy_import( + self.workflow.OutputProcessor).from_engine(self) + + def _maybe_init_async_scheduling(self): + executor_cls = lazy_import(self.workflow.Executor) + scheduler_cls = lazy_import(self.workflow.Scheduler) + + if ("async_scheduling" in executor_cls.support_scheduling + and "async_scheduling" in scheduler_cls.support_scheduling): + logger.info("Use async scheduling") + self.use_async_scheduling = True + + elif ("sync_scheduling" in executor_cls.support_scheduling + and "sync_scheduling" in scheduler_cls.support_scheduling): + logger.info("Use sync scheduling") + self.use_async_scheduling = False + + else: + raise RuntimeError(f"Executor support scheduling: " + f"{executor_cls.support_scheduling}." + f"Scheduler support scheduling: " + f"{executor_cls.support_scheduling}." + f"Not compatible") + + if self.use_async_scheduling: + self.executor_in = Queue() + self.executor_out = Queue() + self.max_num_on_the_fly = ( + self.engine_config.scheduler_config.max_num_on_the_fly) + self.num_on_the_fly = 0 + self.step = self.async_step + else: + self.step = self.sync_step + + @classmethod + def from_engine_args(cls, engine_args: Union[Dict, + EngineArgs]) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + + from vllm.transformers_utils.config import get_config + from vllm.wde.core.loader.utils import get_model_workflow + + if isinstance(engine_args, EngineArgs): + engine_args = engine_args.to_dict() + + hf_config = get_config(engine_args["model"], + engine_args.get("trust_remote_code", False), + engine_args.get("revision", None), + engine_args.get("code_revision", None)) + + workflow_class = get_model_workflow(hf_config) + workflow = lazy_import(workflow_class) + + engine_args = lazy_import(workflow.EngineArgs)(**engine_args) + + engine_config = engine_args.create_engine_config() + engine = cls(engine_config, workflow) + return engine + + def add_request(self, + request_id: str, + inputs: Optional[Union[str, Inputs]] = None, + params: Optional[Params] = None, + arrival_time: Optional[float] = None) -> None: + request = self.input_processor(request_id, inputs, params, + arrival_time) + + # The raised ValidationError will be passed to the upper call stack + self.scheduler.add_request(request) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + self.scheduler.abort_request(request_id) + + def sync_step(self) -> List[RequestOutput]: + scheduler_output = self.scheduler.schedule() + if scheduler_output.is_empty(): + return [] + + executor_input = self.model_inputs_builder(scheduler_output) + executor_output = self.executor.execute_model(executor_input) + request_outputs = self.output_processor(scheduler_output, + executor_output) + self.scheduler.free_finished_request(request_outputs) + request_outputs = self.scheduler.remove_abort_request(request_outputs) + return request_outputs + + def async_step(self) -> List[RequestOutput]: + self.executor.ensure_start_execute_loop() + self._put_as_many_as_possible() + + if self.num_on_the_fly == 0: + return [] + + return self._get(block=True) + + def _put_as_many_as_possible(self): + while self.num_on_the_fly < self.max_num_on_the_fly: + scheduler_output = self.scheduler.schedule() + if scheduler_output.is_empty(): + break + executor_input = self.model_inputs_builder(scheduler_output) + + self.executor_in.put((scheduler_output, executor_input)) + self.num_on_the_fly += 1 + + def _get(self, block): + try: + scheduler_output, executor_output = self.executor_out.get(block) + except Empty: + return + + self.num_on_the_fly -= 1 + + # Theoretically, this put is not needed + # practically, task can be inqueue before doing post-processing + self._put_as_many_as_possible() + + request_outputs = self.output_processor(scheduler_output, + executor_output) + self.scheduler.free_finished_request(request_outputs) + request_outputs = self.scheduler.remove_abort_request(request_outputs) + return request_outputs + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return self.scheduler.get_num_unfinished_requests() + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return self.scheduler.has_unfinished_requests() + + def __reduce__(self): + # This is to ensure that the LLMEngine is not referenced in + # the closure used to initialize Ray worker actors + raise RuntimeError("LLMEngine should not be pickled!") + + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if executor := getattr(self, "executor", None): + executor.shutdown_execute_loop() diff --git a/vllm/wde/core/loader/__init__.py b/vllm/wde/core/loader/__init__.py new file mode 100644 index 0000000000000..927f173387f8c --- /dev/null +++ b/vllm/wde/core/loader/__init__.py @@ -0,0 +1,3 @@ +from vllm.wde.core.loader.loader import get_model_loader, initialize_model + +__all__ = ["get_model_loader", "initialize_model"] diff --git a/vllm/wde/core/loader/loader.py b/vllm/wde/core/loader/loader.py new file mode 100644 index 0000000000000..ed26f73b5a4fb --- /dev/null +++ b/vllm/wde/core/loader/loader.py @@ -0,0 +1,631 @@ +# ruff: noqa: SIM117 +import fnmatch +import glob +import json +import math +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, Generator, List, Optional, Tuple + +import huggingface_hub +import numpy as np +import torch +from huggingface_hub import HfApi, hf_hub_download +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, + get_quant_config, np_cache_weights_iterator, pt_weights_iterator, + safetensors_weights_iterator) +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import is_pin_memory_available +from vllm.wde.core.config import DeviceConfig, LoadConfig, LoadFormat +from vllm.wde.core.layers.attention.abstract import AttentionBackend +from vllm.wde.core.loader.utils import (get_model_architecture, + set_default_torch_dtype) + + +@contextmanager +def device_loading_context(module: torch.nn.Module, + target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided(size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + + +logger = init_logger(__name__) + + +def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + return quant_config + return None + + +def initialize_model( + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + attn_backend: AttentionBackend, + cache_config: Optional[CacheConfig] = None, +) -> nn.Module: + """Initialize a model with the given configurations.""" + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model_class = get_model_architecture(model_config)[0] + quant_config = _get_quantization_config(model_config, load_config) + + return model_class(config=model_config.hf_config, + cache_config=cache_config, + quant_config=quant_config, + attn_backend=attn_backend) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + """Load a model with the given configurations.""" + ... + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, index_file, + self.load_config.download_dir, revision) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], + fall_back_to_pt: bool + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision, fall_back_to_pt) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + model_name_or_path, self.load_config.download_dir, hf_folder, + hf_weights_files) + elif use_safetensors: + weights_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) + + return weights_iterator + + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + model.load_weights( + self._get_weights_iterator(model_config.model, + model_config.revision, + fall_back_to_pt=getattr( + model, + "fall_back_to_pt_during_load", + True)), ) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model.eval() + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + return model.eval() + + +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + default_target_modules = [ + "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", + "o_proj" + ] + + possible_config_file_names = ["adapter_config.json"] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # we don't need to quantize the whole model, only the target modules + # that are specified in the adapter config file. If the adapter config + # file is not provided, we will quantize the default modules. + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + self.target_modules = self.default_target_modules + return + + qlora_adapter = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + config_file_path = self._get_config_file(qlora_adapter) + + with open(config_file_path, "r") as f: + config = json.load(f) + self.target_modules = config["target_modules"] + + def _get_config_file(self, qlora_adapter: str) -> str: + is_local = os.path.isdir(qlora_adapter) + config_file_path = None + if is_local: + for file in self.possible_config_file_names: + config_file_path = os.path.join(qlora_adapter, file) + if os.path.exists(config_file_path): + break + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) + for file in self.possible_config_file_names: + if file in repo_files: + config_file_path = hf_hub_download(repo_id=qlora_adapter, + filename=file) + break + + if not config_file_path: + raise ValueError( + f"Cannot find adapter config file in {qlora_adapter}") + + return config_file_path + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None) -> Tuple[List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob( + os.path.join(model_name_or_path, pattern)) + if weight_files: + return weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + return glob.glob(os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError( + f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision) + + if matched_pattern != "*.safetensors": + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_weights_files, matched_pattern == "*.safetensors" + + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + if use_safetensors: + return safetensors_weights_iterator(hf_weights_files) + else: + return pt_weights_iterator(hf_weights_files) + + def _get_quantized_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], pre_quant: bool + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, + Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + from bitsandbytes.functional import QuantState + if bitsandbytes.__version__ < "0.42.0": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.42.0.") + from bitsandbytes.functional import quantize_4bit + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.42.0 via " + "`pip install bitsandbytes>=0.42.0` to use " + "bitsandbytes quantizer.") from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision) + + quant_state_dict = {} + + def quantized_checkpoint() -> Generator: + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for weight_name, weight_tensor in weight_iterator: + if weight_name.endswith(".weight"): + continue + # TODO: only nf4 quantization is supported for now + if weight_name.endswith(".quant_state.bitsandbytes__fp4"): + raise NotImplementedError( + "Only bitsandbytes_nf4 quantization" + f"is supported for now. {weight_name} is fp4 quantized" + ) + temp_state_dict[weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__nf4 in CPU + quant_state[param_name + + ".quant_state.bitsandbytes__nf4"] = quant_state[ + param_name + + ".quant_state.bitsandbytes__nf4"].cpu().data + return QuantState.from_dict(quant_state, device="cuda") + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + # Filter out all weights whose suffix is not ".weight" + if not weight_name.endswith(".weight"): + continue + if weight_name + ".quant_state.bitsandbytes__nf4" \ + in temp_state_dict: + quant_state = _parse_quant_state(weight_name, + temp_state_dict) + weight_name = weight_name.replace(".weight", ".qweight") + quant_state_dict[weight_name] = quant_state + yield weight_name.replace(".weight", + ".qweight"), weight_tensor + else: + yield weight_name, weight_tensor + + def generator() -> Generator: + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + if any(target_module in weight_name + for target_module in self.target_modules): + weight_name = weight_name.replace(".weight", ".qweight") + # bitsandbytes requires data in GPU + loaded_weight = weight_tensor.cuda().data + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4") + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight + + if pre_quant: + return quantized_checkpoint(), quant_state_dict + return generator(), quant_state_dict + + def _load_weights(self, model_config: ModelConfig, + model: nn.Module) -> None: + if not hasattr(model, 'load_weights'): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(self).__name__}.") + + if not hasattr(model, 'bitsandbytes_stacked_params_mapping'): + raise AttributeError( + f"Model {type(self).__name__} does not support BitsAndBytes " + "quantization yet.") + + logger.info("Loading weights with BitsAndBytes quantization. " + " May take a while ...") + + is_quantized_checkpoint = False + quant_config = getattr(model_config.hf_config, "quantization_config", + None) + if quant_config is not None and quant_config.get( + 'quant_method') == "bitsandbytes": + is_quantized_checkpoint = True + + qweight_iterator, quant_state_dict = \ + self._get_quantized_weights_iterator( + model_config.model, model_config.revision, is_quantized_checkpoint) + + model.load_weights(qweight_iterator) + + torch.cuda.empty_cache() + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + for quant_param_name in quant_state_dict: + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, index + ) in model.bitsandbytes_stacked_params_mapping.items(): + if shard_name in quant_param_name: + shard_index = index + quant_param_name = quant_param_name.replace( + shard_name, weight_name) + break + + if quant_param_name not in param_dict: + raise ValueError( + f"Parameter {quant_param_name} not found in the model.") + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = ( + quant_state_dict[non_stacked_param_name]) + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError( + f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in quant_states.items(): + num_elements[seq] = math.prod( + quant_state.shape) // pack_ratio + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + self._load_weights(model_config, model) + + return model.eval() + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + + return DefaultModelLoader(load_config) diff --git a/vllm/wde/core/loader/utils.py b/vllm/wde/core/loader/utils.py new file mode 100644 index 0000000000000..c03c1acaf7dbd --- /dev/null +++ b/vllm/wde/core/loader/utils.py @@ -0,0 +1,48 @@ +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Tuple, Type + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.wde.core.config import ModelConfig +from vllm.wde.core.modelzoo import ModelRegistry + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture( + model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_model_workflow(hf_config: PretrainedConfig) -> str: + architectures = getattr(hf_config, "architectures", []) + + for arch in architectures: + workflow = ModelRegistry.get_workflow(arch) + if workflow is not None: + return workflow + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] diff --git a/vllm/wde/core/modelzoo.py b/vllm/wde/core/modelzoo.py new file mode 100644 index 0000000000000..626f18cc81e48 --- /dev/null +++ b/vllm/wde/core/modelzoo.py @@ -0,0 +1,64 @@ +import functools +import importlib +from typing import Dict, List, Optional, Type + +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.wde.encode_only.modelzoo import ENCODE_ONLY_MODELS + +logger = init_logger(__name__) + +_MODELS_LIST = [ENCODE_ONLY_MODELS] + +_MODELS = dict() +for m in _MODELS_LIST: + _MODELS.update(**m) + +# Architecture -> type. +# out of tree models +_OOT_MODELS: Dict[str, Type[nn.Module]] = {} + + +class ModelRegistry: + + @staticmethod + @functools.lru_cache(maxsize=128) + def _get_model(model_arch: str): + module_str, workflow = _MODELS[model_arch] + module_name, model_cls_name = module_str.split(":") + module = importlib.import_module(module_name) + return getattr(module, model_cls_name, None) + + @staticmethod + def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] + if model_arch not in _MODELS: + return None + return ModelRegistry._get_model(model_arch) + + @staticmethod + def get_supported_archs() -> List[str]: + return list(_MODELS.keys()) + + @staticmethod + @functools.lru_cache(maxsize=128) + def get_workflow(model_arch: str): + module_str, workflow = _MODELS[model_arch] + return workflow + + @staticmethod + def register_model(model_arch: str, model_cls: Type[nn.Module]): + if model_arch in _MODELS: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls.__name__) + global _OOT_MODELS + _OOT_MODELS[model_arch] = model_cls + + +__all__ = [ + "ModelRegistry", +] diff --git a/vllm/wde/core/processor/__init__.py b/vllm/wde/core/processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/processor/input_processor.py b/vllm/wde/core/processor/input_processor.py new file mode 100644 index 0000000000000..748d46e6f8ad9 --- /dev/null +++ b/vllm/wde/core/processor/input_processor.py @@ -0,0 +1,120 @@ +import time +from abc import ABC, abstractmethod +from typing import Optional, Union + +from vllm.wde.core.processor.tokenizer import Tokenizer +from vllm.wde.core.schema.engine_io import (Inputs, Params, PromptInput, + Request, SchedulableRequest, + TextOnlyInputs, TextPrompt, + TextRequest, + TextSchedulableRequest, + TokensPrompt, ValidationError) + + +class InputProcessor(ABC): + """ + Input(request_id, inputs, params, arrival_time) -> InputProcessor -> Request + """ + + @abstractmethod + def __call__(self, + request_id: str, + inputs: Optional[Union[str, Inputs]] = None, + params: Optional[Params] = None, + arrival_time: Optional[float] = None) -> Request: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine): + raise NotImplementedError + + +class TextInputProcessor(InputProcessor): + + def __call__(self, + request_id: str, + inputs: Optional[PromptInput] = None, + params: Optional[Params] = None, + arrival_time: Optional[float] = None) -> TextRequest: + + if isinstance(inputs, str): + inputs = {"prompt": inputs} + elif isinstance(inputs, TextPrompt): + inputs = {"prompt": inputs.prompt} + elif isinstance(inputs, TokensPrompt): + inputs = {"prompt_token_ids": inputs.prompt_token_ids} + elif isinstance(inputs, TextOnlyInputs): + _inputs = {"prompt_token_ids": inputs.prompt_token_ids} + + if inputs.prompt is not None: + _inputs["prompt"] = inputs.prompt + + inputs = _inputs + + elif isinstance(inputs, dict): + if "prompt" not in inputs and "prompt_token_ids" not in inputs: + raise ValidationError('"prompt" and "prompt_token_ids" ' + 'have at least one in inputs.') + inputs = { + k: v + for k, v in inputs.items() + if k in {"prompt", "prompt_token_ids"} + } + else: + raise ValidationError( + f"Input does not support {type(inputs)} data type") + + if not arrival_time: + arrival_time = time.time() + request = TextRequest(request_id=str(request_id), + inputs=inputs, + arrival_time=arrival_time) + return request + + @classmethod + def from_engine(cls, engine): + return cls() + + +class RequestProcessor(ABC): + """ + Request -> RequestProcessor -> SchedulableRequest + """ + + @abstractmethod + def __call__(self, request: Request) -> SchedulableRequest: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine): + raise NotImplementedError + + +class TextRequestProcessor(RequestProcessor): + + def __init__(self, tokenizer: Tokenizer): + self.tokenizer = tokenizer + + def __call__(self, request: TextRequest) -> TextSchedulableRequest: + inputs = request.inputs + + if "prompt_token_ids" not in inputs: + tokenizer = self.tokenizer + + prompt_token_ids = tokenizer.encode(inputs["prompt"]) + else: + prompt_token_ids = inputs["prompt_token_ids"] + + schedulable_request = TextSchedulableRequest( + request_id=request.request_id, + inputs=TextOnlyInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt")), + arrival_time=request.arrival_time) + + return schedulable_request + + @classmethod + def from_engine(cls, engine): + return cls(engine.tokenizer) diff --git a/vllm/wde/core/processor/model_input_builder.py b/vllm/wde/core/processor/model_input_builder.py new file mode 100644 index 0000000000000..572b8af7bb837 --- /dev/null +++ b/vllm/wde/core/processor/model_input_builder.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + +from vllm.wde.core.schema.engine_io import SchedulerOutput +from vllm.wde.core.schema.execute_io import ExecuteInput + + +class ModelInputBuilder(ABC): + """ + scheduler_output = scheduler.schedule() + SchedulerOutput -> ModelInputBuilder -> ExecuteInput + """ + + @abstractmethod + def __call__(self, scheduler_output: SchedulerOutput) -> ExecuteInput: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine): + raise NotImplementedError diff --git a/vllm/wde/core/processor/output_processor.py b/vllm/wde/core/processor/output_processor.py new file mode 100644 index 0000000000000..17d833c34d0d2 --- /dev/null +++ b/vllm/wde/core/processor/output_processor.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod +from typing import List + +import torch + +from vllm.wde.core.schema.engine_io import RequestOutput, SchedulerOutput + + +class OutputProcessor(ABC): + """ + scheduler_output, execute_output -> OutputProcessor -> RequestOutput + """ + + @abstractmethod + def __call__(self, scheduler_output: SchedulerOutput, + execute_output: torch.Tensor) -> List[RequestOutput]: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine): + raise NotImplementedError diff --git a/vllm/wde/core/processor/tokenizer.py b/vllm/wde/core/processor/tokenizer.py new file mode 100644 index 0000000000000..cb2d1057cf7ba --- /dev/null +++ b/vllm/wde/core/processor/tokenizer.py @@ -0,0 +1,162 @@ +from typing import Optional, Union + +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class Tokenizer: + + def __init__(self, tokenizer_name: str, **kwargs): + self.tokenizer_name = tokenizer_name + self.tokenizer_kwargs = kwargs + + self.tokenizer = get_tokenizer(tokenizer_name=self.tokenizer_name, + **self.tokenizer_kwargs) + + @classmethod + def from_engine(cls, engine): + init_kwargs = dict( + tokenizer_name=engine.engine_config.model_config.tokenizer, + tokenizer_mode=engine.engine_config.model_config.tokenizer_mode, + trust_remote_code=engine.engine_config.model_config. + trust_remote_code, + revision=engine.engine_config.model_config.tokenizer_revision) + + return cls(**init_kwargs) + + def __call__(self, *args, **kwargs): + return self.tokenizer(*args, **kwargs) + + def encode(self, *args, **kwargs): + return self.tokenizer.encode(*args, **kwargs) + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id + + +def get_cached_tokenizer( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Get tokenizer with cached properties. + + This will patch the tokenizer object in place. + + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. This + function caches these properties for faster access.""" + + tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_tokens_extended = ( + tokenizer.all_special_tokens_extended) + tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) + tokenizer_len = len(tokenizer) + + class CachedTokenizer(tokenizer.__class__): # type: ignore + + @property + def all_special_ids(self): + return tokenizer_all_special_ids + + @property + def all_special_tokens(self): + return tokenizer_all_special_tokens + + @property + def all_special_tokens_extended(self): + return tokenizer_all_special_tokens_extended + + def __len__(self): + return tokenizer_len + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + tokenizer.__class__ = CachedTokenizer + return tokenizer + + +def get_tokenizer( + tokenizer_name: str, + *args, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via HuggingFace or ModelScope. + """ + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + import os + + import huggingface_hub + from modelscope.hub.snapshot_download import snapshot_download + + # Only set the tokenizer here, model will be downloaded on the workers. + if not os.path.exists(tokenizer_name): + tokenizer_path = snapshot_download( + model_id=tokenizer_name, + cache_dir=download_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + # Ignore weights - we only need the tokenizer. + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + tokenizer_name = tokenizer_path + + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError( + "Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + if "truncation_side" not in kwargs: + kwargs["truncation_side"] = "left" + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + if (not trust_remote_code and + ("does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e))): + err_msg = ( + "Failed to load the tokenizer. If the tokenizer is a custom " + "tokenizer not yet available in the HuggingFace transformers " + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + except AttributeError as e: + if "BaichuanTokenizer" in str(e): + # This is for the error "'BaichuanTokenizer' object has no + # attribute 'sp_model'". + from vllm.transformers_utils.tokenizers import BaichuanTokenizer + tokenizer = BaichuanTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + else: + raise e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead.") + return get_cached_tokenizer(tokenizer) diff --git a/vllm/wde/core/scheduler.py b/vllm/wde/core/scheduler.py new file mode 100644 index 0000000000000..3cc45fb59001a --- /dev/null +++ b/vllm/wde/core/scheduler.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from collections import deque +from typing import Deque, Iterable, List, Union + +from vllm.logger import init_logger +from vllm.wde.core.config import SchedulerConfig +from vllm.wde.core.processor.input_processor import RequestProcessor +from vllm.wde.core.schema.engine_io import (Request, RequestOutput, + SchedulerOutput) + +logger = init_logger(__name__) + + +class Scheduler(ABC): + support_scheduling = [] + + def __init__( + self, + scheduler_config: SchedulerConfig, + request_processor: RequestProcessor, + ) -> None: + self.scheduler_config = scheduler_config + self.request_processor = request_processor + + self.waiting: Deque[Request] = deque() + + self.requests = set() + self.aborted_requests = set() + + @classmethod + def from_engine(cls, engine) -> "Scheduler": + raise NotImplementedError + + def add_request(self, request: Request) -> None: + if (request.request_id in self.requests + or request.request_id in self.aborted_requests): + logger.warning("[%s] request_id conflict") + return + + self.waiting.append(request) + self.requests.add(request.request_id) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + + self.requests -= request_ids + self.aborted_requests |= request_ids + + def remove_abort_request( + self, request_outputs: List[RequestOutput]) -> List[RequestOutput]: + if len(self.aborted_requests) == 0: + return request_outputs + + current_ids = set(request.request_id for request in request_outputs) + need_abort = self.aborted_requests & current_ids + + if len(need_abort) == 0: + return request_outputs + + request_outputs = [ + request for request in request_outputs + if request.request_id not in need_abort + ] + self.aborted_requests -= need_abort + + return request_outputs + + def has_unfinished_requests(self) -> bool: + return len(self.requests) != 0 + + def get_num_unfinished_requests(self) -> int: + return len(self.requests) + + @abstractmethod + def schedule(self) -> SchedulerOutput: + raise NotImplementedError + + def free_finished_request(self, request_outputs: List[RequestOutput]): + finished_request_ids = set(request.request_id + for request in request_outputs + if request.finished) + self.requests -= finished_request_ids diff --git a/vllm/wde/core/schema/__init__.py b/vllm/wde/core/schema/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/schema/engine_io.py b/vllm/wde/core/schema/engine_io.py new file mode 100644 index 0000000000000..7179f6f389ae9 --- /dev/null +++ b/vllm/wde/core/schema/engine_io.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + + +class Params: + pass + + +class Inputs: + pass + + +@dataclass +class TextPrompt(Inputs): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + +@dataclass +class TokensPrompt(Inputs): + """Schema for a tokenized prompt.""" + + prompt_token_ids: List[int] + """A list of token IDs to pass to the model.""" + + +@dataclass +class TextOnlyInputs(Inputs): + prompt_token_ids: List[int] + """The token IDs of the prompt.""" + + prompt: Optional[str] = None + """ + The original prompt text corresponding to the token IDs, if available. + """ + + +PromptInput = Union[str, Dict, TextPrompt, TokensPrompt, TextOnlyInputs] + + +@dataclass +class Request: + request_id: str + arrival_time: float + + +@dataclass +class TextRequest(Request): + inputs: Dict + + +class ValidationError(ValueError): + pass + + +class SchedulableRequest(Request): + pass + + +@dataclass +class TextSchedulableRequest(SchedulableRequest): + inputs: TextOnlyInputs + + @property + def num_new_tokens(self): + return len(self.inputs.prompt_token_ids) + + +@dataclass +class SchedulerOutput: + pass + + +@dataclass +class RequestOutput(Request): + finished: bool diff --git a/vllm/wde/core/schema/execute_io.py b/vllm/wde/core/schema/execute_io.py new file mode 100644 index 0000000000000..67d91a8b078dd --- /dev/null +++ b/vllm/wde/core/schema/execute_io.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ModelInput: + pass + + +@dataclass +class WorkerInput: + pass + + +@dataclass +class ExecuteInput: + worker_input: Optional[WorkerInput] + model_input: Optional[ModelInput] + + +class ExecuteOutput: + pass diff --git a/vllm/wde/core/worker/__init__.py b/vllm/wde/core/worker/__init__.py new file mode 100644 index 0000000000000..7c59cbdf14007 --- /dev/null +++ b/vllm/wde/core/worker/__init__.py @@ -0,0 +1,3 @@ +from .worker import WorkerBase, WorkerWrapperBase, create_worker + +__all__ = ["create_worker", "WorkerBase", "WorkerWrapperBase"] \ No newline at end of file diff --git a/vllm/wde/core/worker/utils.py b/vllm/wde/core/worker/utils.py new file mode 100644 index 0000000000000..bf3621addf52a --- /dev/null +++ b/vllm/wde/core/worker/utils.py @@ -0,0 +1,53 @@ +from typing import List, Optional + +import torch + + +class FakeGroupCoordinator: + rank: int = 0 + ranks: List[int] = [0] + world_size: int = 1 + local_rank: int = 0 + rank_in_group: int = 0 + + def destroy(self): + pass + + @property + def first_rank(self): + return self.ranks[0] + + @property + def last_rank(self): + return self.ranks[-1] + + @property + def is_first_rank(self): + return self.rank == self.first_rank + + @property + def is_last_rank(self): + return self.rank == self.last_rank + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + return input_ + + +def fix_distributed_environment(): + # This dirty_fix can make ParallelLinear etc. work properly. + # Why should tp and model layers be coupled together? + + import vllm.distributed.parallel_state + + fake_parallel_group = FakeGroupCoordinator() + vllm.distributed.parallel_state._TP = fake_parallel_group + vllm.distributed.parallel_state._PP = fake_parallel_group diff --git a/vllm/wde/core/worker/worker.py b/vllm/wde/core/worker/worker.py new file mode 100644 index 0000000000000..9cdfe9628dda3 --- /dev/null +++ b/vllm/wde/core/worker/worker.py @@ -0,0 +1,104 @@ +import importlib +import os +from abc import ABC, abstractmethod +from typing import Callable, Dict, Optional, Type + +from vllm.logger import init_logger +from vllm.utils import (enable_trace_function_call_for_thread, + update_environment_variables) +from vllm.wde.core.schema.execute_io import ExecuteInput, ExecuteOutput + +logger = init_logger(__name__) + + +class WorkerBase(ABC): + + @abstractmethod + def __call__(self, execute_input: ExecuteInput) -> ExecuteOutput: + raise NotImplementedError + + +class WorkerWrapperBase: + """ + The whole point of this class is to lazily initialize the worker. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + + If worker_class_fn is specified, it will be executed to get the worker + class. + Otherwise, the worker class will be obtained by dynamically importing it + using worker_module_name and worker_class_name. + """ + + def __init__( + self, + worker_module_name: str, + worker_class_name: str, + trust_remote_code: bool = False, + worker_class_fn: Optional[Callable[[], + Type[WorkerBase]]] = None) -> None: + self.worker_module_name = worker_module_name + self.worker_class_name = worker_class_name + self.worker_class_fn = worker_class_fn + self.worker: Optional[WorkerBase] = None + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, *args, **kwargs): + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. + """ + enable_trace_function_call_for_thread() + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + if self.worker_class_fn: + worker_class = self.worker_class_fn() + else: + mod = importlib.import_module(self.worker_module_name) + worker_class = getattr(mod, self.worker_class_name) + + self.worker = worker_class(*args, **kwargs) + assert self.worker is not None + + def execute_method(self, method, *args, **kwargs): + try: + target = self if self.worker is None else self.worker + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + + +def create_worker(module, envs=None, **kwargs): + module_name, class_name = module.split(":") + wrapper = WorkerWrapperBase( + worker_module_name=module_name, + worker_class_name=class_name, + ) + if envs: + wrapper.update_environment_variables(envs) + + wrapper.init_worker(**kwargs) + return wrapper.worker \ No newline at end of file diff --git a/vllm/wde/core/workflow.py b/vllm/wde/core/workflow.py new file mode 100644 index 0000000000000..1a87ecd25eda7 --- /dev/null +++ b/vllm/wde/core/workflow.py @@ -0,0 +1,16 @@ +class Workflow: + EngineArgs: str + Scheduler: str + AttnBackend: str + attn_type: str + Tokenizer: str = "vllm.wde.core.processor.tokenizer:Tokenizer" + InputProcessor: str + RequestProcessor: str + OutputProcessor: str + ModelInputBuilder: str + Executor: str + Worker: str + + @classmethod + def from_engine(cls, engine): + return cls() diff --git a/vllm/wde/encode_only/__init__.py b/vllm/wde/encode_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/encode_only/arg_utils.py b/vllm/wde/encode_only/arg_utils.py new file mode 100644 index 0000000000000..65a56e80b4df0 --- /dev/null +++ b/vllm/wde/encode_only/arg_utils.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +from vllm.logger import init_logger +from vllm.wde.core.arg_utils import EngineArgs +from vllm.wde.core.config import (DeviceConfig, LoadConfig, + filter_unexpected_fields) +from vllm.wde.encode_only.config import (EncodeOnlyEngineConfig, ModelConfig, + PrefillOnlySchedulerConfig) + +logger = init_logger(__name__) + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +@filter_unexpected_fields +@dataclass +class EncodeOnlyEngineArgs(EngineArgs): + """Arguments for vLLM engine.""" + model: str + served_model_name: Optional[Union[List[str]]] = None + tokenizer: Optional[str] = None + skip_tokenizer_init: bool = False + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = 'auto' + dtype: str = 'auto' + kv_cache_dtype: str = 'auto' + quantization_param_path: Optional[str] = None + disable_sliding_window: bool = False + seed: int = 0 + + max_model_len: Optional[int] = None + max_num_batched_tokens: Optional[int] = None + max_num_seqs: int = 256 + max_num_on_the_fly: int = 3 + scheduling: str = "async" + + disable_log_stats: bool = False + revision: Optional[str] = None + code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + disable_custom_all_reduce: bool = False + device: str = 'auto' + model_loader_extra_config: Optional[dict] = None + ignore_patterns: Optional[Union[str, List[str]]] = None + + def __post_init__(self): + if self.tokenizer is None: + self.tokenizer = self.model + + def create_engine_config(self) -> EncodeOnlyEngineConfig: + device_config = DeviceConfig(device=self.device) + model_config = ModelConfig( + model=self.model, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name) + + scheduler_config = PrefillOnlySchedulerConfig( + max_num_batched_tokens=self.max_num_batched_tokens, + max_num_seqs=self.max_num_seqs, + max_model_len=model_config.max_model_len, + max_num_on_the_fly=self.max_num_on_the_fly, + scheduling=self.scheduling) + + load_config = LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + return EncodeOnlyEngineConfig(model_config=model_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config) diff --git a/vllm/wde/encode_only/config.py b/vllm/wde/encode_only/config.py new file mode 100644 index 0000000000000..3a0c2a2c4575a --- /dev/null +++ b/vllm/wde/encode_only/config.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass, fields + +from vllm.logger import init_logger +from vllm.wde.core.config import EngineConfig, ModelConfig +from vllm.wde.prefill_only.config import PrefillOnlySchedulerConfig + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +@dataclass(frozen=True) +class EncodeOnlyEngineConfig(EngineConfig): + model_config: ModelConfig + scheduler_config: PrefillOnlySchedulerConfig + + def to_dict(self): + """Return the configs as a dictionary, for use in **kwargs. + """ + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) + + def log_config(self): + from vllm.version import __version__ as VLLM_VERSION + logger.info( + "Initializing an Encode Only engine (v%s) with config: " + "model=%r, tokenizer=%r, " + "tokenizer_mode=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, " + "device_config=%s, served_model_name=%s, " + "max_num_on_the_fly=%d, scheduling=%s)", VLLM_VERSION, + self.model_config.model, self.model_config.tokenizer, + self.model_config.tokenizer_mode, + self.model_config.trust_remote_code, self.model_config.dtype, + self.model_config.max_model_len, self.load_config.download_dir, + self.load_config.load_format, self.device_config.device, + self.model_config.served_model_name, + self.scheduler_config.max_num_on_the_fly, + self.scheduler_config.scheduling) + if self.parallel_config is not None: + logger.info("Parallel config: data_parallel_size=%d", + self.parallel_config.data_parallel_size) diff --git a/vllm/wde/encode_only/modelzoo/__init__.py b/vllm/wde/encode_only/modelzoo/__init__.py new file mode 100644 index 0000000000000..47c0c59ce178a --- /dev/null +++ b/vllm/wde/encode_only/modelzoo/__init__.py @@ -0,0 +1,8 @@ +TASK = "encode_only" +PREFIX = f"vllm.wde.{TASK}.modelzoo" +WORKFLOW = "vllm.wde.encode_only.workflow:EncodeOnlyWorkflow" + +# Architecture -> (module, workflow). +ENCODE_ONLY_MODELS = { + "BertForMaskedLM": (PREFIX + ".bert:BertForMaskedLM", WORKFLOW), +} diff --git a/vllm/wde/encode_only/modelzoo/bert.py b/vllm/wde/encode_only/modelzoo/bert.py new file mode 100644 index 0000000000000..500faba97ceda --- /dev/null +++ b/vllm/wde/encode_only/modelzoo/bert.py @@ -0,0 +1,404 @@ +# Derived from Bert implementation posted on HuggingFace; license below: +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # noqa: E501 +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import BertConfig +from transformers.utils import logging + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import is_pp_missing_parameter +from vllm.wde.core.layers.attention import (Attention, AttentionBackend, + AttentionMetadata) +from vllm.wde.encode_only.schema.execute_io import EncodeOnlyExecuteOutput + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + + def __init__(self, config: BertConfig, quant_config: QuantizationConfig): + super().__init__() + self.config = config + self.position_embedding_type = getattr(config, + "position_embedding_type", + "absolute") + assert self.position_embedding_type == "absolute" + + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, quant_config=quant_config) + self.token_type_embeddings0 = None + self.position_embeddings = VocabParallelEmbedding( + config.max_position_embeddings, + config.hidden_size, + quant_config=quant_config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def init_token_type_embeddings0(self): + del self.token_type_embeddings0 + self.register_buffer( + "token_type_embeddings0", + torch.zeros(self.config.hidden_size, + dtype=self.word_embeddings.weight.dtype, + device=self.word_embeddings.weight.device)) + + def forward(self, input_ids, positions): + embeddings = self.word_embeddings(input_ids) + if self.token_type_embeddings0 is not None: + token_type_embeddings = self.token_type_embeddings0 + embeddings += token_type_embeddings + + embeddings += self.position_embeddings(positions) + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + num_kv_heads = config.num_attention_heads + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = hidden_size // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + num_heads, + num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + quant_config=quant_config, + attn_backend=attn_backend) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn(q, k, v, attn_metadata) + return attn_output + + +class BertSelfOutput(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + quant_config=quant_config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.self = BertSelfAttention(config, attn_backend) + self.output = BertSelfOutput(config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + self_outputs = self.self(hidden_states, attn_metadata) + attention_output = self.output(self_outputs, hidden_states) + return attention_output + + +class BertIntermediate(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.attention = BertAttention(config, attn_backend, quant_config) + self.intermediate = BertIntermediate(config, quant_config) + self.output = BertOutput(config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + attention_output = self.attention(hidden_states, attn_metadata) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.layer = nn.ModuleList([ + BertLayer(config, attn_backend, quant_config) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, attn_metadata) + return hidden_states + + +class BertPooler(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.hidden_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + self.activation = nn.Tanh() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + seq_start_loc = attn_metadata.seq_start_loc + first_token_tensor = hidden_states[seq_start_loc[:-1]] + pooled_output, _ = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + add_pooling_layer: bool = True, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.embeddings = BertEmbeddings(config, quant_config) + self.encoder = BertEncoder(config, attn_backend, quant_config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor]: + embedding_output = self.embeddings( + input_ids=input_ids, + positions=positions, + ) + sequence_output = self.encoder(embedding_output, attn_metadata) + pooled_output = self.pooler( + sequence_output, + attn_metadata) if self.pooler is not None else None + return sequence_output, pooled_output + + +class BertForMaskedLM(nn.Module): + _ignore_weights_keys = [ + "cls.predictions.transform.LayerNorm.gamma", + "cls.predictions.transform.dense.weight", + "cls.seq_relationship.weight", + ] + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None, + *args, + **kwargs): + super().__init__() + self.config = config + self.quant_config = quant_config + + self.bert = BertModel(config, + attn_backend, + quant_config=quant_config, + add_pooling_layer=True) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> EncodeOnlyExecuteOutput: + + sequence_output, pooled_output = self.bert( + input_ids, + positions, + attn_metadata, + ) + + return EncodeOnlyExecuteOutput(last_hidden_states=sequence_output, + pooled_output=pooled_output) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v") + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + for name, loaded_weight in weights: + if hasattr(self, "prefix"): + name = self.prefix + name + + if name in self._ignore_weights_keys: + continue + + if name == "bert.embeddings.token_type_embeddings.weight": + # token_type_ids is all zero, + # so we only need token_type_embeddings[0] + self.bert.embeddings.init_token_type_embeddings0() + default_weight_loader( + self.bert.embeddings.token_type_embeddings0, + loaded_weight[0]) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # https://huggingface.co/google-bert/bert-base-uncased/discussions/70 + # https://github.com/huggingface/transformers/blob/fee86516a48c92133847fc7b44ca2f83c7c5634d/src/transformers/modeling_utils.py#L691-L720 + if "LayerNorm.gamma" in name: + name = name.replace("LayerNorm.gamma", "LayerNorm.weight") + if "LayerNorm.beta" in name: + name = name.replace("LayerNorm.beta", "LayerNorm.bias") + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/wde/encode_only/processor/__init__.py b/vllm/wde/encode_only/processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/encode_only/processor/output_processor.py b/vllm/wde/encode_only/processor/output_processor.py new file mode 100644 index 0000000000000..b8c47046d2544 --- /dev/null +++ b/vllm/wde/encode_only/processor/output_processor.py @@ -0,0 +1,49 @@ +from typing import List + +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.processor.output_processor import (OutputProcessor, + RequestOutput) +from vllm.wde.encode_only.schema.engine_io import EncodeOnlyRequestOutput +from vllm.wde.encode_only.schema.execute_io import EncodeOnlyExecuteOutput +from vllm.wde.prefill_only.schema.engine_io import PrefillOnlySchedulerOutput + + +class PrefillOnlyModelOutputProcessor(OutputProcessor): + + def __init__(self): + pass + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls() + + def __call__( + self, scheduler_output: PrefillOnlySchedulerOutput, + execute_output: EncodeOnlyExecuteOutput) -> List[RequestOutput]: + if execute_output.pooled_output is not None: + request_outputs = [] + for request, outputs in zip(scheduler_output.scheduled_requests, + execute_output.pooled_output): + prompt_token_ids = request.inputs.prompt_token_ids + request_outputs.append( + EncodeOnlyRequestOutput(request_id=request.request_id, + arrival_time=request.arrival_time, + prompt_token_ids=prompt_token_ids, + finished=True, + outputs=outputs)) + return request_outputs + else: + request_outputs = [] + offset = 0 + for request in scheduler_output.scheduled_requests: + prompt_token_ids = request.inputs.prompt_token_ids + n_tokens = len(prompt_token_ids) + request_outputs.append( + EncodeOnlyRequestOutput( + request_id=request.request_id, + arrival_time=request.arrival_time, + prompt_token_ids=prompt_token_ids, + finished=True, + outputs=execute_output[offset:offset + n_tokens])) + offset += n_tokens + return request_outputs diff --git a/vllm/wde/encode_only/schema/__init__.py b/vllm/wde/encode_only/schema/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/encode_only/schema/engine_io.py b/vllm/wde/encode_only/schema/engine_io.py new file mode 100644 index 0000000000000..7d5440ec77ac2 --- /dev/null +++ b/vllm/wde/encode_only/schema/engine_io.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import List + +import torch + +from vllm.wde.core.schema.engine_io import RequestOutput + + +@dataclass +class EncodeOnlyRequestOutput(RequestOutput): + prompt_token_ids: List[int] + outputs: torch.Tensor \ No newline at end of file diff --git a/vllm/wde/encode_only/schema/execute_io.py b/vllm/wde/encode_only/schema/execute_io.py new file mode 100644 index 0000000000000..6c43d809bc32c --- /dev/null +++ b/vllm/wde/encode_only/schema/execute_io.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.wde.core.schema.execute_io import ExecuteOutput + + +@dataclass +class EncodeOnlyExecuteOutput(ExecuteOutput): + last_hidden_states: torch.Tensor + pooled_output: Optional[torch.Tensor] = None diff --git a/vllm/wde/encode_only/workflow.py b/vllm/wde/encode_only/workflow.py new file mode 100644 index 0000000000000..ad751b5de3baf --- /dev/null +++ b/vllm/wde/encode_only/workflow.py @@ -0,0 +1,8 @@ +from vllm.wde.prefill_only.workflow import PrefillOnlyWorkflow + + +class EncodeOnlyWorkflow(PrefillOnlyWorkflow): + EngineArgs: str = "vllm.wde.encode_only.arg_utils:EncodeOnlyEngineArgs" + OutputProcessor: str = ("vllm.wde.encode_only.processor." + "output_processor:PrefillOnlyModelOutputProcessor") + attn_type: str = "ENCODER" diff --git a/vllm/wde/entrypoints/__init__.py b/vllm/wde/entrypoints/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/entrypoints/llm.py b/vllm/wde/entrypoints/llm.py new file mode 100644 index 0000000000000..f003572f629cf --- /dev/null +++ b/vllm/wde/entrypoints/llm.py @@ -0,0 +1,139 @@ +from typing import List, Optional, Sequence, Union, cast + +from tqdm import tqdm + +from vllm.logger import init_logger +from vllm.utils import Counter +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.schema.engine_io import Params, RequestOutput +from vllm.wde.core.schema.engine_io import TextOnlyInputs as PromptInputs +from vllm.wde.core.schema.engine_io import ValidationError + +logger = init_logger(__name__) + + +class LLM: + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + cpu_offload_gb: float = 0, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + removed_vision_keys = ("image_token_id", "image_feature_size", + "image_input_shape", "image_input_type") + if any(k in kwargs for k in removed_vision_keys): + raise TypeError( + "There is no need to pass vision-related arguments anymore.") + engine_args = dict( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + **kwargs, + ) + self.llm_engine = LLMEngine.from_engine_args(engine_args) + self.request_counter = Counter() + + def encode( + self, + inputs: Union[Union[PromptInputs, Sequence[PromptInputs]], + Optional[Union[str, List[str]]]] = None, + pooling_params: Optional[Union[Params, Sequence[Params]]] = None, + use_tqdm: bool = True, + ) -> List[RequestOutput]: + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], inputs) + + if pooling_params is None: + # Use default pooling params. + pooling_params = Params() + + self._validate_and_add_requests( + inputs=inputs, + params=pooling_params, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, RequestOutput) + + def _validate_and_add_requests( + self, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + params: Optional[Union[Params, Sequence[Params]]] = None, + ) -> None: + + # Add requests to the engine. + for i, request_inputs in enumerate(inputs): + try: + self._add_request( + request_inputs, + params[i] if isinstance(params, Sequence) else params) + except ValidationError as e: + raise e + + def _add_request( + self, + inputs: PromptInputs, + params: Params, + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request(request_id, inputs, params) + + def _run_engine(self, *, use_tqdm: bool) -> List[RequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, " + f"output: {0:.2f} toks/s"), + ) + # Run the engine. + outputs: List[RequestOutput] = [] + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + pbar.update(1) + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id)) diff --git a/vllm/wde/prefill_only/__init__.py b/vllm/wde/prefill_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/config.py b/vllm/wde/prefill_only/config.py new file mode 100644 index 0000000000000..deaa744a26396 --- /dev/null +++ b/vllm/wde/prefill_only/config.py @@ -0,0 +1,54 @@ +from typing import Optional + +from vllm.wde.core.config import SchedulerConfig + + +class PrefillOnlySchedulerConfig(SchedulerConfig): + + def __init__(self, + max_model_len: int, + max_num_batched_tokens: Optional[int] = None, + max_num_requests: Optional[int] = None, + max_num_seqs: Optional[int] = None, + max_num_on_the_fly: Optional[int] = 3, + scheduling: str = "sync") -> None: + self.max_model_len = max_model_len + self.max_num_requests: int = 0 + self.max_num_batched_tokens: int = 0 + self.max_num_on_the_fly: int = max_num_on_the_fly + self.scheduling = scheduling + + self.set_args(max_num_batched_tokens, max_num_requests, max_num_seqs) + + def set_args(self, + max_num_batched_tokens: Optional[int] = None, + max_num_requests: Optional[int] = None, + max_num_seqs: Optional[int] = None): + if max_num_seqs is not None: + self.max_num_requests = max_num_seqs + else: + self.max_num_requests = max_num_requests + + if max_num_batched_tokens is not None: + self.max_num_batched_tokens = max_num_batched_tokens + else: + self.max_num_batched_tokens = (self.max_model_len * + self.max_num_requests) + + self._verify_args() + + def _verify_args(self) -> None: + if self.max_num_batched_tokens < self.max_model_len: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_model_len " + f"({self.max_model_len}).") + + if self.max_num_on_the_fly < 2: + raise ValueError( + f"max_num_on_the_fly {self.max_num_on_the_fly} must " + "be greater than 1") + + if self.scheduling not in ["sync", "async", "double_buffer"]: + raise ValueError(f"scheduling {self.scheduling} must " + f"in sync, async and double_buffer") \ No newline at end of file diff --git a/vllm/wde/prefill_only/executor/__init__.py b/vllm/wde/prefill_only/executor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/executor/gpu_executor.py b/vllm/wde/prefill_only/executor/gpu_executor.py new file mode 100644 index 0000000000000..153d21a773e6b --- /dev/null +++ b/vllm/wde/prefill_only/executor/gpu_executor.py @@ -0,0 +1,207 @@ +import atexit +import queue +from queue import Queue +from threading import Thread +from typing import Optional + +import torch + +from vllm.logger import init_logger +from vllm.wde.core.config import EngineConfig +from vllm.wde.core.layers.attention import AttentionBackend +from vllm.wde.core.schema.execute_io import ExecuteInput, ExecuteOutput +from vllm.wde.core.worker import WorkerBase, create_worker +from vllm.wde.core.workflow import Workflow + +logger = init_logger(__name__) + + +class GPUExecutor: + support_scheduling = ["sync_scheduling"] + + def __init__( + self, + engine_config: EngineConfig, + workflow: Workflow, + attn_backend: AttentionBackend, + ) -> None: + self.engine_config = engine_config + self.workflow = workflow + self.attn_backend = attn_backend + self.output_to_cpu = False + self._init_executor() + + @classmethod + def from_engine(cls, engine): + return cls(engine_config=engine.engine_config, + workflow=engine.workflow, + attn_backend=engine.attn_backend) + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + + worker_kwargs = dict( + engine_config=self.engine_config, + attn_backend=self.attn_backend, + ) + worker_kwargs.update(module=self.workflow.Worker) + + self.worker = create_worker(**worker_kwargs) + self.worker.init_device() + self.worker.load_model() + + def execute_model(self, + executor_input: ExecuteInput) -> Optional[ExecuteOutput]: + executor_input.model_input.to(self.worker.device) + output = self.worker(executor_input) + if self.output_to_cpu: + output.to("cpu") + return output + + def shutdown_execute_loop(self): + pass + + +class GPUAsyncExecutor(GPUExecutor): + support_scheduling = ["async_scheduling"] + + def __init__(self, engine_config: EngineConfig, workflow: Workflow, + attn_backend: AttentionBackend, executor_in: Queue, + executor_out: Queue) -> None: + super().__init__(engine_config, workflow, attn_backend) + self.executor_in = executor_in + self.executor_out = executor_out + + self.executor_thread: Optional[Thread] = None + + if self.engine_config.scheduler_config.scheduling == "double_buffer": + self.execute_loop = double_buffer_execute_loop + else: + self.execute_loop = simple_execute_loop + + @classmethod + def from_engine(cls, engine): + return cls(engine_config=engine.engine_config, + workflow=engine.workflow, + attn_backend=engine.attn_backend, + executor_in=engine.executor_in, + executor_out=engine.executor_out) + + def ensure_start_execute_loop(self): + if self.executor_thread is None or not self.executor_thread.is_alive(): + self.executor_thread = Thread(target=self.execute_loop, + args=(self.worker, self.executor_in, + self.executor_out, + self.output_to_cpu), + daemon=True) + self.executor_thread.start() + atexit.register(self.shutdown_execute_loop) + + def shutdown_execute_loop(self): + if self.executor_thread.is_alive(): + self.executor_in.put(None) + self.executor_thread.join() + atexit.unregister(self.shutdown_execute_loop) + + +def simple_execute_loop(worker: WorkerBase, + executor_in: Queue, + executor_out: Queue, + output_to_cpu: bool = False): + + def execute_model(executor_input: ExecuteInput) -> Optional[ExecuteOutput]: + executor_input.model_input.to(worker.device) + output = worker(executor_input) + if output_to_cpu: + output.to("cpu") + return output + + while True: + o = executor_in.get() + if o is None: + break + + scheduler_output, executor_input = o + executor_output = execute_model(executor_input) + if output_to_cpu: + executor_output.to("cpu") + executor_out.put((scheduler_output, executor_output)) + + +def double_buffer_execute_loop(worker: WorkerBase, + executor_in: Queue, + executor_out: Queue, + output_to_cpu: bool = False): + from dataclasses import dataclass + + from vllm.wde.core.schema.engine_io import SchedulerOutput + + @dataclass + class Task: + scheduler_output: SchedulerOutput + executor_input: ExecuteInput + executor_output: Optional[ExecuteOutput] + + @classmethod + def get(cls, block): + o = executor_in.get(block) + if o is None: + return None + + scheduler_output, executor_input = o + + task = cls(scheduler_output=scheduler_output, + executor_input=executor_input, + executor_output=None) + return task + + current_task: Optional[Task] = None + next_task: Optional[Task] = None + compute_stream = torch.cuda.Stream() + io_stream = torch.cuda.Stream() + + go_on = True + while go_on: + if current_task is None: + current_task = Task.get(block=True) + if current_task is None: + break + + with torch.cuda.stream(compute_stream): + current_task.executor_input.model_input.to(worker.device, + non_blocking=True) + current_task.executor_output = worker( + current_task.executor_input) + end_compute = torch.cuda.Event() + else: + with torch.cuda.stream(compute_stream): + end_compute = torch.cuda.Event() + + try: + next_task = Task.get(block=False) + if next_task is None: + go_on = False + else: + with torch.cuda.stream(io_stream): + next_task.executor_input.model_input.to(worker.device, + non_blocking=True) + + compute_stream.wait_stream(io_stream) + + with torch.cuda.stream(compute_stream): + next_task.executor_output = worker( + next_task.executor_input) + except queue.Empty: + pass + + end_compute.wait() + if output_to_cpu: + with torch.cuda.stream(io_stream): + current_task.executor_output.to("cpu", non_blocking=True) + io_stream.synchronize() + executor_out.put( + (current_task.scheduler_output, current_task.executor_output)) + + current_task = next_task + next_task = None diff --git a/vllm/wde/prefill_only/layers/__init__.py b/vllm/wde/prefill_only/layers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/layers/attention/__init__.py b/vllm/wde/prefill_only/layers/attention/__init__.py new file mode 100644 index 0000000000000..8b137891791fe --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/__init__.py @@ -0,0 +1 @@ + diff --git a/vllm/wde/prefill_only/layers/attention/backends/__init__.py b/vllm/wde/prefill_only/layers/attention/backends/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/layers/attention/backends/abstract.py b/vllm/wde/prefill_only/layers/attention/backends/abstract.py new file mode 100644 index 0000000000000..ee6384b8b2278 --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/abstract.py @@ -0,0 +1,67 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Generic, List, Optional, TypeVar + +import torch + +from vllm.utils import is_pin_memory_available +from vllm.wde.core.layers.attention.abstract import (AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyAttentionBackend(AttentionBackend, ABC): + + def __init__(self, attn_type: AttentionType): + if attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyAttentionBackend") + + super().__init__(attn_type) + + +class PrefillOnlyAttentionImpl(AttentionImpl, ABC): + pass + + +@dataclass +class PrefillOnlyAttentionMetadata(AttentionMetadata): + max_seq_len: int + seq_lens: list[int] + + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + +T = TypeVar("T", bound=AttentionMetadata) + + +class PrefillOnlyAttentionMetadataBuilder(AttentionMetadataBuilder, + Generic[T]): + + def __init__(self): + pass + + def __call__(self, seq_lens: List[int]): + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.long, + pin_memory=pin_memory, + device="cpu") + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + return PrefillOnlyAttentionMetadata(seq_lens=seq_lens, + max_seq_len=max(seq_lens), + seq_start_loc=seq_start_loc) diff --git a/vllm/wde/prefill_only/layers/attention/backends/flash_attn.py b/vllm/wde/prefill_only/layers/attention/backends/flash_attn.py new file mode 100644 index 0000000000000..6b4305f7afa1c --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/flash_attn.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder) + + +class PrefillOnlyFlashAttentionBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "flash_attn" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyFlashAttentionImpl"]: + return PrefillOnlyFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyFlashAttentionMetadata"]: + return PrefillOnlyFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]: + return PrefillOnlyFlashAttentionMetadataBuilder + + +@dataclass +class PrefillOnlyFlashAttentionMetadata(PrefillOnlyAttentionMetadata): + pass + + +class PrefillOnlyFlashAttentionMetadataBuilder( + PrefillOnlyAttentionMetadataBuilder[PrefillOnlyFlashAttentionMetadata] +): + pass + + +class PrefillOnlyFlashAttentionImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + + from vllm_flash_attn import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + PrefillOnlyFlashAttentionBackend.get_supported_head_sizes()) + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: PrefillOnlyFlashAttentionMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyFlashAttentionImpl") + + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + attn_output = self.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_seq_len, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=causal, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + + # Reshape the output tensor. + return attn_output.view(num_tokens, hidden_size) diff --git a/vllm/wde/prefill_only/layers/attention/backends/flashinfer.py b/vllm/wde/prefill_only/layers/attention/backends/flashinfer.py new file mode 100644 index 0000000000000..6dfe56a03fa69 --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/flashinfer.py @@ -0,0 +1,156 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder) + + +class PrefillOnlyFlashInferBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "flashinfer" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyFlashInferImpl"]: + return PrefillOnlyFlashInferImpl + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]: + return PrefillOnlyFlashInferMetadata + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyFlashInferMetadataBuilder"]: + return PrefillOnlyFlashInferMetadataBuilder + + +@dataclass +class PrefillOnlyFlashInferMetadata(PrefillOnlyAttentionMetadata): + pass + + +class PrefillOnlyFlashInferMetadataBuilder( + PrefillOnlyAttentionMetadataBuilder[PrefillOnlyFlashInferMetadata]): + pass + + +class PrefillOnlyFlashInferImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError("PrefillOnlyFlashInferImpl does not " + "support block-sparse attention.") + + from vllm_flash_attn import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + PrefillOnlyFlashInferBackend.get_supported_head_sizes()) + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: PrefillOnlyFlashInferMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyFlashInferImpl") + + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + # Because encode only models do not involve kv cache + # When using Flashinfer backend in encode only models, + # you are actually using FLASH ATTN backend + attn_output = self.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_seq_len, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=causal, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + + # Reshape the output tensor. + return attn_output.view(num_tokens, hidden_size) diff --git a/vllm/wde/prefill_only/layers/attention/backends/torch_naive.py b/vllm/wde/prefill_only/layers/attention/backends/torch_naive.py new file mode 100644 index 0000000000000..10319d4fecd41 --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/torch_naive.py @@ -0,0 +1,161 @@ +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.utils import is_pin_memory_available +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder) + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyTorchNAIVEBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "torch_naive" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyTorchNaiveBackendImpl"]: + return PrefillOnlyTorchNaiveBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyTorchNaiveMetadata"]: + return PrefillOnlyTorchNaiveMetadata + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]: + return PrefillOnlyTorchNaiveMetadataBuilder + + +@dataclass +class PrefillOnlyTorchNaiveMetadata(PrefillOnlyAttentionMetadata): + pass + + +class PrefillOnlyTorchNaiveMetadataBuilder( + PrefillOnlyAttentionMetadataBuilder[PrefillOnlyTorchNaiveMetadata]): + pass + + +class PrefillOnlyTorchNaiveBackendImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch naive does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch naive does not support logits soft cap.") + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: PrefillOnlyTorchNaiveMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in TorchNaive.") + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyTorchNaiveBackendImpl") + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + sub_out = scaled_dot_product_attention( + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], + is_causal=causal, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def scaled_dot_product_attention(query, + key, + value, + attn_mask=None, + is_causal=False, + scale=None) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool, + device=query.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value diff --git a/vllm/wde/prefill_only/layers/attention/backends/torch_sdpa.py b/vllm/wde/prefill_only/layers/attention/backends/torch_sdpa.py new file mode 100644 index 0000000000000..8a8cba9c2aaaa --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/torch_sdpa.py @@ -0,0 +1,135 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from vllm.utils import is_pin_memory_available +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder) + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyTorchSDPABackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "torch_sdpa" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyTorchSDPABackendImpl"]: + return PrefillOnlyTorchSDPABackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyTorchSDPAMetadata"]: + return PrefillOnlyTorchSDPAMetadata + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]: + return PrefillOnlyTorchSDPAMetadataBuilder + + +@dataclass +class PrefillOnlyTorchSDPAMetadata(PrefillOnlyAttentionMetadata): + pass + + +class PrefillOnlyTorchSDPAMetadataBuilder( + PrefillOnlyAttentionMetadataBuilder[PrefillOnlyTorchSDPAMetadata]): + pass + + +class PrefillOnlyTorchSDPABackendImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch SPDA does not support logits soft cap.") + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: PrefillOnlyTorchSDPAMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in TorchSDPA.") + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyTorchSDPABackendImpl") + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + sub_out = scaled_dot_product_attention( + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], + dropout_p=0.0, + is_causal=causal, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/wde/prefill_only/layers/attention/backends/xformers.py b/vllm/wde/prefill_only/layers/attention/backends/xformers.py new file mode 100644 index 0000000000000..03c9fd5166496 --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/xformers.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + BlockDiagonalMask) + +from vllm.logger import init_logger +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder) + +logger = init_logger(__name__) + + +class PrefillOnlyXFormersBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "xformers" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyXFormersImpl"]: + return PrefillOnlyXFormersImpl + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]: + return PrefillOnlyXFormersMetadata + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyXFormersMetadataBuilder"]: + return PrefillOnlyXFormersMetadataBuilder + + +@dataclass +class PrefillOnlyXFormersMetadata(PrefillOnlyAttentionMetadata): + pass + + +class PrefillOnlyXFormersMetadataBuilder( + PrefillOnlyAttentionMetadataBuilder[PrefillOnlyXFormersMetadata]): + pass + + +class PrefillOnlyXFormersImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError( + "XFormers does not support attention logits soft capping.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + assert self.alibi_slopes is None + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: PrefillOnlyXFormersMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyXFormersImpl") + original_query = query + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + if causal: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) + else: + attn_bias = BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) + + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + out = xops.memory_efficient_attention_forward(query, + key, + value, + p=0.0, + attn_bias=attn_bias, + scale=self.scale) + return out.view_as(original_query) diff --git a/vllm/wde/prefill_only/layers/attention/selector.py b/vllm/wde/prefill_only/layers/attention/selector.py new file mode 100644 index 0000000000000..eeee6a240bde5 --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/selector.py @@ -0,0 +1,140 @@ +import enum +from typing import Optional + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.wde.core.layers.attention.abstract import AttentionType + +logger = init_logger(__name__) + + +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + OPENVINO = enum.auto() + FLASHINFER = enum.auto() + PALLAS = enum.auto() + IPEX = enum.auto() + TORCH_NAIVE = enum.auto() + + @staticmethod + def backend_name_to_enum(backend_name: str) -> "_Backend": + assert backend_name is not None + + backend_members = _Backend.__members__ + if backend_name not in backend_members: + raise ValueError( + f"Invalid attention backend '{backend_name}'. " + f"Available backends: {', '.join(backend_members)} " + "(case-sensitive).") + + return _Backend[backend_name] + + +class AttnBackend: + + @classmethod + def from_engine(cls, engine): + model_config = engine.engine_config.model_config + num_heads = model_config.get_num_attention_heads() + head_size = model_config.get_head_size() + num_kv_heads = model_config.get_num_kv_heads() + sliding_window = model_config.get_sliding_window() + dtype = model_config.dtype + + backend = cls.which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype) + + backend_cls = cls.get_backend_cls(backend) + + attn_type = AttentionType.attn_type_name_to_enum( + engine.workflow.attn_type) + + return backend_cls(attn_type) + + @staticmethod + def get_backend_cls(backend): + if backend == _Backend.FLASH_ATTN: + logger.info("Using FLASH ATTN backend.") + from vllm.wde.prefill_only.layers.attention.backends.flash_attn import ( # noqa: E501 + PrefillOnlyFlashAttentionBackend) + return PrefillOnlyFlashAttentionBackend + if backend == _Backend.XFORMERS: + logger.info("Using XFormers backend.") + from vllm.wde.prefill_only.layers.attention.backends.xformers import ( # noqa: E501 + PrefillOnlyXFormersBackend) + return PrefillOnlyXFormersBackend + elif backend == _Backend.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + from vllm.wde.prefill_only.layers.attention.backends.torch_sdpa import ( # noqa: E501 + PrefillOnlyTorchSDPABackend) + return PrefillOnlyTorchSDPABackend + elif backend == _Backend.FLASHINFER: + logger.info("Using Flashinfer backend.") + logger.info("When using Flashinfer backend in encode only models, " + "you are actually using FLASH ATTN backend") + from vllm.wde.prefill_only.layers.attention.backends.flashinfer import ( # noqa: E501 + PrefillOnlyFlashInferBackend) + return PrefillOnlyFlashInferBackend + elif backend == _Backend.TORCH_NAIVE: + logger.info("Using Torch naive backend.") + from vllm.wde.prefill_only.layers.attention.backends.torch_naive import ( # noqa: E501 + PrefillOnlyTorchNAIVEBackend) + return PrefillOnlyTorchNAIVEBackend + else: + raise ValueError("Invalid attention backend.") + + @classmethod + def which_attn_to_use(cls, num_heads: int, head_size: int, + num_kv_heads: int, sliding_window: Optional[int], + dtype: torch.dtype): + # Default case. + selected_backend = _Backend.FLASH_ATTN + + # get_env_variable_attn_backend + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = _Backend.backend_name_to_enum( + backend_by_env_var) + + # FlashAttn in NVIDIA GPUs. + if selected_backend == _Backend.FLASH_ATTN: + if current_platform.get_device_capability()[0] < 8: + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + selected_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + selected_backend = _Backend.XFORMERS + elif sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window." + ) + selected_backend = _Backend.XFORMERS + + return selected_backend + + +AttentionImpls_fp32 = ["TORCH_SDPA", "XFORMERS", "TORCH_NAIVE"] +AttentionImpls_fp16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] +AttentionImpls_bf16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] + +AttentionImpls = { + "float": AttentionImpls_fp32, + "half": AttentionImpls_fp16, + "bfloat16": AttentionImpls_bf16, +} \ No newline at end of file diff --git a/vllm/wde/prefill_only/processor/__init__.py b/vllm/wde/prefill_only/processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/processor/model_input_builder.py b/vllm/wde/prefill_only/processor/model_input_builder.py new file mode 100644 index 0000000000000..de754573434e9 --- /dev/null +++ b/vllm/wde/prefill_only/processor/model_input_builder.py @@ -0,0 +1,52 @@ +import torch + +from vllm.utils import is_pin_memory_available +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.processor.model_input_builder import ModelInputBuilder +from vllm.wde.core.schema.execute_io import ExecuteInput +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionMetadataBuilder) +from vllm.wde.prefill_only.schema.engine_io import PrefillOnlySchedulerOutput +from vllm.wde.prefill_only.schema.execute_io import ModelInputForGPU + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyModelInputBuilder(ModelInputBuilder): + + def __init__( + self, + attention_metadata_builder: PrefillOnlyAttentionMetadataBuilder): + self.attention_metadata_builder = attention_metadata_builder + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls(engine.attn_backend.get_builder_cls()()) + + def __call__(self, + scheduler_output: PrefillOnlySchedulerOutput) -> ExecuteInput: + input_tokens = [] + input_positions = [] + seq_lens = [] + for request in scheduler_output.scheduled_requests: + prompt_token_ids = request.inputs.prompt_token_ids + n_tokens = len(prompt_token_ids) + input_tokens.extend(prompt_token_ids) + input_positions.extend(list(range(0, n_tokens))) + seq_lens.append(n_tokens) + + input_ids = torch.tensor(input_tokens, + dtype=torch.long, + pin_memory=pin_memory, + device="cpu") + positions = torch.tensor(input_positions, + dtype=torch.long, + pin_memory=pin_memory, + device="cpu") + attn_metadata = self.attention_metadata_builder(seq_lens) + + model_input = ModelInputForGPU(input_ids=input_ids, + positions=positions, + attn_metadata=attn_metadata) + + return ExecuteInput(worker_input=None, model_input=model_input) diff --git a/vllm/wde/prefill_only/runner/__init__.py b/vllm/wde/prefill_only/runner/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/runner/model_runner.py b/vllm/wde/prefill_only/runner/model_runner.py new file mode 100644 index 0000000000000..93f6b77d59458 --- /dev/null +++ b/vllm/wde/prefill_only/runner/model_runner.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.utils import DeviceMemoryProfiler, is_pin_memory_available +from vllm.wde.core.config import DeviceConfig, LoadConfig, ModelConfig +from vllm.wde.core.layers.attention import AttentionBackend +from vllm.wde.core.schema.execute_io import ExecuteOutput +from vllm.wde.prefill_only.config import PrefillOnlySchedulerConfig +from vllm.wde.prefill_only.schema.execute_io import ModelInputForGPU + +logger = init_logger(__name__) + + +class ModelRunner: + + def __init__( + self, + model_config: ModelConfig, + scheduler_config: PrefillOnlySchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + attn_backend: AttentionBackend, + ): + self.model_config = model_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.load_config = load_config + self.attn_backend = attn_backend + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + # Lazy initialization + self.model: nn.Module # Set after load_model + + def load_model(self) -> None: + from vllm.wde.core.loader.loader import (get_model_loader, + initialize_model) + + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: + loader = get_model_loader(self.load_config) + self.model = initialize_model(model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + attn_backend=self.attn_backend) + + loader.load_model(self.model, + model_config=self.model_config, + device_config=self.device_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPU, + ) -> ExecuteOutput: + return self.model(**model_input.to_dict()) diff --git a/vllm/wde/prefill_only/scheduler.py b/vllm/wde/prefill_only/scheduler.py new file mode 100644 index 0000000000000..fb107b260d0d3 --- /dev/null +++ b/vllm/wde/prefill_only/scheduler.py @@ -0,0 +1,97 @@ +from dataclasses import dataclass, field +from typing import Set + +from vllm.logger import init_logger +from vllm.wde.core.processor.input_processor import RequestProcessor +from vllm.wde.core.scheduler import Scheduler +from vllm.wde.prefill_only.config import PrefillOnlySchedulerConfig +from vllm.wde.prefill_only.schema.engine_io import (PrefillOnlySchedulerOutput, + SchedulableRequest) + +logger = init_logger(__name__) + + +@dataclass +class SchedulingBudget: + token_budget: int + max_num_requests: int + _curr_requests: Set[str] = field(default_factory=set) + _num_batched_tokens: int = 0 + + def can_schedule(self, *, num_new_tokens: int, num_new_request: int = 1): + assert num_new_tokens != 0 + assert num_new_request != 0 + a = self.num_batched_tokens + num_new_tokens <= self.token_budget + b = self.num_curr_request + num_new_request <= self.max_num_requests + return a and b + + def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + if req_id in self._curr_requests: + return + + self._curr_requests.add(req_id) + self._num_batched_tokens += num_batched_tokens + + @property + def num_batched_tokens(self): + return self._num_batched_tokens + + @property + def num_curr_request(self): + return len(self._curr_requests) + + +class PrefillOnlyScheduler(Scheduler): + support_scheduling = ["sync_scheduling", "async_scheduling"] + + def __init__( + self, + scheduler_config: PrefillOnlySchedulerConfig, + request_processor: RequestProcessor, + ) -> None: + super().__init__(scheduler_config, request_processor) + + @classmethod + def from_engine(cls, engine): + return cls(engine.engine_config.scheduler_config, + engine.request_processor) + + def schedule(self) -> PrefillOnlySchedulerOutput: + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_requests=self.scheduler_config.max_num_requests, + ) + + waiting_queue = self.waiting + + scheduled_requests = [] + ignored_requests = [] + while waiting_queue: + request = waiting_queue[0] + if request.request_id in self.aborted_requests: + self.aborted_requests.remove(request.request_id) + waiting_queue.popleft() + continue + + if not isinstance(request, SchedulableRequest): + request = self.request_processor(request) + waiting_queue[0] = request + + num_new_tokens = request.num_new_tokens + + if num_new_tokens > self.scheduler_config.max_model_len: + self.requests.remove(request.request_id) + waiting_queue.popleft() + ignored_requests.append(request) + continue + + if not budget.can_schedule(num_new_tokens=num_new_tokens): + break + + budget.add_num_batched_tokens(request.request_id, num_new_tokens) + waiting_queue.popleft() + scheduled_requests.append(request) + + return PrefillOnlySchedulerOutput( + scheduled_requests=scheduled_requests, + ignored_requests=ignored_requests) diff --git a/vllm/wde/prefill_only/schema/__init__.py b/vllm/wde/prefill_only/schema/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/schema/engine_io.py b/vllm/wde/prefill_only/schema/engine_io.py new file mode 100644 index 0000000000000..0610a02c974e5 --- /dev/null +++ b/vllm/wde/prefill_only/schema/engine_io.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import List + +from vllm.wde.core.schema.engine_io import SchedulableRequest, SchedulerOutput + + +@dataclass +class PrefillOnlySchedulerOutput(SchedulerOutput): + scheduled_requests: List[SchedulableRequest] + ignored_requests: List[SchedulableRequest] + + def is_empty(self) -> bool: + return not self.scheduled_requests \ No newline at end of file diff --git a/vllm/wde/prefill_only/schema/execute_io.py b/vllm/wde/prefill_only/schema/execute_io.py new file mode 100644 index 0000000000000..26a419428a85b --- /dev/null +++ b/vllm/wde/prefill_only/schema/execute_io.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass + +import torch + +from vllm.wde.core.layers.attention import AttentionMetadata +from vllm.wde.core.schema.execute_io import ExecuteInput, ModelInput + + +@dataclass +class ModelInputForGPU(ModelInput): + input_ids: torch.Tensor + positions: torch.Tensor + attn_metadata: AttentionMetadata + + def to(self, target_device, non_blocking=False): + for k in self.__dict__: + self.__dict__[k] = self.__dict__[k].to(device=target_device, + non_blocking=non_blocking) + + def to_dict(self): + return self.__dict__ + + +class PrefillOnlyExecuteInput(ExecuteInput): + worker_input = None + model_input: ModelInputForGPU diff --git a/vllm/wde/prefill_only/worker/__init__.py b/vllm/wde/prefill_only/worker/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/worker/gpu_worker.py b/vllm/wde/prefill_only/worker/gpu_worker.py new file mode 100644 index 0000000000000..017ff277c4c86 --- /dev/null +++ b/vllm/wde/prefill_only/worker/gpu_worker.py @@ -0,0 +1,91 @@ +import os + +import torch + +from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform +from vllm.wde.core.config import (DeviceConfig, EngineConfig, LoadConfig, + ModelConfig) +from vllm.wde.core.layers.attention import AttentionBackend +from vllm.wde.core.schema.execute_io import ExecuteOutput +from vllm.wde.core.worker import WorkerBase +from vllm.wde.prefill_only.config import PrefillOnlySchedulerConfig +from vllm.wde.prefill_only.runner.model_runner import ModelRunner +from vllm.wde.prefill_only.schema.execute_io import PrefillOnlyExecuteInput + + +class Worker(WorkerBase): + + def __init__( + self, + engine_config: EngineConfig, + attn_backend: AttentionBackend, + ) -> None: + self.model_config: ModelConfig = engine_config.model_config + self.scheduler_config: PrefillOnlySchedulerConfig = ( + engine_config.scheduler_config) + self.device_config: DeviceConfig = engine_config.device_config + self.load_config: LoadConfig = engine_config.load_config + self.device = self.device_config.device + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner = ModelRunner(self.model_config, + self.scheduler_config, + self.device_config, self.load_config, + attn_backend) + + def init_device(self) -> None: + from vllm.wde.core.worker.utils import fix_distributed_environment + + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + + fix_distributed_environment() + + # Set random seed. + set_random_seed(self.model_config.seed) + + @torch.inference_mode + def load_model(self): + self.model_runner.load_model() + + @torch.inference_mode + def __call__(self, + execute_input: PrefillOnlyExecuteInput) -> ExecuteOutput: + output = self.model_runner.execute_model(execute_input.model_input) + return output + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: + compute_capability = current_platform.get_device_capability() + if compute_capability[0] < 8: + gpu_name = torch.cuda.get_device_name() + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU has compute capability " + f"{compute_capability[0]}.{compute_capability[1]}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/wde/prefill_only/workflow.py b/vllm/wde/prefill_only/workflow.py new file mode 100644 index 0000000000000..bb9193fe23981 --- /dev/null +++ b/vllm/wde/prefill_only/workflow.py @@ -0,0 +1,30 @@ +from vllm.wde.core.workflow import Workflow + + +class PrefillOnlyWorkflow(Workflow): + InputProcessor: str = ("vllm.wde.core.processor." + "input_processor:TextInputProcessor") + RequestProcessor: str = ("vllm.wde.core.processor." + "input_processor:TextRequestProcessor") + + ModelInputBuilder: str = ( + "vllm.wde.prefill_only.processor." + "model_input_builder:PrefillOnlyModelInputBuilder") + Worker: str = "vllm.wde.prefill_only.worker.gpu_worker:Worker" + Executor: str = "vllm.wde.prefill_only.executor.gpu_executor" + Scheduler: str = "vllm.wde.prefill_only.scheduler:PrefillOnlyScheduler" + AttnBackend: str = ("vllm.wde.prefill_only.layers." + "attention.selector:AttnBackend") + + @classmethod + def from_engine(cls, engine): + workflow = cls() + + if engine.engine_config.scheduler_config.scheduling in ["sync"]: + workflow.Executor += ":GPUExecutor" + elif engine.engine_config.scheduler_config.scheduling in [ + "async", "double_buffer" + ]: + workflow.Executor += ":GPUAsyncExecutor" + + return workflow diff --git a/vllm/worker/prefill_only_gpu_worker.py b/vllm/worker/prefill_only_gpu_worker.py new file mode 100644 index 0000000000000..a415e3394a8fc --- /dev/null +++ b/vllm/worker/prefill_only_gpu_worker.py @@ -0,0 +1,192 @@ +import importlib +import os +from abc import ABC, abstractmethod +from typing import Callable, Dict, Optional, Type + +import torch + +from vllm.attention.prefill_only.abstract import PrefillOnlyAttentionBackend +from vllm.config import (DeviceConfig, EngineConfig, LoadConfig, ModelConfig, + SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor.prefill_only.execute_io import ( + ExecuteInput, ExecuteOutput, PrefillOnlyExecuteInput) +from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform +from vllm.utils import (enable_trace_function_call_for_thread, + update_environment_variables) + +from .prefill_only_model_runner import ModelRunner + +logger = init_logger(__name__) + + +class WorkerBase(ABC): + + @abstractmethod + def __call__(self, execute_input: ExecuteInput) -> ExecuteOutput: + raise NotImplementedError + + +class WorkerWrapperBase: + """ + The whole point of this class is to lazily initialize the worker. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + + If worker_class_fn is specified, it will be executed to get the worker + class. + Otherwise, the worker class will be obtained by dynamically importing it + using worker_module_name and worker_class_name. + """ + + def __init__( + self, + worker_module_name: str, + worker_class_name: str, + trust_remote_code: bool = False, + worker_class_fn: Optional[Callable[[], + Type[WorkerBase]]] = None) -> None: + self.worker_module_name = worker_module_name + self.worker_class_name = worker_class_name + self.worker_class_fn = worker_class_fn + self.worker: Optional[WorkerBase] = None + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, *args, **kwargs): + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. + """ + enable_trace_function_call_for_thread() + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + if self.worker_class_fn: + worker_class = self.worker_class_fn() + else: + mod = importlib.import_module(self.worker_module_name) + worker_class = getattr(mod, self.worker_class_name) + + self.worker = worker_class(*args, **kwargs) + assert self.worker is not None + + def execute_method(self, method, *args, **kwargs): + try: + target = self if self.worker is None else self.worker + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + + +def create_worker(module, envs=None, **kwargs): + module_name, class_name = module.split(":") + wrapper = WorkerWrapperBase( + worker_module_name=module_name, + worker_class_name=class_name, + ) + if envs: + wrapper.update_environment_variables(envs) + + wrapper.init_worker(**kwargs) + return wrapper.worker + + +class Worker(WorkerBase): + + def __init__( + self, + engine_config: EngineConfig, + attn_backend: PrefillOnlyAttentionBackend, + ) -> None: + self.model_config: ModelConfig = engine_config.model_config + self.scheduler_config: SchedulerConfig = ( + engine_config.scheduler_config) + self.device_config: DeviceConfig = engine_config.device_config + self.load_config: LoadConfig = engine_config.load_config + self.device = self.device_config.device + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner = ModelRunner(self.model_config, + self.scheduler_config, + self.device_config, self.load_config, + attn_backend) + + def init_device(self) -> None: + from vllm.model_executor.prefill_only.utils import ( + fix_distributed_environment) + + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + + fix_distributed_environment() + + # Set random seed. + set_random_seed(self.model_config.seed) + + @torch.inference_mode + def load_model(self): + self.model_runner.load_model() + + @torch.inference_mode + def __call__(self, + execute_input: PrefillOnlyExecuteInput) -> ExecuteOutput: + output = self.model_runner.execute_model(execute_input.model_input) + return output + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: + compute_capability = current_platform.get_device_capability() + if compute_capability[0] < 8: + gpu_name = torch.cuda.get_device_name() + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU has compute capability " + f"{compute_capability[0]}.{compute_capability[1]}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/worker/prefill_only_model_runner.py b/vllm/worker/prefill_only_model_runner.py new file mode 100644 index 0000000000000..e5c8cf5b762a0 --- /dev/null +++ b/vllm/worker/prefill_only_model_runner.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn + +from vllm.attention.prefill_only.abstract import PrefillOnlyAttentionBackend +from vllm.config import DeviceConfig, LoadConfig, ModelConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.model_executor.prefill_only.execute_io import (ExecuteOutput, + ModelInputForGPU) +from vllm.utils import DeviceMemoryProfiler, is_pin_memory_available + +logger = init_logger(__name__) + + +class ModelRunner: + + def __init__( + self, + model_config: ModelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + attn_backend: PrefillOnlyAttentionBackend, + ): + self.model_config = model_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.load_config = load_config + self.attn_backend = attn_backend + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + # Lazy initialization + self.model: nn.Module # Set after load_model + + def load_model(self) -> None: + from vllm.model_executor.prefill_only.loader import (get_model_loader, + initialize_model) + + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: + loader = get_model_loader(self.load_config) + self.model = initialize_model(model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + attn_backend=self.attn_backend) + + loader.load_model(self.model, + model_config=self.model_config, + device_config=self.device_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPU, + ) -> ExecuteOutput: + return self.model(**model_input.to_dict())