Skip to content

Commit

Permalink
kv cache quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpissarra committed Jul 19, 2024
1 parent 070546e commit 9490426
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 51 deletions.
107 changes: 56 additions & 51 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
/*!
* \brief The KV data managed by the KV cache.
* The array has `num_layers` NDArrays, each of them
* has layout (num_pages, 2, num_heads, page_size, head_dim).
* has layout (num_pages, 2, num_heads, page_size, num_storage).
* Along on the "2" dimension, index 0 stands for K and 1 stands for V.
*/
Array<NDArray> pages_;
Expand Down Expand Up @@ -985,10 +985,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim,
int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size,
bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta,
DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy,
PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window,
PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask,
int64_t num_storage, DLDataType dtype, DLDataType kv_storage_dtype, Device device,
PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill,
PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window,
PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged,
PackedFunc f_attention_prefill_with_tree_mask,
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
Optional<PackedFunc> f_attention_prefill_begin_forward,
Expand Down Expand Up @@ -1030,8 +1031,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
device_(device) {
pages_.reserve(num_layers);
for (int i = 0; i < num_layers; ++i) {
pages_.push_back(
NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, head_dim}, dtype, device));
pages_.push_back(NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, num_storage},
kv_storage_dtype, device));
}
// Allocate the host memory.
Device preferred_host_device = GetPreferredHostDevice(device);
Expand Down Expand Up @@ -1673,8 +1674,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
NDArray o_data, double attn_score_scaling_factor) final {
// Part 1. Shape and dtype check.
NDArray pages = pages_[layer_id];
CHECK(qkv_data.DataType() == pages.DataType());
CHECK(o_data.DataType() == pages.DataType());

// qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, head_dim)
// o_data: (num_total_length, num_qo_heads, head_dim)
Expand Down Expand Up @@ -2433,7 +2432,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27)
CHECK(args.size() == 27 || args.size() == 28 || args.size() == 29)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
int64_t num_layers = args[1];
Expand All @@ -2443,31 +2442,33 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
int rope_mode = args[5];
double rotary_scale = args[6];
double rotary_theta = args[7];
NDArray init = args[8];
PackedFunc f_transpose_append = args[9];
PackedFunc f_attention_prefill = args[10];
PackedFunc f_attention_decode = args[11];
PackedFunc f_attention_prefill_sliding_window = args[12];
PackedFunc f_attention_decode_sliding_window = args[13];
PackedFunc f_attention_prefill_ragged = args[14];
PackedFunc f_attention_prefill_ragged_begin_forward = args[15];
PackedFunc f_attention_prefill_ragged_end_forward = args[16];
PackedFunc f_attention_prefill_begin_forward = args[17];
PackedFunc f_attention_prefill_end_forward = args[18];
PackedFunc f_attention_decode_begin_forward = args[19];
PackedFunc f_attention_decode_end_forward = args[20];
PackedFunc f_merge_inplace = args[21];
PackedFunc f_split_rotary = args[22];
PackedFunc f_copy_single_page = args[23];
Optional<PackedFunc> f_debug_get_kv = args[24];
int64_t num_storage = args[8];
NDArray init = args[9];
NDArray kv_storage_init = args[10];
PackedFunc f_transpose_append = args[11];
PackedFunc f_attention_prefill = args[12];
PackedFunc f_attention_decode = args[13];
PackedFunc f_attention_prefill_sliding_window = args[14];
PackedFunc f_attention_decode_sliding_window = args[15];
PackedFunc f_attention_prefill_ragged = args[16];
PackedFunc f_attention_prefill_ragged_begin_forward = args[17];
PackedFunc f_attention_prefill_ragged_end_forward = args[18];
PackedFunc f_attention_prefill_begin_forward = args[19];
PackedFunc f_attention_prefill_end_forward = args[20];
PackedFunc f_attention_decode_begin_forward = args[21];
PackedFunc f_attention_decode_end_forward = args[22];
PackedFunc f_merge_inplace = args[23];
PackedFunc f_split_rotary = args[24];
PackedFunc f_copy_single_page = args[25];
Optional<PackedFunc> f_debug_get_kv = args[26];
PackedFunc f_compact_copy{nullptr};
PackedFunc f_attention_prefill_with_tree_mask{nullptr};

if (args.size() >= 26) {
f_compact_copy = args[25].AsObjectRef<PackedFunc>();
if (args.size() >= 28) {
f_compact_copy = args[27].AsObjectRef<PackedFunc>();
}
if (args.size() >= 27) {
f_attention_prefill_with_tree_mask = args[26].AsObjectRef<PackedFunc>();
if (args.size() >= 29) {
f_attention_prefill_with_tree_mask = args[28].AsObjectRef<PackedFunc>();
}

CHECK_EQ(cache_config.size(), 5);
Expand All @@ -2484,8 +2485,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append),
std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode),
rotary_scale, rotary_theta, num_storage, init->dtype, kv_storage_init->dtype,
init->device, std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill), std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask),
Expand All @@ -2500,7 +2502,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21)
CHECK(args.size() == 21 || args.size() == 22 || args.size() == 23)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
int64_t num_layers = args[1];
Expand All @@ -2510,25 +2512,27 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
int rope_mode = args[5];
double rotary_scale = args[6];
double rotary_theta = args[7];
NDArray init = args[8];
PackedFunc f_transpose_append = args[9];
PackedFunc f_attention_prefill = args[10];
PackedFunc f_attention_decode = args[11];
PackedFunc f_attention_prefill_sliding_window = args[12];
PackedFunc f_attention_decode_sliding_window = args[13];
PackedFunc f_attention_prefill_ragged = args[14];
PackedFunc f_merge_inplace = args[15];
PackedFunc f_split_rotary = args[16];
PackedFunc f_copy_single_page = args[17];
Optional<PackedFunc> f_debug_get_kv = args[18];
int64_t num_storage = args[8];
NDArray init = args[9];
NDArray kv_storage_init = args[10];
PackedFunc f_transpose_append = args[11];
PackedFunc f_attention_prefill = args[12];
PackedFunc f_attention_decode = args[13];
PackedFunc f_attention_prefill_sliding_window = args[14];
PackedFunc f_attention_decode_sliding_window = args[15];
PackedFunc f_attention_prefill_ragged = args[16];
PackedFunc f_merge_inplace = args[17];
PackedFunc f_split_rotary = args[18];
PackedFunc f_copy_single_page = args[19];
Optional<PackedFunc> f_debug_get_kv = args[20];
PackedFunc f_compact_copy{nullptr};
PackedFunc f_attention_prefill_with_tree_mask{nullptr};

if (args.size() >= 20) {
f_compact_copy = args[19].AsObjectRef<PackedFunc>();
if (args.size() >= 22) {
f_compact_copy = args[21].AsObjectRef<PackedFunc>();
}
if (args.size() >= 21) {
f_attention_prefill_with_tree_mask = args[20].AsObjectRef<PackedFunc>();
if (args.size() >= 23) {
f_attention_prefill_with_tree_mask = args[22].AsObjectRef<PackedFunc>();
}

CHECK_EQ(cache_config.size(), 5);
Expand All @@ -2545,8 +2549,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append),
std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode),
rotary_scale, rotary_theta, num_storage, init->dtype, kv_storage_init->dtype,
init->device, std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill), std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask), //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ def set_global_func():

def create_kv_cache(rope_mode):
support_sliding_window = 0
num_storage = head_dim
kv_storage_dtype = dtype

cache = fcreate(
tvm.runtime.ShapeTuple(
[
Expand All @@ -361,7 +364,9 @@ def create_kv_cache(rope_mode):
rope_mode,
rope_scale,
rope_theta,
num_storage,
tvm.nd.empty((), dtype, device=device),
tvm.nd.empty((), kv_storage_dtype, device=device),
ftranspose_append,
fattention_prefill,
fattention_decode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def set_global_func(head_dim, dtype):


def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
num_storage = head_dim
kv_storage_dtype = dtype

fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced")
cache = fcreate(
tvm.runtime.ShapeTuple(
Expand All @@ -160,7 +163,9 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
rope_mode,
rope_scale,
rope_theta,
num_storage,
tvm.nd.empty((), dtype, device=device),
tvm.nd.empty((), kv_storage_dtype, device=device),
ftranspose_append,
fattn_prefill,
fattn_decode,
Expand Down

0 comments on commit 9490426

Please sign in to comment.