-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RoPE] Add minimal RoPE f32/f32x4 pack impl (#80)
* [RoPE]: Minimal version of RoPE implementation. Add f32/x4. * Update rope.cu * Update rope.py * Update README.md * Update rope.py * Update rope.cu * Update README.md --------- Co-authored-by: DefTruth <[email protected]>
- Loading branch information
Showing
5 changed files
with
287 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
*.so | ||
*.a | ||
*.dylib | ||
*.dll | ||
*.lib | ||
.DS_Store | ||
build | ||
*.whl | ||
tmp | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Sigmoid | ||
|
||
## 0x00 说明 | ||
|
||
RoPE基础版本,包含了RoPE在Llama的最小实现。 | ||
|
||
包含以下内容: | ||
|
||
- [X] rope_f32_kernel | ||
- [X] rope_f32x4_kernel(float4向量化版本) | ||
- [X] PyTorch bindings | ||
|
||
|
||
## 测试 | ||
|
||
```bash | ||
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... | ||
export TORCH_CUDA_ARCH_LIST=Ada | ||
python3 rope.py | ||
``` | ||
|
||
输出: | ||
|
||
```bash | ||
---------------------------------------------------------------------------------------------------- | ||
M=4096, N=512 | ||
---------------------------------------------------------------------------------------------------- | ||
out_f32: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.006247ms | ||
out_f32x4_pack: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.005484ms | ||
out_f32_th: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.734866ms | ||
---------------------------------------------------------------------------------------------------- | ||
M=4096, N=1024 | ||
---------------------------------------------------------------------------------------------------- | ||
out_f32: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:0.010335ms | ||
out_f32x4_pack: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:0.008714ms | ||
out_f32_th: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:1.447463ms | ||
---------------------------------------------------------------------------------------------------- | ||
M=8192, N=512 | ||
---------------------------------------------------------------------------------------------------- | ||
out_f32: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:0.010288ms | ||
out_f32x4_pack: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:0.008750ms | ||
out_f32_th: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:1.434934ms | ||
---------------------------------------------------------------------------------------------------- | ||
M=8192, N=1024 | ||
---------------------------------------------------------------------------------------------------- | ||
out_f32: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:0.018394ms | ||
out_f32x4_pack: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:0.015330ms | ||
out_f32_th: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:2.518094ms | ||
---------------------------------------------------------------------------------------------------- | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include <float.h> | ||
#include <vector> | ||
#include <algorithm> | ||
#include <cuda_runtime.h> | ||
#include <cuda_fp16.h> | ||
#include <cuda_bf16.h> | ||
#include <cuda_fp8.h> | ||
#include <torch/types.h> | ||
#include <torch/extension.h> | ||
|
||
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0]) | ||
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0]) | ||
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0]) | ||
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) | ||
#define BLOCK_SIZE 256 | ||
#define theta 10000.0f | ||
|
||
__global__ void rope_f32_kernel(float* x, float* out, int seq_len, int N){ | ||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
float x1 = x[idx * 2]; | ||
float x2 = x[idx * 2 + 1]; | ||
int token_pos = idx / N; | ||
int token_idx = idx % N; | ||
float exp_v = 1.0f / powf(theta, token_idx / (N * 2)); | ||
float sin_v = sinf(token_pos / exp_v); | ||
float cos_v = cosf(token_pos / exp_v); | ||
float out1 = x1 * cos_v - x2 * sin_v; | ||
float out2 = x1 * sin_v + x2 * cos_v; | ||
out[idx * 2] = out1; | ||
out[idx * 2 + 1] = out2; | ||
} | ||
|
||
// another index method of rope. | ||
__global__ void rope_f32_v2_kernel(float* x, float* out, int seq_len, int N){ | ||
int token_pos = blockIdx.x; | ||
int tid = threadIdx.x; | ||
float x1 = x[token_pos * N * 2 + tid * 2]; | ||
float x2 = x[token_pos * N * 2 + tid * 2 + 1]; | ||
float exp_v = 1.0f / powf(theta, (int)(tid / 2) / (N * 2)); | ||
float sin_v = sinf(token_pos / exp_v); | ||
float cos_v = cosf(token_pos / exp_v); | ||
float out1 = x1 * cos_v - x2 * sin_v; | ||
float out2 = x1 * sin_v + x2 * cos_v; | ||
out[token_pos * N * 2 + tid * 2] = out1; | ||
out[token_pos * N * 2 + tid * 2 + 1] = out2; | ||
} | ||
|
||
__global__ void rope_f32x4_pack_kernel(float* x, float* out, int seq_len, int N){ | ||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
float4 x_v = FLOAT4(x[idx * 4]); | ||
int token_pos = idx / N; | ||
int token_idx = idx % N; | ||
float exp_f_v = 1.0f / powf(theta, token_idx * 2 / (N * 4)); | ||
float exp_s_v = 1.0f / powf(theta, ((token_idx * 2) + 1) / (N * 4)); | ||
float sin_f_v = sinf(token_pos / exp_f_v); | ||
float cos_f_v = cosf(token_pos / exp_f_v); | ||
float sin_s_v = sinf(token_pos / exp_s_v); | ||
float cos_s_v = cosf(token_pos / exp_s_v); | ||
float4 out_v; | ||
out_v.x = x_v.x * cos_f_v - x_v.y * sin_f_v; | ||
out_v.y = x_v.x * sin_f_v + x_v.y * cos_f_v; | ||
out_v.z = x_v.z * cos_s_v - x_v.w * sin_s_v; | ||
out_v.w = x_v.z * sin_s_v + x_v.w * cos_s_v; | ||
FLOAT4(out[idx * 4]) = out_v; | ||
} | ||
|
||
// --------------------- PyTorch bindings for custom kernel ----------------------- | ||
#define STRINGFY(str) #str | ||
#define TORCH_BINDING_COMMON_EXTENSION(func) \ | ||
m.def(STRINGFY(func), &func, STRINGFY(func)); | ||
|
||
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ | ||
if (((T).options().dtype() != (th_type))) { \ | ||
std::cout << "Tensor Info:" << (T).options() << std::endl; \ | ||
throw std::runtime_error("values must be " #th_type); \ | ||
} | ||
|
||
void rope_f32(torch::Tensor x, torch::Tensor out) { | ||
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) | ||
CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32) | ||
int seq_len = x.size(0); | ||
int hidden_size = x.size(1); | ||
int N = (int)(hidden_size/2); | ||
dim3 grid((seq_len * N + BLOCK_SIZE - 1) / BLOCK_SIZE); | ||
dim3 block(BLOCK_SIZE); | ||
rope_f32_kernel<<<grid, block>>>( | ||
x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N); | ||
} | ||
|
||
void rope_f32_v2(torch::Tensor x, torch::Tensor out) { | ||
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) | ||
CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32) | ||
int seq_len = x.size(0); | ||
int hidden_size = x.size(1); | ||
int N = (int)(hidden_size/2); | ||
dim3 grid(seq_len); | ||
dim3 block(N); | ||
rope_f32_v2_kernel<<<grid, block>>>( | ||
x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N); | ||
} | ||
|
||
void rope_f32x4_pack(torch::Tensor x, torch::Tensor out) { | ||
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) | ||
CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32) | ||
int seq_len = x.size(0); | ||
int hidden_size = x.size(1); | ||
int N = (int)(hidden_size/4); | ||
dim3 grid((seq_len * N + BLOCK_SIZE - 1) / BLOCK_SIZE); | ||
dim3 block(BLOCK_SIZE); | ||
rope_f32x4_pack_kernel<<<grid, block>>>( | ||
x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N); | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
TORCH_BINDING_COMMON_EXTENSION(rope_f32) | ||
TORCH_BINDING_COMMON_EXTENSION(rope_f32_v2) | ||
TORCH_BINDING_COMMON_EXTENSION(rope_f32x4_pack) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import torch | ||
import time | ||
import math | ||
from torch.utils.cpp_extension import load | ||
from functools import partial | ||
from typing import Optional | ||
from typing import Tuple | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
torch.set_grad_enabled(False) | ||
|
||
# Load the CUDA kernel as a python module | ||
lib = load( | ||
name="rope", | ||
sources=["rope.cu"], | ||
extra_cuda_cflags=[ | ||
"-O3", | ||
"-U__CUDA_NO_HALF_OPERATORS__", | ||
"-U__CUDA_NO_HALF_CONVERSIONS__", | ||
"-U__CUDA_NO_HALF2_OPERATORS__", | ||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", | ||
"--expt-relaxed-constexpr", | ||
"--expt-extended-lambda", | ||
"--use_fast_math", | ||
], | ||
extra_cflags=["-std=c++17"], | ||
) | ||
|
||
|
||
def run_benchmark( | ||
perf_func: callable, | ||
a: torch.Tensor, | ||
tag: str, | ||
out: Optional[torch.Tensor] = None, | ||
warmup: int = 2, | ||
iters: int = 20, | ||
show_all: bool = False, | ||
): | ||
if out is not None: | ||
out.fill_(0) | ||
if out is not None: | ||
for i in range(warmup): | ||
perf_func(a, out) | ||
else: | ||
for i in range(warmup): | ||
_ = perf_func(a) | ||
|
||
torch.cuda.synchronize() | ||
start = time.time() | ||
# iters | ||
if out is not None: | ||
for i in range(iters): | ||
perf_func(a, out) | ||
else: | ||
for i in range(iters): | ||
out = perf_func(a) | ||
torch.cuda.synchronize() | ||
end = time.time() | ||
total_time = (end - start) * 1000 # ms | ||
mean_time = total_time / iters | ||
out_info = f"out_{tag}" | ||
out_val = out.flatten().detach().cpu().numpy().tolist()[:3] | ||
out_val = [round(v, 8) for v in out_val] | ||
out_val = [f"{v:<12}" for v in out_val] | ||
print(f"{out_info:>20}: {out_val}, time:{mean_time:.6f}ms") | ||
if show_all: | ||
print(out) | ||
return out.clone(), mean_time | ||
|
||
|
||
def naive_rope( | ||
x: torch.Tensor, | ||
theta: float = 10000.0, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
dim = x.shape[-1] | ||
seq_len = x.shape[-2] | ||
# get the shape of x (ignore the head dimension). | ||
# x: [batch_size, seq_len, dim] | ||
x_ = x.float().reshape(*x.shape[:-1], -1, 2) | ||
# x_: [batch_size, seq_len, dim//2, 2] | ||
x_ = torch.view_as_complex(x_) | ||
# pack neibored element into a complex | ||
# x_: [batch_size, seq_len, dim//2, 1]. eg: tensor([(1.6116-0.5772j), ...] | ||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | ||
t = torch.arange(seq_len , device=freqs.device) | ||
freqs = torch.outer(t, freqs).float().cuda() | ||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | ||
# get rotate angle | ||
xq_out = torch.view_as_real(x_ * freqs_cis).flatten(1) | ||
# do rotate | ||
return xq_out.type_as(x) | ||
|
||
print("-" * 100) | ||
M = [4096, 8192] | ||
N = [512, 1024] | ||
MN = [[m, n] for m in M for n in N] | ||
for M,N in MN: | ||
print(" " * 40 + f"M={M}, N={N}") | ||
print("-" * 100) | ||
x = torch.randn((M, N)).cuda().float().contiguous() | ||
out = torch.zeros_like(x).cuda().float().contiguous() | ||
run_benchmark(lib.rope_f32, x, "f32", out) | ||
run_benchmark(lib.rope_f32x4_pack, x, "f32x4_pack", out) | ||
run_benchmark(naive_rope, x, "f32_th") | ||
print("-" * 100) |