Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mla #775

Merged
merged 7 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ async def benchmark(
# multi-modal benchmark is only available on OpenAI Chat backend.
raise ValueError(
"Multi-modal content is only supported on 'openai-chat' backend.")
test_output_len = 10
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
Expand Down
5 changes: 5 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);

void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& scale);

// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
95 changes: 95 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,51 @@ __global__ void reshape_and_cache_flash_kernel(
}
}
}

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;

auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx = block_idx * block_stride +
block_offset * (kv_lora_rank + pe_dim) + i +
offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
}
}
};

copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}

} // namespace vllm

// KV_T is the stored data type of kv-cache.
Expand Down Expand Up @@ -343,6 +388,56 @@ void reshape_and_cache_flash(
CALL_RESHAPE_AND_CACHE_FLASH);
}

// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));

void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);

TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);

int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);

dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
}

namespace vllm {

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
Expand Down
9 changes: 9 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);

// Concat kv_c and k_pe and cache them.
cache_ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);

// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
Expand Down
84 changes: 45 additions & 39 deletions scripts/run_example_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import argparse
import os

model_path = "/data/models/DeepSeek-R1/"
# model_path = "deepseek-ai/DeepSeek-V2-Lite"

# Parse the command-line arguments.
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="/software/data/DeepSeek-R1/", help="The model path.")
parser.add_argument("--model", type=str, default=model_path, help="The model path.")
#parser.add_argument("--model", type=str, default="/data/models/DeepSeek-R1/", help="The model path.")
parser.add_argument("--tokenizer", type=str, default="deepseek-ai/DeepSeek-R1", help="The model path.")
parser.add_argument("--tokenizer", type=str, default=model_path, help="The model path.")
#parser.add_argument("--model", type=str, default="/data/models/DeepSeek-R1-bf16-small/", help="The model path.")
#parser.add_argument("--tokenizer", type=str, default="opensourcerelease/DeepSeek-R1-bf16", help="The model path.")
parser.add_argument("--tp_size", type=int, default=8, help="The number of threads.")
Expand All @@ -16,44 +19,47 @@
os.environ["VLLM_SKIP_WARMUP"] = "true"
os.environ["HABANA_VISIBLE_DEVICES"] = "ALL"
os.environ["PT_HPU_ENABLE_LAZY_COLLECTIVES"] = "true"
os.environ["VLLM_RAY_DISABLE_LOG_TO_DRIVER"] = "1"
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
# os.environ["VLLM_RAY_DISABLE_LOG_TO_DRIVER"] = "1"
# os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
os.environ["VLLM_MOE_N_SLICE"] = "8"
os.environ["VLLM_MLA_DISABLE_REQUANTIZATION"] = "1"

if __name__ == "__main__":

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, max_tokens=50)
model = args.model
if args.tp_size == 1:
llm = LLM(
model=model,
tokenizer=args.tokenizer,
trust_remote_code=True,
dtype="bfloat16",
)
else:
llm = LLM(
model=model,
tokenizer=args.tokenizer,
tensor_parallel_size=args.tp_size,
distributed_executor_backend='ray',
trust_remote_code=True,
max_model_len=1024,
dtype="bfloat16",
)
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, max_tokens=50)
model = args.model
if args.tp_size == 1:
llm = LLM(
model=model,
tokenizer=args.tokenizer,
trust_remote_code=True,
dtype="bfloat16",
max_model_len=1024,
)
else:
llm = LLM(
model=model,
tokenizer=args.tokenizer,
tensor_parallel_size=args.tp_size,
distributed_executor_backend='mp',
trust_remote_code=True,
max_model_len=1024,
dtype="bfloat16",
)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
12 changes: 6 additions & 6 deletions scripts/run_static-online.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
tp_parrallel=8
bs=32
bs=96
in_len=1024
out_len=1024
multi_step=1
Expand All @@ -10,12 +10,13 @@ VLLM_DECODE_BLOCK_BUCKET_MAX=$((total_len * bs / 128 + 128))

# model="/data/models/DeepSeek-R1/"
# tokenizer="/data/models/DeepSeek-R1/"
model="/software/data/DeepSeek-R1/"
tokenizer="/software/data/DeepSeek-R1/"
model="/data/models/DeepSeek-R1/"
tokenizer="/data/models/DeepSeek-R1/"
model_name="DeepSeek-R1"

HABANA_VISIBLE_DEVICES="ALL" \
VLLM_MOE_N_SLICE=8 \
VLLM_MOE_N_SLICE=4 \
VLLM_MLA_DISABLE_REQUANTIZATION=1 \
PT_HPU_ENABLE_LAZY_COLLECTIVES="true" \
VLLM_RAY_DISABLE_LOG_TO_DRIVER="1" \
RAY_IGNORE_UNHANDLED_ERRORS="1" \
Expand All @@ -37,7 +38,6 @@ python -m vllm.entrypoints.openai.api_server \
--use-v2-block-manager \
--num_scheduler_steps ${multi_step}\
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--distributed_executor_backend ray \
--gpu_memory_utilization 0.9 \
--trust_remote_code 2>&1 | tee benchmark_logs/serving.log &
Expand All @@ -53,7 +53,7 @@ done
sleep 5s
echo ${pid}

num_prompts=32
num_prompts=300
request_rate=1
start_time=$(date +%s)
echo "Start to benchmark"
Expand Down
89 changes: 89 additions & 0 deletions tests/kernels/test_triton_decode_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
import torch

from vllm.attention.ops.triton_decode_attention import decode_attention_fwd


def cdiv(a, b):
return (a + b - 1) // b


@pytest.mark.parametrize("B", [3, 5])
@pytest.mark.parametrize("L", [1027, 1025])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D_QK", [128, 192, 576])
@pytest.mark.parametrize("D_V", [128, 512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L # This represents the number of tokens already in the sequence
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8

num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1),
device="cuda")
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()

# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")

# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")

# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")

b_seq_len = torch.full((B, ), seq_len, device="cuda")

attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)

# Call the original implementation.
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)

# Page size can be larger than 1.
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)

o1 = torch.zeros_like(o)

decode_attention_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)

assert torch.allclose(o, o1)
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ class DummyPlatform(CudaPlatform):
device_name = "DummyDevice"

def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1):
kv_cache_dtype, block_size, use_v1, use_mla):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
2 changes: 1 addition & 1 deletion tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
awq, casperhansen/mixtral-instruct-awq, main
Expand Down
Loading
Loading