From a921d8be9dd8b98795b4d8076f3af4f48dc3d24d Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 22 Nov 2023 21:31:27 +0100 Subject: [PATCH 01/13] =?UTF-8?q?[DOCS]=C2=A0Add=20engine=20args=20documen?= =?UTF-8?q?tation=20(#1741)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/index.rst | 1 + docs/source/models/engine_args.rst | 114 +++++++++++++++++++++++++++++ vllm/engine/arg_utils.py | 4 + 3 files changed, 119 insertions(+) create mode 100644 docs/source/models/engine_args.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index eb98aa6049bfb..caa1935cbfe46 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -73,6 +73,7 @@ Documentation models/supported_models models/adding_model + models/engine_args .. toctree:: :maxdepth: 1 diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst new file mode 100644 index 0000000000000..a70c22e9af11a --- /dev/null +++ b/docs/source/models/engine_args.rst @@ -0,0 +1,114 @@ +.. _engine_args: + +Engine Arguments +================ + +Below, you can find an explanation of every engine argument for vLLM: + +.. option:: --model + + Name or path of the huggingface model to use. + +.. option:: --tokenizer + + Name or path of the huggingface tokenizer to use. + +.. option:: --revision + + The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + +.. option:: --tokenizer-revision + + The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + +.. option:: --tokenizer-mode {auto,slow} + + The tokenizer mode. + + * "auto" will use the fast tokenizer if available. + * "slow" will always use the slow tokenizer. + +.. option:: --trust-remote-code + + Trust remote code from huggingface. + +.. option:: --download-dir + + Directory to download and load the weights, default to the default cache dir of huggingface. + +.. option:: --load-format {auto,pt,safetensors,npcache,dummy} + + The format of the model weights to load. + + * "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. + * "pt" will load the weights in the pytorch bin format. + * "safetensors" will load the weights in the safetensors format. + * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. + * "dummy" will initialize the weights with random values, mainly for profiling. + +.. option:: --dtype {auto,half,float16,bfloat16,float,float32} + + Data type for model weights and activations. + + * "auto" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. + * "half" for FP16. Recommended for AWQ quantization. + * "float16" is the same as "half". + * "bfloat16" for a balance between precision and range. + * "float" is shorthand for FP32 precision. + * "float32" for FP32 precision. + +.. option:: --max-model-len + + Model context length. If unspecified, will be automatically derived from the model config. + +.. option:: --worker-use-ray + + Use Ray for distributed serving, will be automatically set when using more than 1 GPU. + +.. option:: --pipeline-parallel-size (-pp) + + Number of pipeline stages. + +.. option:: --tensor-parallel-size (-tp) + + Number of tensor parallel replicas. + +.. option:: --max-parallel-loading-workers + + Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models. + +.. option:: --block-size {8,16,32} + + Token block size for contiguous chunks of tokens. + +.. option:: --seed + + Random seed for operations. + +.. option:: --swap-space + + CPU swap space size (GiB) per GPU. + +.. option:: --gpu-memory-utilization + + The percentage of GPU memory to be used for the model executor. + +.. option:: --max-num-batched-tokens + + Maximum number of batched tokens per iteration. + +.. option:: --max-num-seqs + + Maximum number of sequences per iteration. + +.. option:: --max-paddings + + Maximum number of paddings in a batch. + +.. option:: --disable-log-stats + + Disable logging statistics. + +.. option:: --quantization (-q) {awq,squeezellm,None} + + Method used to quantize the weights. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c7e476c704740..746b0e64ece7b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -42,6 +42,10 @@ def __post_init__(self): def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Shared CLI arguments for vLLM engine.""" + + # NOTE: If you update any of the arguments below, please also + # make sure to update docs/source/models/engine_args.rst + # Model arguments parser.add_argument( '--model', From 4cea74c73b2e0981aadfefb3a00e8186d065c897 Mon Sep 17 00:00:00 2001 From: ljss <31004720+beginlner@users.noreply.github.com> Date: Thu, 23 Nov 2023 04:51:09 +0800 Subject: [PATCH 02/13] Set top_p=0 and top_k=-1 in greedy sampling (#1748) --- vllm/sampling_params.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 68ba762f899dc..f9eca1a9fc43c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -147,6 +147,8 @@ def __init__( self._verify_non_beam_search() if self.temperature < _SAMPLING_EPS: # Zero temperature means greedy sampling. + self.top_p = 1.0 + self.top_k = -1 self._verify_greedy_sampling() def _verify_args(self) -> None: @@ -214,10 +216,6 @@ def _verify_greedy_sampling(self) -> None: if self.best_of > 1: raise ValueError("best_of must be 1 when using greedy sampling." f"Got {self.best_of}.") - if self.top_p < 1.0 - _SAMPLING_EPS: - raise ValueError("top_p must be 1 when using greedy sampling.") - if self.top_k != -1: - raise ValueError("top_k must be -1 when using greedy sampling.") @cached_property def sampling_type(self) -> SamplingType: From de23687d168ebeaa8872c27f05b8292bab0fac71 Mon Sep 17 00:00:00 2001 From: ljss <31004720+beginlner@users.noreply.github.com> Date: Thu, 23 Nov 2023 06:41:44 +0800 Subject: [PATCH 03/13] Fix repetition penalty aligned with huggingface (#1577) --- vllm/model_executor/layers/sampler.py | 76 +++++++++++++++++---------- vllm/sampling_params.py | 6 +-- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 9fcc2f20675c0..c874ec5921155 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -21,7 +21,7 @@ class Sampler(nn.Module): 1. Discard the hidden states that are not used for sampling (i.e., all tokens except the final one in each prompt). 2. Compute the logits for the next tokens. - 3. Apply presence and frequency penalties. + 3. Apply presence, frequency and repetition penalties. 4. Apply temperature scaling. 5. Apply top-p and top-k truncation. 6. Sample the next tokens. @@ -50,14 +50,12 @@ def forward( # Apply logits processors (if any). logits = _apply_logits_processors(logits, input_metadata) # Apply presence and frequency penalties. - output_tokens = _get_output_tokens(input_metadata) - assert len(output_tokens) == logits.shape[0] presence_penalties, frequency_penalties, repetition_penalties = ( _get_penalties(input_metadata)) assert len(presence_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0] assert len(repetition_penalties) == logits.shape[0] - logits = _apply_penalties(logits, output_tokens, presence_penalties, + logits = _apply_penalties(logits, input_metadata, presence_penalties, frequency_penalties, repetition_penalties) # Apply temperature scaling. @@ -146,7 +144,10 @@ def _get_penalties( return presence_penalties, frequency_penalties, repetition_penalties -def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: +def _get_prompt_and_output_tokens( + input_metadata: InputMetadata +) -> Tuple[List[List[int]], List[List[int]]]: + prompt_tokens: List[List[int]] = [] output_tokens: List[List[int]] = [] for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group @@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: # NOTE: prompt token positions do not need output tokens to # compute penalties. prompt_len = input_metadata.prompt_lens[i] + prompt_tokens.extend([] for _ in range(prompt_len - 1)) output_tokens.extend([] for _ in range(prompt_len - 1)) for seq_id in seq_ids: seq_data = input_metadata.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) output_tokens.append(seq_data.output_token_ids) - return output_tokens + return prompt_tokens, output_tokens + + +def _get_bin_counts_and_mask( + logits: torch.Tensor, + tokens: List[List[int]], + vocab_size: int, + num_seqs: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + max_len = max(len(tokens) for tokens in tokens) + padded_tokens = [ + tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens + ] + tokens_tensor = torch.tensor(padded_tokens, + dtype=torch.long, + device=logits.device) + + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=logits.device) + bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask def _apply_logits_processors(logits: torch.Tensor, @@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor, def _apply_penalties( logits: torch.Tensor, - output_tokens: List[List[int]], + input_metadata: InputMetadata, presence_penalties: List[float], frequency_penalties: List[float], repetition_penalties: List[float], ) -> torch.Tensor: num_seqs, vocab_size = logits.shape for i in range(num_seqs): - if not output_tokens[i]: - continue p = presence_penalties[i] f = frequency_penalties[i] r = repetition_penalties[i] @@ -206,24 +233,15 @@ def _apply_penalties( # Return early if all sequences have zero penalties. return logits - max_output_len = max(len(tokens) for tokens in output_tokens) - padded_output_tokens = [ - tokens + [vocab_size] * (max_output_len - len(tokens)) - for tokens in output_tokens - ] - output_tokens_tensor = torch.tensor(padded_output_tokens, - dtype=torch.long, - device=logits.device) + prompt_tokens, output_tokens = ( + _get_prompt_and_output_tokens(input_metadata)) + assert len(prompt_tokens) == logits.shape[0] + assert len(output_tokens) == logits.shape[0] - # Compute the bin counts for the output tokens. - # vocab_size + 1 for padding. - bin_counts = torch.zeros((num_seqs, vocab_size + 1), - dtype=torch.long, - device=logits.device) - bin_counts.scatter_add_(1, output_tokens_tensor, - torch.ones_like(output_tokens_tensor)) - bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin. - mask = bin_counts > 0 + prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask( + logits, prompt_tokens, vocab_size, num_seqs) + output_bin_counts, output_mask = _get_bin_counts_and_mask( + logits, output_tokens, vocab_size, num_seqs) repetition_penalties = torch.tensor(repetition_penalties, dtype=logits.dtype, @@ -236,14 +254,14 @@ def _apply_penalties( device=logits.device) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) - repetition_penalties[~mask] = 1.0 + repetition_penalties[~(prompt_mask | output_mask)] = 1.0 logits = torch.where(logits > 0, logits / repetition_penalties, logits * repetition_penalties) # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts - logits -= presence_penalties.unsqueeze(dim=1) * mask + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index f9eca1a9fc43c..5a08169c48a36 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -42,9 +42,9 @@ class SamplingParams: model to use new tokens, while values < 0 encourage the model to repeat tokens. repetition_penalty: Float that penalizes new tokens based on whether - they appear in the generated text so far. Values > 1 encourage the - model to use new tokens, while values < 1 encourage the model to - repeat tokens. + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens. temperature: Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling. From e0c6f556e85053059c74ab6b5cee396baf3b4316 Mon Sep 17 00:00:00 2001 From: Yanming W Date: Thu, 23 Nov 2023 16:31:19 -0800 Subject: [PATCH 04/13] [Build] Avoid building too many extensions (#1624) --- .../kernels/benchmark_paged_attention.py | 6 +- csrc/activation.cpp | 28 ------- csrc/attention.cpp | 42 ---------- csrc/{cache.cpp => cache.h} | 19 ----- csrc/cuda_utils.cpp | 13 --- csrc/cuda_utils.h | 5 ++ csrc/layernorm.cpp | 24 ------ csrc/ops.h | 75 +++++++++++++++++ csrc/pos_encoding.cpp | 16 ---- csrc/pybind.cpp | 80 ++++++++++++++++++ csrc/quantization.cpp | 19 ----- setup.py | 82 +++---------------- tests/kernels/test_activation.py | 8 +- tests/kernels/test_attention.py | 6 +- tests/kernels/test_cache.py | 2 +- tests/kernels/test_layernorm.py | 4 +- tests/kernels/test_pos_encoding.py | 4 +- vllm/model_executor/layers/activation.py | 8 +- vllm/model_executor/layers/attention.py | 8 +- vllm/model_executor/layers/layernorm.py | 6 +- .../model_executor/layers/quantization/awq.py | 5 +- .../layers/quantization/squeezellm.py | 5 +- .../model_executor/layers/rotary_embedding.py | 9 +- vllm/utils.py | 2 +- vllm/worker/cache_engine.py | 2 +- 25 files changed, 206 insertions(+), 272 deletions(-) delete mode 100644 csrc/activation.cpp delete mode 100644 csrc/attention.cpp rename csrc/{cache.cpp => cache.h} (58%) delete mode 100644 csrc/cuda_utils.cpp create mode 100644 csrc/cuda_utils.h delete mode 100644 csrc/layernorm.cpp create mode 100644 csrc/ops.h delete mode 100644 csrc/pos_encoding.cpp create mode 100644 csrc/pybind.cpp delete mode 100644 csrc/quantization.cpp diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 0ef8030767677..91fcf5340298a 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -4,7 +4,7 @@ import torch -from vllm import attention_ops +from vllm._C import ops NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -98,7 +98,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: for _ in range(num_iters): if version == "v1": - attention_ops.paged_attention_v1( + ops.paged_attention_v1( output, query, key_cache, @@ -112,7 +112,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: alibi_slopes, ) elif version == "v2": - attention_ops.paged_attention_v2( + ops.paged_attention_v2( output, exp_sums, max_logits, diff --git a/csrc/activation.cpp b/csrc/activation.cpp deleted file mode 100644 index c100f89ac7377..0000000000000 --- a/csrc/activation.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include - -void silu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_new( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_fast( - torch::Tensor& out, - torch::Tensor& input); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - m.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - m.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); -} diff --git a/csrc/attention.cpp b/csrc/attention.cpp deleted file mode 100644 index bd93fd71b733d..0000000000000 --- a/csrc/attention.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include -#include - -void paged_attention_v1( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& head_mapping, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes); - -void paged_attention_v2( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& head_mapping, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - m.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); -} diff --git a/csrc/cache.cpp b/csrc/cache.h similarity index 58% rename from csrc/cache.cpp rename to csrc/cache.h index 9ae17bb2985c6..da49d9103214b 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.h @@ -26,22 +26,3 @@ void gather_cached_kv( torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - m.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - m.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); - m.def( - "gather_cached_kv", - &gather_cached_kv, - "Gather key and value from the cache into contiguous QKV tensors"); -} diff --git a/csrc/cuda_utils.cpp b/csrc/cuda_utils.cpp deleted file mode 100644 index e7f22ec89d7b4..0000000000000 --- a/csrc/cuda_utils.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include - -int get_device_attribute( - int attribute, - int device_id); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "get_device_attribute", - &get_device_attribute, - "Gets the specified device attribute."); -} - diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h new file mode 100644 index 0000000000000..85cb199b9aa0c --- /dev/null +++ b/csrc/cuda_utils.h @@ -0,0 +1,5 @@ +#include + +int get_device_attribute( + int attribute, + int device_id); diff --git a/csrc/layernorm.cpp b/csrc/layernorm.cpp deleted file mode 100644 index c341a7097962c..0000000000000 --- a/csrc/layernorm.cpp +++ /dev/null @@ -1,24 +0,0 @@ -#include - -void rms_norm( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& weight, - float epsilon); - -void fused_add_rms_norm( - torch::Tensor& input, - torch::Tensor& residual, - torch::Tensor& weight, - float epsilon); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); - m.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); -} diff --git a/csrc/ops.h b/csrc/ops.h new file mode 100644 index 0000000000000..cfb18fbefd7a9 --- /dev/null +++ b/csrc/ops.h @@ -0,0 +1,75 @@ +#include + +void paged_attention_v1( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len, + const c10::optional& alibi_slopes); + +void paged_attention_v2( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len, + const c10::optional& alibi_slopes); + +void rms_norm( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& weight, + float epsilon); + +void fused_add_rms_norm( + torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, + float epsilon); + +void rotary_embedding( + torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + int head_size, + torch::Tensor& cos_sin_cache, + bool is_neox); + +void silu_and_mul( + torch::Tensor& out, + torch::Tensor& input); + +void gelu_new( + torch::Tensor& out, + torch::Tensor& input); + +void gelu_fast( + torch::Tensor& out, + torch::Tensor& input); + +torch::Tensor awq_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters); + +void squeezellm_gemm( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor lookup_table); diff --git a/csrc/pos_encoding.cpp b/csrc/pos_encoding.cpp deleted file mode 100644 index eee0cf0d0fa09..0000000000000 --- a/csrc/pos_encoding.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -void rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); -} diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp new file mode 100644 index 0000000000000..9e31429690021 --- /dev/null +++ b/csrc/pybind.cpp @@ -0,0 +1,80 @@ +#include "cache.h" +#include "cuda_utils.h" +#include "ops.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // vLLM custom ops + pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); + + // Attention ops + ops.def( + "paged_attention_v1", + &paged_attention_v1, + "Compute the attention between an input query and the cached keys/values using PagedAttention."); + ops.def( + "paged_attention_v2", + &paged_attention_v2, + "PagedAttention V2."); + + // Activation ops + ops.def( + "silu_and_mul", + &silu_and_mul, + "Activation function used in SwiGLU."); + ops.def( + "gelu_new", + &gelu_new, + "GELU implementation used in GPT-2."); + ops.def( + "gelu_fast", + &gelu_fast, + "Approximate GELU implementation."); + + // Layernorm + ops.def( + "rms_norm", + &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + ops.def( + "fused_add_rms_norm", + &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); + + // Rotary embedding + ops.def( + "rotary_embedding", + &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + + // Quantization ops + ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); + + // Cache ops + pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); + cache_ops.def( + "swap_blocks", + &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def( + "copy_blocks", + ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def( + "reshape_and_cache", + &reshape_and_cache, + "Reshape the key and value tensors and cache them"); + cache_ops.def( + "gather_cached_kv", + &gather_cached_kv, + "Gather key and value from the cache into contiguous QKV tensors"); + + // Cuda utils + pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); + cuda_utils.def( + "get_device_attribute", + &get_device_attribute, + "Gets the specified device attribute."); +} diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp deleted file mode 100644 index dfe17a496c780..0000000000000 --- a/csrc/quantization.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include - -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); - -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); -} diff --git a/setup.py b/setup.py index 36f4913435628..2b040e88f0aa4 100644 --- a/setup.py +++ b/setup.py @@ -142,87 +142,25 @@ def get_torch_arch_list() -> Set[str]: NVCC_FLAGS += ["--threads", str(num_threads)] ext_modules = [] - -# Cache operations. -cache_extension = CUDAExtension( - name="vllm.cache_ops", - sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(cache_extension) - -# Attention kernels. -attention_extension = CUDAExtension( - name="vllm.attention_ops", - sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(attention_extension) - -# Positional encoding kernels. -positional_encoding_extension = CUDAExtension( - name="vllm.pos_encoding_ops", - sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(positional_encoding_extension) - -# Layer normalization kernels. -layernorm_extension = CUDAExtension( - name="vllm.layernorm_ops", - sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(layernorm_extension) - -# Activation kernels. -activation_extension = CUDAExtension( - name="vllm.activation_ops", - sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(activation_extension) - -# Quantization kernels. -quantization_extension = CUDAExtension( - name="vllm.quantization_ops", +vllm_extension = CUDAExtension( + name="vllm._C", sources=[ - "csrc/quantization.cpp", + "csrc/cache_kernels.cu", + "csrc/attention/attention_kernels.cu", + "csrc/pos_encoding_kernels.cu", + "csrc/activation_kernels.cu", + "csrc/layernorm_kernels.cu", "csrc/quantization/awq/gemm_kernels.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + "csrc/cuda_utils_kernels.cu", + "csrc/pybind.cpp", ], extra_compile_args={ "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS, }, ) -ext_modules.append(quantization_extension) - -# Misc. CUDA utils. -cuda_utils_extension = CUDAExtension( - name="vllm.cuda_utils", - sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(cuda_utils_extension) +ext_modules.append(vllm_extension) def get_path(*filepath) -> str: diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 0b3ad0aa255a1..978b377ea94d4 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from transformers.activations import get_activation -from vllm import activation_ops +from vllm._C import ops DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing @@ -31,7 +31,7 @@ def test_silu_and_mul( torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") - activation_ops.silu_and_mul(out, x) + ops.silu_and_mul(out, x) ref_out = ref_silu_and_mul(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -51,7 +51,7 @@ def test_gelu_new( torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") - activation_ops.gelu_new(out, x) + ops.gelu_new(out, x) ref_out = get_activation("gelu_new")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -70,6 +70,6 @@ def test_gelu_fast( torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") - activation_ops.gelu_fast(out, x) + ops.gelu_fast(out, x) ref_out = get_activation("gelu_fast")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index e76416d88311d..a65d4d54d7c82 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -6,7 +6,7 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm import attention_ops +from vllm._C import ops from vllm.utils import get_max_shared_memory_bytes FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -165,7 +165,7 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) if version == "v1": - attention_ops.paged_attention_v1( + ops.paged_attention_v1( output, query, key_cache, @@ -194,7 +194,7 @@ def test_paged_attention( device=output.device, ) max_logits = torch.empty_like(exp_sums) - attention_ops.paged_attention_v2( + ops.paged_attention_v2( output, exp_sums, max_logits, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index e15e7ba91bcb0..9b5d7687a3fec 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -3,7 +3,7 @@ import pytest import torch -from vllm import cache_ops +from vllm._C import cache_ops DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [83] # Arbitrary values for testing diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index a63ef5cc76ffd..ee5228d68e4db 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from vllm import layernorm_ops +from vllm._C import ops DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing @@ -48,7 +48,7 @@ def test_rms_norm( ref = RefRMSNorm(hidden_size).to(dtype).cuda() out = torch.empty_like(x) - layernorm_ops.rms_norm( + ops.rms_norm( out, x, ref.weight.data, diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index d660417440844..7d22bdab4625b 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from vllm import pos_encoding_ops +from vllm._C import ops IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -145,7 +145,7 @@ def test_rotary_embedding( # Run the kernel. The kernel is in-place, so we need to clone the inputs. out_query = query.clone() out_key = key.clone() - pos_encoding_ops.rotary_embedding( + ops.rotary_embedding( positions, out_query, out_key, diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index ecab0c8d3256a..5c0def823edea 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm import activation_ops +from vllm._C import ops from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -26,7 +26,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - activation_ops.silu_and_mul(out, x) + ops.silu_and_mul(out, x) return out @@ -34,7 +34,7 @@ class NewGELU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) - activation_ops.gelu_new(out, x) + ops.gelu_new(out, x) return out @@ -42,7 +42,7 @@ class FastGELU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) - activation_ops.gelu_fast(out, x) + ops.gelu_fast(out, x) return out diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index e51bb311decd9..63271ba5b9327 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -7,8 +7,8 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) -from vllm import attention_ops -from vllm import cache_ops +from vllm._C import ops +from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.rotary_embedding import get_rope @@ -163,7 +163,7 @@ def single_query_cached_kv_attention( max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: # Run PagedAttention V1. - attention_ops.paged_attention_v1( + ops.paged_attention_v1( output, query, key_cache, @@ -190,7 +190,7 @@ def single_query_cached_kv_attention( device=output.device, ) max_logits = torch.empty_like(exp_sums) - attention_ops.paged_attention_v2( + ops.paged_attention_v2( output, exp_sums, max_logits, diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 275efa0b7dc3f..69fba087099ef 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm import layernorm_ops +from vllm._C import ops class RMSNorm(nn.Module): @@ -29,7 +29,7 @@ def forward( residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if residual is not None: - layernorm_ops.fused_add_rms_norm( + ops.fused_add_rms_norm( x, residual, self.weight.data, @@ -37,7 +37,7 @@ def forward( ) return x, residual out = torch.empty_like(x) - layernorm_ops.rms_norm( + ops.rms_norm( out, x, self.weight.data, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 0ab5819d930aa..95d419e64f049 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,7 +3,7 @@ import torch from torch.nn.parameter import Parameter -from vllm import quantization_ops +from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -151,8 +151,7 @@ def apply_weights(self, pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) - out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros, - pack_factor) + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: out = out + bias return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 61ec8b79b6ddc..aa6bd0652424f 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -3,7 +3,7 @@ import torch from torch.nn.parameter import Parameter -from vllm import quantization_ops +from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -116,8 +116,7 @@ def apply_weights(self, reshaped_x = x.reshape(-1, x.shape[-1]) # NOTE: The output tensor should be zero-initialized. out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, qweight, out, - lookup_table) + ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) if bias is not None: out = out + bias diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 1b88e9a3b8057..162bb0b533e4f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn -from vllm import pos_encoding_ops +from vllm._C import ops class RotaryEmbedding(nn.Module): @@ -87,11 +87,10 @@ def forward( query: torch.Tensor, key: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - # pos_encoding_ops.rotary_embedding() is an in-place operation that + # ops.rotary_embedding() is an in-place operation that # updates the query and key tensors. - pos_encoding_ops.rotary_embedding(positions, query, key, - self.head_size, self.cos_sin_cache, - self.is_neox_style) + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key diff --git a/vllm/utils.py b/vllm/utils.py index 34d3084856af8..47e51048fed45 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,7 +5,7 @@ import psutil import torch -from vllm import cuda_utils +from vllm._C import cuda_utils class Device(enum.Enum): diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index cdb7902082653..1dd0243f8f3a3 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -3,7 +3,7 @@ import torch -from vllm import cache_ops +from vllm._C import cache_ops from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger from vllm.utils import in_wsl From 7c600440f7560348e571f021f2b2d1469de5264d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 23 Nov 2023 23:04:44 -0800 Subject: [PATCH 05/13] Fix model docstrings (#1764) --- vllm/model_executor/models/aquila.py | 6 +----- vllm/model_executor/models/baichuan.py | 6 +----- vllm/model_executor/models/bloom.py | 6 +----- vllm/model_executor/models/chatglm.py | 6 +----- vllm/model_executor/models/gpt2.py | 6 +----- vllm/model_executor/models/gpt_bigcode.py | 6 +----- vllm/model_executor/models/gpt_j.py | 6 +----- vllm/model_executor/models/gpt_neox.py | 6 +----- vllm/model_executor/models/llama.py | 6 +----- vllm/model_executor/models/mistral.py | 6 +----- vllm/model_executor/models/opt.py | 6 +----- vllm/model_executor/models/phi_1_5.py | 6 +----- vllm/model_executor/models/qwen.py | 6 +----- vllm/model_executor/models/yi.py | 6 +----- 14 files changed, 14 insertions(+), 70 deletions(-) diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 8372da562cf2e..889239cdb4e0e 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -20,11 +20,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only LLaMA model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only LLaMA model compatible with HuggingFace weights.""" from typing import Any, Dict, List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 93cbc1a8516a7..61cc2192b01bb 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -17,11 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only BaiChuan model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only BaiChuan model compatible with HuggingFace weights.""" import math from typing import List, Optional, Tuple diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 0eb3fdbb9ae3a..99ccd7442f31b 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -15,11 +15,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only BLOOM model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only BLOOM model compatible with HuggingFace weights.""" import math from typing import List, Optional, Tuple diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 2a113a155aedd..db426a94214cf 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -1,11 +1,7 @@ # coding=utf-8 # Adapted from # https://github.com/THUDM/ChatGLM2-6B -"""Inference-only ChatGLM model compatible with THUDM weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only ChatGLM model compatible with THUDM weights.""" from typing import List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 0f9f74d32ae3c..5dce59f77eea2 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -16,11 +16,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only GPT-2 model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only GPT-2 model compatible with HuggingFace weights.""" from typing import List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 47a5d7711e370..9b69fc90b13aa 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -17,11 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only GPTBigCode model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" from typing import List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 9093d642a68fb..1f0f7d4206c88 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -15,11 +15,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only GPT-J model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only GPT-J model compatible with HuggingFace weights.""" from typing import List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 8c0667d88d953..b289ddc51da85 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -15,11 +15,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only GPT-NeoX model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" from typing import List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c3192e8069703..8e7344da4888e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -20,11 +20,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only LLaMA model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only LLaMA model compatible with HuggingFace weights.""" from typing import Any, Dict, List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 793e25b635978..d18572610741c 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -20,11 +20,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only Mistral model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only Mistral model compatible with HuggingFace weights.""" from typing import List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 4c8ff596b4732..8d88ccd706eaa 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -16,11 +16,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only OPT model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only OPT model compatible with HuggingFace weights.""" from typing import List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index 18cd40f39a0af..7ef614601da39 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -34,11 +34,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Inference-only Phi-1.5 model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only Phi-1.5 model compatible with HuggingFace weights.""" from typing import List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index ce13cae7ee002..d581838f6ce8f 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -3,11 +3,7 @@ # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE -"""Inference-only QWen model compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only QWen model compatible with HuggingFace weights.""" from typing import Any, Dict, List, Optional, Tuple import torch diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index 8faa106f202f5..c457132855cdc 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -20,11 +20,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only Yi model (https://01.ai) compatible with HuggingFace weights. - -The input of the model is flattened to a 1D tensor of tokens. The model uses -InputMetadata to extract the original 2D shape of the input. -""" +"""Inference-only Yi model (https://01.ai) compatible with HuggingFace weights.""" from typing import Any, Dict, List, Optional, Tuple import torch From 665cbcec4b963f6ab7b696f3d7e3393a7909003d Mon Sep 17 00:00:00 2001 From: Yunmo Chen <16273544+wanmok@users.noreply.github.com> Date: Mon, 27 Nov 2023 13:29:17 +0800 Subject: [PATCH 06/13] Added echo function to OpenAI API server. (#1504) --- vllm/entrypoints/openai/api_server.py | 92 ++++++++++++++++++++------- vllm/entrypoints/openai/protocol.py | 3 +- 2 files changed, 71 insertions(+), 24 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a9c9fbed0cbaa..4143e1af8ae04 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -160,16 +160,26 @@ async def show_available_models(): return ModelList(data=model_cards) -def create_logprobs(token_ids: List[int], - id_logprobs: List[Dict[int, float]], - initial_text_offset: int = 0) -> LogProbs: +def create_logprobs( + token_ids: List[int], + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, + num_output_top_logprobs: Optional[int] = None, + initial_text_offset: int = 0, +) -> LogProbs: """Create OpenAI-style logprobs.""" logprobs = LogProbs() last_token_len = 0 - for token_id, id_logprob in zip(token_ids, id_logprobs): + if num_output_top_logprobs: + logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is not None: + token_logprob = step_top_logprobs[token_id] + else: + token_logprob = None token = tokenizer.convert_ids_to_tokens(token_id) logprobs.tokens.append(token) - logprobs.token_logprobs.append(id_logprob[token_id]) + logprobs.token_logprobs.append(token_logprob) if len(logprobs.text_offset) == 0: logprobs.text_offset.append(initial_text_offset) else: @@ -177,10 +187,11 @@ def create_logprobs(token_ids: List[int], last_token_len) last_token_len = len(token) - logprobs.top_logprobs.append({ - tokenizer.convert_ids_to_tokens(i): p - for i, p in id_logprob.items() - }) + if num_output_top_logprobs: + logprobs.top_logprobs.append({ + tokenizer.convert_ids_to_tokens(i): p + for i, p in step_top_logprobs.items() + } if step_top_logprobs else None) return logprobs @@ -371,8 +382,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): for the API specification. This API mimics the OpenAI Completion API. NOTE: Currently we do not support the following features: - - echo (since the vLLM engine does not currently support - getting the logprobs of prompt tokens) - suffix (the language models we currently support do not support suffix) - logit_bias (to be supported by vLLM engine) @@ -383,11 +392,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request): if error_check_ret is not None: return error_check_ret - if request.echo: - # We do not support echo since the vLLM engine does not - # currently support getting the logprobs of prompt tokens. - return create_error_response(HTTPStatus.BAD_REQUEST, - "echo is not currently supported") + # OpenAI API supports echoing the prompt when max_tokens is 0. + echo_without_generation = request.echo and request.max_tokens == 0 if request.suffix is not None: # The language models we currently support do not support suffix. @@ -443,9 +449,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request): stop=request.stop, stop_token_ids=request.stop_token_ids, ignore_eos=request.ignore_eos, - max_tokens=request.max_tokens, + max_tokens=request.max_tokens + if not echo_without_generation else 1, logprobs=request.logprobs, use_beam_search=request.use_beam_search, + prompt_logprobs=request.logprobs if request.echo else None, skip_special_tokens=request.skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) @@ -495,24 +503,42 @@ def create_stream_response_json( async def completion_stream_generator() -> AsyncGenerator[str, None]: previous_texts = [""] * request.n previous_num_tokens = [0] * request.n + has_echoed = [False] * request.n async for res in result_generator: res: RequestOutput for output in res.outputs: i = output.index delta_text = output.text[len(previous_texts[i]):] + token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[previous_num_tokens[i]:] + offsets = len(previous_texts[i]) + if request.echo and not has_echoed[i]: + if not echo_without_generation: + delta_text = res.prompt + delta_text + token_ids = res.prompt_token_ids + token_ids + top_logprobs = res.prompt_logprobs + top_logprobs + else: + delta_text = res.prompt + token_ids = res.prompt_token_ids + top_logprobs = res.prompt_logprobs + has_echoed[i] = True if request.logprobs is not None: logprobs = create_logprobs( - output.token_ids[previous_num_tokens[i]:], - output.logprobs[previous_num_tokens[i]:], - len(previous_texts[i])) + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=offsets, + ) else: logprobs = None previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) + finish_reason = output.finish_reason response_json = create_stream_response_json( index=i, text=delta_text, logprobs=logprobs, + finish_reason=finish_reason, ) yield f"data: {response_json}\n\n" if output.finish_reason is not None: @@ -551,14 +577,36 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: final_res = res assert final_res is not None choices = [] + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt for output in final_res.outputs: if request.logprobs is not None: - logprobs = create_logprobs(output.token_ids, output.logprobs) + if not echo_without_generation: + token_ids = output.token_ids + top_logprobs = output.logprobs + if request.echo: + token_ids = prompt_token_ids + token_ids + top_logprobs = prompt_logprobs + top_logprobs + else: + token_ids = prompt_token_ids + top_logprobs = prompt_logprobs + logprobs = create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) else: logprobs = None + if not echo_without_generation: + output_text = output.text + if request.echo: + output_text = prompt_text + output_text + else: + output_text = prompt_text choice_data = CompletionResponseChoice( index=output.index, - text=output.text, + text=output_text, logprobs=logprobs, finish_reason=output.finish_reason, ) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 39db35620307f..797f0a7115e6e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -106,8 +106,7 @@ class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, - float]]] = Field(default_factory=list) + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None class CompletionResponseChoice(BaseModel): From a8b150c5950a353d458b47ce319585698ef41e3f Mon Sep 17 00:00:00 2001 From: ljss <31004720+beginlner@users.noreply.github.com> Date: Tue, 28 Nov 2023 03:18:26 +0800 Subject: [PATCH 07/13] Init model on GPU to reduce CPU memory footprint (#1796) --- vllm/model_executor/model_loader.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 54b87c4b866e3..fe6ff36edb882 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -87,9 +87,9 @@ def get_model(model_config: ModelConfig) -> nn.Module: with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. - model = model_class(model_config.hf_config, linear_method) + with torch.device("cuda"): + model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": - model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) @@ -97,5 +97,4 @@ def get_model(model_config: ModelConfig) -> nn.Module: # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir, model_config.load_format, model_config.revision) - model = model.cuda() return model.eval() From a1125ad4df7011e74bbbcf88d268e521278161fa Mon Sep 17 00:00:00 2001 From: explainerauthors <152090505+explainerauthors@users.noreply.github.com> Date: Tue, 28 Nov 2023 10:19:35 -0800 Subject: [PATCH 08/13] Correct comments in parallel_state.py (#1818) --- vllm/model_executor/parallel_utils/parallel_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index 53871c85a8620..9a5e2889381d9 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -2,7 +2,7 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -"""Model and data parallel groups.""" +"""Tensor and pipeline parallel groups.""" import torch @@ -84,7 +84,7 @@ def initialize_model_parallel( def model_parallel_is_initialized(): - """Check if model and data parallel groups are initialized.""" + """Check if tensor and pipeline parallel groups are initialized.""" return (_TENSOR_MODEL_PARALLEL_GROUP is not None and _PIPELINE_MODEL_PARALLEL_GROUP is not None) From b9438904842b729f622f286447fc22c94bd2735f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 28 Nov 2023 11:22:44 -0800 Subject: [PATCH 09/13] Fix OPT param names (#1819) --- vllm/model_executor/models/opt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 8d88ccd706eaa..2bbcd0030f6fc 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -331,6 +331,9 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: continue + if name.startswith("decoder."): + name = "model." + name + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue From 708e6c18b0eff37643458424391bcc68ef7f3467 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 28 Nov 2023 14:08:01 -0800 Subject: [PATCH 10/13] [FIX] Fix class naming (#1803) --- vllm/engine/llm_engine.py | 8 ++++---- vllm/model_executor/layers/sampler.py | 6 +++--- vllm/sequence.py | 20 ++++++++++---------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e33d8aa2a2131..f79b5e84f49ac 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -12,8 +12,8 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceGroupOutputs, - SequenceOutputs, SequenceStatus) + SequenceGroupMetadata, SequenceGroupOutput, + SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) from vllm.utils import Counter @@ -363,7 +363,7 @@ def _check_beam_search_early_stopping( return current_worst_score >= highest_attainable_score def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutputs) -> None: + outputs: SequenceGroupOutput) -> None: # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs if prompt_logprobs is not None: @@ -384,7 +384,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Process the child samples for each parent sequence for parent in parent_seqs: - child_samples: List[SequenceOutputs] = parent_child_dict[ + child_samples: List[SequenceOutput] = parent_child_dict[ parent.seq_id] if len(child_samples) == 0: # This parent sequence has no children samples. Remove diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c874ec5921155..b545587fd2044 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -9,7 +9,7 @@ tensor_model_parallel_all_gather) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, - SequenceData, SequenceGroupOutputs, SequenceOutputs) + SequenceData, SequenceGroupOutput, SequenceOutput) _SAMPLING_EPS = 1e-5 @@ -641,7 +641,7 @@ def _build_sampler_output( next_token_ids, group_sample_logprobs): seq_outputs.append( - SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs)) + SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( - SequenceGroupOutputs(seq_outputs, group_prompt_logprobs)) + SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) return sampler_output diff --git a/vllm/sequence.py b/vllm/sequence.py index ecfaee6e8c3d6..7d36eeac0aa02 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -352,7 +352,7 @@ def __init__( self.block_tables = block_tables -class SequenceOutputs: +class SequenceOutput: """The model output associated with a sequence. Args: @@ -374,40 +374,40 @@ def __init__( self.logprobs = logprobs def __repr__(self) -> str: - return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, " + return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}, " f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceOutputs): + if not isinstance(other, SequenceOutput): raise NotImplementedError() return (self.parent_seq_id == other.parent_seq_id and self.output_token == other.output_token and self.logprobs == other.logprobs) -class SequenceGroupOutputs: - """The model outputs associated with a sequence group.""" +class SequenceGroupOutput: + """The model output associated with a sequence group.""" def __init__( self, - samples: List[SequenceOutputs], + samples: List[SequenceOutput], prompt_logprobs: Optional[PromptLogprobs], ) -> None: self.samples = samples self.prompt_logprobs = prompt_logprobs def __repr__(self) -> str: - return (f"SequenceGroupOutputs(samples={self.samples}, " + return (f"SequenceGroupOutput(samples={self.samples}, " f"prompt_logprobs={self.prompt_logprobs})") def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceGroupOutputs): + if not isinstance(other, SequenceGroupOutput): raise NotImplementedError() return (self.samples == other.samples and self.prompt_logprobs == other.prompt_logprobs) -# For each sequence group, we generate a list of SequenceOutputs object, +# For each sequence group, we generate a list of SequenceOutput object, # each of which contains one possible candidate for the next token. -SamplerOutput = List[SequenceGroupOutputs] +SamplerOutput = List[SequenceGroupOutput] From 6ed068a71a58110f41c9cba76035f4c086840eb1 Mon Sep 17 00:00:00 2001 From: explainerauthors <152090505+explainerauthors@users.noreply.github.com> Date: Tue, 28 Nov 2023 16:34:05 -0800 Subject: [PATCH 11/13] Use the type BlockTable (#1791) --- vllm/core/block_manager.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 4ce87f6e5061b..f6251f667c65c 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -7,6 +7,10 @@ from vllm.utils import Device +# Mapping: logical block number -> physical block. +BlockTable = List[PhysicalTokenBlock] + + class BlockAllocator: """Manages free physical token blocks for a device. @@ -26,7 +30,7 @@ def __init__( self.num_blocks = num_blocks # Initialize the free blocks. - self.free_blocks: List[PhysicalTokenBlock] = [] + self.free_blocks: BlockTable = [] for i in range(num_blocks): block = PhysicalTokenBlock(device=device, block_number=i, @@ -51,10 +55,6 @@ def get_num_free_blocks(self) -> int: return len(self.free_blocks) -# Mapping: logical block number -> physical block. -BlockTable = List[PhysicalTokenBlock] - - class AllocStatus(enum.Enum): """Result for BlockSpaceManager.can_allocate From 1cb4ad8de98b672873dbfc1f246fbceae29234a5 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 29 Nov 2023 00:40:19 +0000 Subject: [PATCH 12/13] [FIX] Fix formatting error --- vllm/core/block_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index f6251f667c65c..8b26319b88cd3 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -6,7 +6,6 @@ from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device - # Mapping: logical block number -> physical block. BlockTable = List[PhysicalTokenBlock] From e19a64c7eff2085790dbf71851208fa2dd31ca4d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 28 Nov 2023 16:56:43 -0800 Subject: [PATCH 13/13] [FIX] Fix formatting error in main branch (#1822)