Skip to content

Commit

Permalink
[RoPE] Add minimal RoPE f32/f32x4 pack impl (#80)
Browse files Browse the repository at this point in the history
* [RoPE]: Minimal version of RoPE implementation. Add f32/x4.

* Update rope.cu

* Update rope.py

* Update README.md

* Update rope.py

* Update rope.cu

* Update README.md

---------

Co-authored-by: DefTruth <[email protected]>
  • Loading branch information
bear-zd and DefTruth authored Oct 15, 2024
1 parent ba4998d commit 2906e78
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
| ✔️ [safe_softmax_f16x8_pack_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [online_safe_softmax_f32](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [online_safe_softmax_f32x4_pack](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [rope_f32](./rope/rope.cu)|f32|f32|[link](./rope/)|⭐️⭐️|
| ✔️ [rope_f32x4_pack](./rope/rope.cu)|f32|f32|[link](./rope/)|⭐️⭐️|
| ✔️ [layer_norm_f32](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f32x4](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
Expand Down
10 changes: 10 additions & 0 deletions rope/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
*.so
*.a
*.dylib
*.dll
*.lib
.DS_Store
build
*.whl
tmp

50 changes: 50 additions & 0 deletions rope/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Sigmoid

## 0x00 说明

RoPE基础版本,包含了RoPE在Llama的最小实现。

包含以下内容:

- [X] rope_f32_kernel
- [X] rope_f32x4_kernel(float4向量化版本)
- [X] PyTorch bindings


## 测试

```bash
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 rope.py
```

输出:

```bash
----------------------------------------------------------------------------------------------------
M=4096, N=512
----------------------------------------------------------------------------------------------------
out_f32: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.006247ms
out_f32x4_pack: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.005484ms
out_f32_th: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.734866ms
----------------------------------------------------------------------------------------------------
M=4096, N=1024
----------------------------------------------------------------------------------------------------
out_f32: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:0.010335ms
out_f32x4_pack: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:0.008714ms
out_f32_th: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:1.447463ms
----------------------------------------------------------------------------------------------------
M=8192, N=512
----------------------------------------------------------------------------------------------------
out_f32: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:0.010288ms
out_f32x4_pack: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:0.008750ms
out_f32_th: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:1.434934ms
----------------------------------------------------------------------------------------------------
M=8192, N=1024
----------------------------------------------------------------------------------------------------
out_f32: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:0.018394ms
out_f32x4_pack: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:0.015330ms
out_f32_th: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:2.518094ms
----------------------------------------------------------------------------------------------------
```
120 changes: 120 additions & 0 deletions rope/rope.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <torch/types.h>
#include <torch/extension.h>

#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define BLOCK_SIZE 256
#define theta 10000.0f

__global__ void rope_f32_kernel(float* x, float* out, int seq_len, int N){
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float x1 = x[idx * 2];
float x2 = x[idx * 2 + 1];
int token_pos = idx / N;
int token_idx = idx % N;
float exp_v = 1.0f / powf(theta, token_idx / (N * 2));
float sin_v = sinf(token_pos / exp_v);
float cos_v = cosf(token_pos / exp_v);
float out1 = x1 * cos_v - x2 * sin_v;
float out2 = x1 * sin_v + x2 * cos_v;
out[idx * 2] = out1;
out[idx * 2 + 1] = out2;
}

// another index method of rope.
__global__ void rope_f32_v2_kernel(float* x, float* out, int seq_len, int N){
int token_pos = blockIdx.x;
int tid = threadIdx.x;
float x1 = x[token_pos * N * 2 + tid * 2];
float x2 = x[token_pos * N * 2 + tid * 2 + 1];
float exp_v = 1.0f / powf(theta, (int)(tid / 2) / (N * 2));
float sin_v = sinf(token_pos / exp_v);
float cos_v = cosf(token_pos / exp_v);
float out1 = x1 * cos_v - x2 * sin_v;
float out2 = x1 * sin_v + x2 * cos_v;
out[token_pos * N * 2 + tid * 2] = out1;
out[token_pos * N * 2 + tid * 2 + 1] = out2;
}

__global__ void rope_f32x4_pack_kernel(float* x, float* out, int seq_len, int N){
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float4 x_v = FLOAT4(x[idx * 4]);
int token_pos = idx / N;
int token_idx = idx % N;
float exp_f_v = 1.0f / powf(theta, token_idx * 2 / (N * 4));
float exp_s_v = 1.0f / powf(theta, ((token_idx * 2) + 1) / (N * 4));
float sin_f_v = sinf(token_pos / exp_f_v);
float cos_f_v = cosf(token_pos / exp_f_v);
float sin_s_v = sinf(token_pos / exp_s_v);
float cos_s_v = cosf(token_pos / exp_s_v);
float4 out_v;
out_v.x = x_v.x * cos_f_v - x_v.y * sin_f_v;
out_v.y = x_v.x * sin_f_v + x_v.y * cos_f_v;
out_v.z = x_v.z * cos_s_v - x_v.w * sin_s_v;
out_v.w = x_v.z * sin_s_v + x_v.w * cos_s_v;
FLOAT4(out[idx * 4]) = out_v;
}

// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
m.def(STRINGFY(func), &func, STRINGFY(func));

#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
if (((T).options().dtype() != (th_type))) { \
std::cout << "Tensor Info:" << (T).options() << std::endl; \
throw std::runtime_error("values must be " #th_type); \
}

void rope_f32(torch::Tensor x, torch::Tensor out) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32)
int seq_len = x.size(0);
int hidden_size = x.size(1);
int N = (int)(hidden_size/2);
dim3 grid((seq_len * N + BLOCK_SIZE - 1) / BLOCK_SIZE);
dim3 block(BLOCK_SIZE);
rope_f32_kernel<<<grid, block>>>(
x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N);
}

void rope_f32_v2(torch::Tensor x, torch::Tensor out) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32)
int seq_len = x.size(0);
int hidden_size = x.size(1);
int N = (int)(hidden_size/2);
dim3 grid(seq_len);
dim3 block(N);
rope_f32_v2_kernel<<<grid, block>>>(
x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N);
}

void rope_f32x4_pack(torch::Tensor x, torch::Tensor out) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(out, torch::kFloat32)
int seq_len = x.size(0);
int hidden_size = x.size(1);
int N = (int)(hidden_size/4);
dim3 grid((seq_len * N + BLOCK_SIZE - 1) / BLOCK_SIZE);
dim3 block(BLOCK_SIZE);
rope_f32x4_pack_kernel<<<grid, block>>>(
x.data_ptr<float>(), out.data_ptr<float>(), seq_len, N);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(rope_f32)
TORCH_BINDING_COMMON_EXTENSION(rope_f32_v2)
TORCH_BINDING_COMMON_EXTENSION(rope_f32x4_pack)
}
105 changes: 105 additions & 0 deletions rope/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import time
import math
from torch.utils.cpp_extension import load
from functools import partial
from typing import Optional
from typing import Tuple
import torch.nn as nn
import torch.nn.functional as F
torch.set_grad_enabled(False)

# Load the CUDA kernel as a python module
lib = load(
name="rope",
sources=["rope.cu"],
extra_cuda_cflags=[
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
],
extra_cflags=["-std=c++17"],
)


def run_benchmark(
perf_func: callable,
a: torch.Tensor,
tag: str,
out: Optional[torch.Tensor] = None,
warmup: int = 2,
iters: int = 20,
show_all: bool = False,
):
if out is not None:
out.fill_(0)
if out is not None:
for i in range(warmup):
perf_func(a, out)
else:
for i in range(warmup):
_ = perf_func(a)

torch.cuda.synchronize()
start = time.time()
# iters
if out is not None:
for i in range(iters):
perf_func(a, out)
else:
for i in range(iters):
out = perf_func(a)
torch.cuda.synchronize()
end = time.time()
total_time = (end - start) * 1000 # ms
mean_time = total_time / iters
out_info = f"out_{tag}"
out_val = out.flatten().detach().cpu().numpy().tolist()[:3]
out_val = [round(v, 8) for v in out_val]
out_val = [f"{v:<12}" for v in out_val]
print(f"{out_info:>20}: {out_val}, time:{mean_time:.6f}ms")
if show_all:
print(out)
return out.clone(), mean_time


def naive_rope(
x: torch.Tensor,
theta: float = 10000.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
dim = x.shape[-1]
seq_len = x.shape[-2]
# get the shape of x (ignore the head dimension).
# x: [batch_size, seq_len, dim]
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
# x_: [batch_size, seq_len, dim//2, 2]
x_ = torch.view_as_complex(x_)
# pack neibored element into a complex
# x_: [batch_size, seq_len, dim//2, 1]. eg: tensor([(1.6116-0.5772j), ...]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(seq_len , device=freqs.device)
freqs = torch.outer(t, freqs).float().cuda()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
# get rotate angle
xq_out = torch.view_as_real(x_ * freqs_cis).flatten(1)
# do rotate
return xq_out.type_as(x)

print("-" * 100)
M = [4096, 8192]
N = [512, 1024]
MN = [[m, n] for m in M for n in N]
for M,N in MN:
print(" " * 40 + f"M={M}, N={N}")
print("-" * 100)
x = torch.randn((M, N)).cuda().float().contiguous()
out = torch.zeros_like(x).cuda().float().contiguous()
run_benchmark(lib.rope_f32, x, "f32", out)
run_benchmark(lib.rope_f32x4_pack, x, "f32x4_pack", out)
run_benchmark(naive_rope, x, "f32_th")
print("-" * 100)

0 comments on commit 2906e78

Please sign in to comment.