Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon committed Dec 23, 2024
1 parent 0420fb2 commit 3fdbd8e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
31 changes: 17 additions & 14 deletions csrc/prepare_inputs/copy_subranges.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,34 @@ __global__ void copy_subranges_kernel(const int* __restrict__ matrix_src,

void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
torch::Tensor& matrix_tgt, int64_t n) {
// NOTE(woosuk): Here, we skip most of the error checking to minimize the
// CPU overheads. We assume that the caller will pass the correct inputs.

// Check tensor properties
TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");
// TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
// TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
// TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
// TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
// TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
// TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");

auto src_sizes = matrix_src.sizes();
auto diff_sizes = matrix_diff.sizes();
auto tgt_sizes = matrix_tgt.sizes();

TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");
// TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
// TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
// TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");

int64_t N = src_sizes[0];
int64_t M = src_sizes[1];

TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
"matrix_tgt must have same shape as matrix_src");
// TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
// TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
// TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
// "matrix_tgt must have same shape as matrix_src");

TORCH_CHECK(n <= N, "n must be <= N");
// TORCH_CHECK(n <= N, "n must be <= N");

const int* d_matrix_src = matrix_src.data_ptr<int>();
const int* d_matrix_diff = matrix_diff.data_ptr<int>();
Expand Down
3 changes: 3 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ def copy_subranges(
tgt_matrix: torch.Tensor,
num_subranges: int,
) -> None:
# NOTE(woosuk): We use `torch.ops._C.copy_subranges.default` instead of
# `torch.ops._C.copy_subranges` to avoid unnecessary CPU overheads from
# the dispatcher.
torch.ops._C.copy_subranges.default(src_matrix, diff_matrix, tgt_matrix,
num_subranges)

Expand Down

0 comments on commit 3fdbd8e

Please sign in to comment.