Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port most vLLM kernels to ROCm #1313

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "cuda_compat.h"
pcmoritz marked this conversation as resolved.
Show resolved Hide resolved
#include "dispatch_utils.h"

namespace vllm {
Expand All @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel(
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
}
}
Expand Down Expand Up @@ -57,7 +58,7 @@ __global__ void activation_kernel(
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
Expand Down
22 changes: 11 additions & 11 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}

// Warp leaders store the data to shared memory.
Expand All @@ -59,11 +59,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}

// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
return VLLM_SHFL_SYNC(sum, 0);
}

// TODO(woosuk): Merge the last two dimensions of the grid.
Expand Down Expand Up @@ -220,7 +220,7 @@ __device__ void paged_attention_kernel(
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
Expand All @@ -232,10 +232,10 @@ __device__ void paged_attention_kernel(
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
qk_max = VLLM_SHFL_SYNC(qk_max, 0);

// Get the sum of the exp values.
float exp_sum = 0.f;
Expand Down Expand Up @@ -320,7 +320,7 @@ __device__ void paged_attention_kernel(
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
Expand Down Expand Up @@ -486,7 +486,7 @@ __global__ void paged_attention_v2_reduce_kernel(
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
Expand All @@ -496,10 +496,10 @@ __global__ void paged_attention_v2_reduce_kernel(
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
max_logit = VLLM_SHFL_SYNC(max_logit, 0);

// Load rescaled exp sums to shared memory.
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
Expand Down Expand Up @@ -534,7 +534,7 @@ __global__ void paged_attention_v2_reduce_kernel(

#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
cudaFuncSetAttribute( \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
(void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
Expand Down
3 changes: 2 additions & 1 deletion csrc/attention/attention_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
#pragma once

#include "../cuda_compat.h"
#include "attention_dtypes.h"

#include <float.h>
Expand All @@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
Expand Down
25 changes: 22 additions & 3 deletions csrc/attention/dtype_bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>

typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
#endif

#include <stdint.h>

namespace vllm {
Expand Down Expand Up @@ -98,7 +107,17 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return a + b;
#ifndef USE_ROCM
return a + b;
#else
// See https://github.com/RadeonOpenCompute/ROCm/issues/2534
hip_bfloat16 A, B;
__hip_bfloat16 c;
A.data = a.data;
B.data = b.data;
c.data = (A + B).data;
return c;
#endif
#endif
}

Expand Down
68 changes: 63 additions & 5 deletions csrc/attention/dtype_float16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

#ifdef USE_ROCM
#include <hip/hip_fp16.h>
#endif

#include <stdint.h>

namespace vllm {
Expand Down Expand Up @@ -64,28 +68,58 @@ struct FloatVec<uint4> {
// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
uint32_t b;
#ifndef USE_ROCM
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
#else
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = a;
tmp.u16[1] = a;
b = tmp.u32;
#endif
return b;
}

inline __device__ float half_to_float(uint16_t h) {
float f;
#ifndef USE_ROCM
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
#else
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
#endif
return f;
}

inline __device__ float2 half2_to_float2(uint32_t v) {
#ifndef USE_ROCM
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
#else
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u32 = v;
float2 ret;
ret.x = half_to_float(tmp.u16[0]);
ret.y = half_to_float(tmp.u16[1]);
return ret;
#endif
}

inline __device__ uint16_t float_to_half(float f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#ifndef USE_ROCM
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
#else
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
#endif
return tmp.u16[0];
}

Expand All @@ -94,26 +128,38 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
uint32_t u32;
uint16_t u16[2];
} tmp;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
tmp.u16[0] = float_to_half(f.x);
tmp.u16[1] = float_to_half(f.y);
#endif
return tmp.u32;
}

// Vector addition.
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
uint16_t c;
#ifndef USE_ROCM
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
#else
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}

inline __device__ uint32_t add(uint32_t a, uint32_t b) {
uint32_t c;
#ifndef USE_ROCM
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
#else
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}

Expand Down Expand Up @@ -158,14 +204,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c;
#ifndef USE_ROCM
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
#else
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}

template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c;
#ifndef USE_ROCM
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
#else
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}

Expand Down Expand Up @@ -272,7 +326,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
// Vector fused multiply-add.
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
#ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#else
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return d;
}

Expand Down
13 changes: 7 additions & 6 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"

#include <algorithm>
Expand Down Expand Up @@ -28,8 +29,8 @@ void swap_blocks(
TORCH_CHECK(false, "Invalid device combination");
}

void *src_ptr = src.data_ptr();
void *dst_ptr = dst.data_ptr();
char *src_ptr = static_cast<char*>(src.data_ptr());
char *dst_ptr = static_cast<char*>(dst.data_ptr());

const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -267,8 +268,8 @@ __global__ void gather_cached_kv_kernel(
+ head_offset * block_size
+ block_offset;

key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
}
}

Expand Down Expand Up @@ -333,8 +334,8 @@ __global__ void gather_cached_kv_kernel_optimized(
src_key_indices[j] = src_key_idx;
src_value_indices[j] = src_value_idx;

keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
}

#pragma unroll
Expand Down
19 changes: 19 additions & 0 deletions csrc/cuda_compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg)
#else
#define VLLM_LDG(arg) *(arg)
#endif

#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif

#ifndef USE_ROCM
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane);
#else
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif
4 changes: 4 additions & 0 deletions csrc/cuda_utils_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif

int get_device_attribute(
int attribute,
int device_id)
Expand Down
Loading
Loading