Skip to content

Commit

Permalink
fp16 quantized q8 matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Mar 3, 2024
1 parent 75dd9a2 commit cb07523
Showing 1 changed file with 118 additions and 68 deletions.
186 changes: 118 additions & 68 deletions crates/luminal_metal/src/quantized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ impl<T: MetalFloat> QuantizedMatmul<T> {
fn new(device: Device, queue: CommandQueue) -> Self {
let type_name = T::type_name();
Self {
matmul_pipeline: compile_function("matmul", "
matmul_pipeline: compile_function("matmul", &format!("
using namespace metal;
#define QK8_0 32
#define NB_Q8_0 8
typedef struct {
typedef struct {{
half d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
}} block_q8_0;
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
Expand All @@ -53,37 +53,39 @@ typedef struct {
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
#define SG_MAT_ROW 8
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread half4x4 & reg) {
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread half4x4 & reg) {{
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
for (int i = 0; i < 16; i++) {
for (int i = 0; i < 16; i++) {{
reg[i/4][i%4] = (qs[i + 16*il] * d);
}
}
}}
}}
kernel void matmul(device const uchar * src0 [[buffer(0)]],
device const uchar * src1 [[buffer(1)]],
device float * dst [[buffer(2)]],
constant int64_t & ne00 [[buffer(3)]], // k
constant int64_t & ne02 [[buffer(4)]], // Always 1
constant uint64_t & nb01 [[buffer(5)]], // k * 1.0625 (avg bytes per weight)
constant uint64_t & nb02 [[buffer(6)]], // k * 1.0625 * n
constant int64_t & ne12 [[buffer(7)]], // Always 1
constant uint64_t & nb10 [[buffer(8)]], // 4 (bytes in float)
constant uint64_t & nb11 [[buffer(9)]], // 4 * k
constant uint64_t & nb12 [[buffer(10)]], // 4 * k * m
constant int64_t & ne0 [[buffer(11)]], // n
constant int64_t & ne1 [[buffer(12)]], // m
constant uint & r2 [[buffer(13)]], // 1
constant uint & r3 [[buffer(14)]], // 1
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel void matmul(
device const uchar * src0 [[buffer(0)]],
device const uchar * src1 [[buffer(1)]],
device {type_name} * dst [[buffer(2)]],
constant int64_t & ne00 [[buffer(3)]], // k
constant int64_t & ne02 [[buffer(4)]], // Always 1
constant uint64_t & nb01 [[buffer(5)]], // k * 1.0625 (avg bytes per weight)
constant uint64_t & nb02 [[buffer(6)]], // k * 1.0625 * n
constant int64_t & ne12 [[buffer(7)]], // Always 1
constant uint64_t & nb10 [[buffer(8)]], // bytes of dtype
constant uint64_t & nb11 [[buffer(9)]], // bytes of dtype * k
constant uint64_t & nb12 [[buffer(10)]], // bytes of dtype * k * m
constant int64_t & ne0 [[buffer(11)]], // n
constant int64_t & ne1 [[buffer(12)]], // m
constant uint & r2 [[buffer(13)]], // 1
constant uint & r3 [[buffer(14)]], // 1
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]
) {{
threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
threadgroup {type_name} * sb = (threadgroup {type_name} *)(shared_memory + 4096);
const uint r0 = tgpig.y;
const uint r1 = tgpig.x;
Expand All @@ -101,11 +103,11 @@ kernel void matmul(device const uchar * src0 [[buffer(0)]],
// Simdgroup a b and c mat storage
simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8]; // Accumulate into c so init to 0
for (int i = 0; i < 8; i++){
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
simdgroup_{type_name}8x8 mb[2];
simdgroup_{type_name}8x8 c_res[8]; // Accumulate into c so init to 0
for (int i = 0; i < 8; i++){{
c_res[i] = make_filled_simdgroup_matrix<{type_name}, 8>(0.0);
}}
short il = (tiitg % THREAD_PER_ROW);
Expand All @@ -116,25 +118,25 @@ kernel void matmul(device const uchar * src0 [[buffer(0)]],
ushort offset1 = il/nl;
device const block_q8_0 * x = (device const block_q8_0 *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
device const {type_name} * y = (device const {type_name} *)(src1
+ nb12 * im
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {{
// load data and store to threadgroup memory
half4x4 temp_a;
dequantize_q8_0(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll(16)
for (int i = 0; i < 16; i++) {
for (int i = 0; i < 16; i++) {{
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
}
}}
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
*(threadgroup {type_name}2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device {type_name}2x4 *)y);
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2+nl-1)/nl : x;
Expand All @@ -144,57 +146,57 @@ kernel void matmul(device const uchar * src0 [[buffer(0)]],
// load matrices from threadgroup memory and conduct outer products
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
threadgroup {type_name} * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
#pragma unroll(4)
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {{
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
for (int i = 0; i < 4; i++) {{
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
}
}}
simdgroup_barrier(mem_flags::mem_none);
#pragma unroll(2)
for (int i = 0; i < 2; i++) {
for (int i = 0; i < 2; i++) {{
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
}
}}
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
#pragma unroll(8)
for (int i = 0; i < 8; i++){
for (int i = 0; i < 8; i++){{
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
}
}
}
}}
}}
}}
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {{
device {type_name} * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
for (int i = 0; i < 8; i++) {
for (int i = 0; i < 8; i++) {{
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
}
} else {
}}
}} else {{
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
threadgroup {type_name} * temp_str = ((threadgroup {type_name} *)shared_memory) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
for (int i = 0; i < 8; i++) {
for (int i = 0; i < 8; i++) {{
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
}
}}
threadgroup_barrier(mem_flags::mem_threadgroup);
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
if (sgitg == 0) {
for (int i = 0; i < n_rows; i++) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
device {type_name} * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
if (sgitg == 0) {{
for (int i = 0; i < n_rows; i++) {{
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {{
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
}
}
}
}
}", &device),
}}
}}
}}
}}
}}"), &device),
matvec_pipeline: compile_function("matvec", &format!("
using namespace metal;
#define QK8_0 32
Expand Down Expand Up @@ -350,9 +352,9 @@ impl<T> MetalKernel for QuantizedMatmul<T> {
encoder.set_i64(5, (k as f32 * 1.0625) as i64);
encoder.set_i64(6, (k as f32 * 1.0625 * n as f32) as i64);
encoder.set_i64(7, 1);
encoder.set_i64(8, 4);
encoder.set_i64(9, 4 * k as i64);
encoder.set_i64(10, 4 * k as i64 * m as i64);
encoder.set_i64(8, size_of::<T>() as i64);
encoder.set_i64(9, (size_of::<T>() * k) as i64);
encoder.set_i64(10, (size_of::<T>() * k * m) as i64);
encoder.set_i64(11, n as i64);
encoder.set_i64(12, m as i64);
encoder.set_u32(13, 1);
Expand Down Expand Up @@ -574,7 +576,7 @@ mod tests {
};
use luminal::{
prelude::*,
tests::{assert_close, random_vec_rng},
tests::{assert_close, assert_close_precision, random_vec_rng},
};
use metal_rs::{Device, MTLResourceOptions};
use rand::{thread_rng, Rng};
Expand Down Expand Up @@ -690,4 +692,52 @@ mod tests {
let d_c = d_b.matmul(d_a.permute());
assert_close(&out.data(), &d_c.as_vec());
}

#[test]
fn test_quantized_matmul_fp16() {
let mut rng = thread_rng();
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
let inp_mat_data = random_vec_rng(1024 * 16, &mut rng);
let mut cx = Graph::new();
let weights = cx.tensor::<R2<512, 1024>>();
let inp_mat = cx.tensor::<R2<16, 1024>>().set(inp_mat_data.clone());
let mut out = inp_mat.matmul(weights.permute()).retrieve();

// "Load" weights in 8bit
let blocks = mat_data
.chunks_exact(32)
.map(|chunk| {
let mut array = [0; 32];
for (i, n) in chunk.iter().enumerate() {
array[i] = *n;
}
BlockQ8_0 {
_d: f16::from_f32(1.0),
_qs: array,
}
})
.collect::<Vec<_>>();
let dev = Device::system_default().unwrap();
cx.tensors
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));

cx.compile(
MetalQuantizedCompiler::<f16>::new(vec![weights.id]),
&mut out,
);
cx.execute();

let cpu = dfdx::tensor::Cpu::default();
let d_a = cpu.tensor_from_vec(
mat_data.into_iter().map(|i| i as f32).collect::<Vec<_>>(),
(dfdx::shapes::Const::<512>, dfdx::shapes::Const::<1024>),
);
let d_b = cpu.tensor_from_vec(
inp_mat_data,
(dfdx::shapes::Const::<16>, dfdx::shapes::Const::<1024>),
);
let d_c = d_b.matmul(d_a.permute());
assert_close_precision(&out.data(), &d_c.as_vec(), 0);
// This is imprecise currently because we accumulate in fp16 in the matmul. TODO: accumulate in fp32 and convert before saving to dest
}
}

0 comments on commit cb07523

Please sign in to comment.