Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breaking API changes from vllm upstream #16

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from vllm import attention_ops
from vllm._C import ops

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
Expand Down Expand Up @@ -98,7 +98,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:

for _ in range(num_iters):
if version == "v1":
attention_ops.paged_attention_v1(
ops.paged_attention_v1(
output,
query,
key_cache,
Expand All @@ -112,7 +112,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
alibi_slopes,
)
elif version == "v2":
attention_ops.paged_attention_v2(
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
Expand Down
28 changes: 0 additions & 28 deletions csrc/activation.cpp

This file was deleted.

42 changes: 0 additions & 42 deletions csrc/attention.cpp

This file was deleted.

19 changes: 0 additions & 19 deletions csrc/cache.cpp → csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,3 @@ void gather_cached_kv(
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
m.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
m.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
m.def(
"gather_cached_kv",
&gather_cached_kv,
"Gather key and value from the cache into contiguous QKV tensors");
}
13 changes: 0 additions & 13 deletions csrc/cuda_utils.cpp

This file was deleted.

5 changes: 5 additions & 0 deletions csrc/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <torch/extension.h>

int get_device_attribute(
int attribute,
int device_id);
24 changes: 0 additions & 24 deletions csrc/layernorm.cpp

This file was deleted.

75 changes: 75 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include <torch/extension.h>

void paged_attention_v1(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);

void paged_attention_v2(
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);

void rms_norm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);

void fused_add_rms_norm(
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);

void rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);

void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);

void gelu_new(
torch::Tensor& out,
torch::Tensor& input);

void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);

torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);

void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);
16 changes: 0 additions & 16 deletions csrc/pos_encoding.cpp

This file was deleted.

80 changes: 80 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include <torch/extension.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");

// Attention ops
ops.def(
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");

// Activation ops
ops.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");

// Layernorm
ops.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");

ops.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");

// Rotary embedding
ops.def(
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");

// Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");

// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"gather_cached_kv",
&gather_cached_kv,
"Gather key and value from the cache into contiguous QKV tensors");

// Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
cuda_utils.def(
"get_device_attribute",
&get_device_attribute,
"Gets the specified device attribute.");
}
19 changes: 0 additions & 19 deletions csrc/quantization.cpp

This file was deleted.

1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Documentation

models/supported_models
models/adding_model
models/engine_args

.. toctree::
:maxdepth: 1
Expand Down
Loading
Loading