diff --git a/crates/luminal_metal/src/quantized.rs b/crates/luminal_metal/src/quantized.rs index 466ab782..aa826e54 100644 --- a/crates/luminal_metal/src/quantized.rs +++ b/crates/luminal_metal/src/quantized.rs @@ -33,14 +33,14 @@ impl QuantizedMatmul { fn new(device: Device, queue: CommandQueue) -> Self { let type_name = T::type_name(); Self { - matmul_pipeline: compile_function("matmul", " + matmul_pipeline: compile_function("matmul", &format!(" using namespace metal; #define QK8_0 32 #define NB_Q8_0 8 -typedef struct { +typedef struct {{ half d; // delta int8_t qs[QK8_0]; // quants -} block_q8_0; +}} block_q8_0; #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B @@ -53,37 +53,39 @@ typedef struct { #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 #define SG_MAT_ROW 8 -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread half4x4 & reg) { +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread half4x4 & reg) {{ device const int8_t * qs = ((device const int8_t *)xb->qs); const half d = xb->d; - for (int i = 0; i < 16; i++) { + for (int i = 0; i < 16; i++) {{ reg[i/4][i%4] = (qs[i + 16*il] * d); - } -} + }} +}} -kernel void matmul(device const uchar * src0 [[buffer(0)]], - device const uchar * src1 [[buffer(1)]], - device float * dst [[buffer(2)]], - constant int64_t & ne00 [[buffer(3)]], // k - constant int64_t & ne02 [[buffer(4)]], // Always 1 - constant uint64_t & nb01 [[buffer(5)]], // k * 1.0625 (avg bytes per weight) - constant uint64_t & nb02 [[buffer(6)]], // k * 1.0625 * n - constant int64_t & ne12 [[buffer(7)]], // Always 1 - constant uint64_t & nb10 [[buffer(8)]], // 4 (bytes in float) - constant uint64_t & nb11 [[buffer(9)]], // 4 * k - constant uint64_t & nb12 [[buffer(10)]], // 4 * k * m - constant int64_t & ne0 [[buffer(11)]], // n - constant int64_t & ne1 [[buffer(12)]], // m - constant uint & r2 [[buffer(13)]], // 1 - constant uint & r3 [[buffer(14)]], // 1 - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +kernel void matmul( + device const uchar * src0 [[buffer(0)]], + device const uchar * src1 [[buffer(1)]], + device {type_name} * dst [[buffer(2)]], + constant int64_t & ne00 [[buffer(3)]], // k + constant int64_t & ne02 [[buffer(4)]], // Always 1 + constant uint64_t & nb01 [[buffer(5)]], // k * 1.0625 (avg bytes per weight) + constant uint64_t & nb02 [[buffer(6)]], // k * 1.0625 * n + constant int64_t & ne12 [[buffer(7)]], // Always 1 + constant uint64_t & nb10 [[buffer(8)]], // bytes of dtype + constant uint64_t & nb11 [[buffer(9)]], // bytes of dtype * k + constant uint64_t & nb12 [[buffer(10)]], // bytes of dtype * k * m + constant int64_t & ne0 [[buffer(11)]], // n + constant int64_t & ne1 [[buffer(12)]], // m + constant uint & r2 [[buffer(13)]], // 1 + constant uint & r3 [[buffer(14)]], // 1 + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]] +) {{ threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + threadgroup {type_name} * sb = (threadgroup {type_name} *)(shared_memory + 4096); const uint r0 = tgpig.y; const uint r1 = tgpig.x; @@ -101,11 +103,11 @@ kernel void matmul(device const uchar * src0 [[buffer(0)]], // Simdgroup a b and c mat storage simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; // Accumulate into c so init to 0 - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } + simdgroup_{type_name}8x8 mb[2]; + simdgroup_{type_name}8x8 c_res[8]; // Accumulate into c so init to 0 + for (int i = 0; i < 8; i++){{ + c_res[i] = make_filled_simdgroup_matrix<{type_name}, 8>(0.0); + }} short il = (tiitg % THREAD_PER_ROW); @@ -116,25 +118,25 @@ kernel void matmul(device const uchar * src0 [[buffer(0)]], ushort offset1 = il/nl; device const block_q8_0 * x = (device const block_q8_0 *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = (device const float *)(src1 + device const {type_name} * y = (device const {type_name} *)(src1 + nb12 * im + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {{ // load data and store to threadgroup memory half4x4 temp_a; dequantize_q8_0(x, il, temp_a); threadgroup_barrier(mem_flags::mem_threadgroup); #pragma unroll(16) - for (int i = 0; i < 16; i++) { + for (int i = 0; i < 16; i++) {{ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } + }} - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + *(threadgroup {type_name}2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device {type_name}2x4 *)y); il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2+nl-1)/nl : x; @@ -144,57 +146,57 @@ kernel void matmul(device const uchar * src0 [[buffer(0)]], // load matrices from threadgroup memory and conduct outer products threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + threadgroup {type_name} * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); #pragma unroll(4) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {{ #pragma unroll(4) - for (int i = 0; i < 4; i++) { + for (int i = 0; i < 4; i++) {{ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); - } + }} simdgroup_barrier(mem_flags::mem_none); #pragma unroll(2) - for (int i = 0; i < 2; i++) { + for (int i = 0; i < 2; i++) {{ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); - } + }} lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; #pragma unroll(8) - for (int i = 0; i < 8; i++){ + for (int i = 0; i < 8; i++){{ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } + }} + }} + }} - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {{ + device {type_name} * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; - for (int i = 0; i < 8; i++) { + for (int i = 0; i < 8; i++) {{ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); - } - } else { + }} + }} else {{ // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + threadgroup {type_name} * temp_str = ((threadgroup {type_name} *)shared_memory) \ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { + for (int i = 0; i < 8; i++) {{ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } + }} threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + device {type_name} * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg == 0) {{ + for (int i = 0; i < n_rows; i++) {{ + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {{ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -}", &device), + }} + }} + }} + }} +}}"), &device), matvec_pipeline: compile_function("matvec", &format!(" using namespace metal; #define QK8_0 32 @@ -350,9 +352,9 @@ impl MetalKernel for QuantizedMatmul { encoder.set_i64(5, (k as f32 * 1.0625) as i64); encoder.set_i64(6, (k as f32 * 1.0625 * n as f32) as i64); encoder.set_i64(7, 1); - encoder.set_i64(8, 4); - encoder.set_i64(9, 4 * k as i64); - encoder.set_i64(10, 4 * k as i64 * m as i64); + encoder.set_i64(8, size_of::() as i64); + encoder.set_i64(9, (size_of::() * k) as i64); + encoder.set_i64(10, (size_of::() * k * m) as i64); encoder.set_i64(11, n as i64); encoder.set_i64(12, m as i64); encoder.set_u32(13, 1); @@ -574,7 +576,7 @@ mod tests { }; use luminal::{ prelude::*, - tests::{assert_close, random_vec_rng}, + tests::{assert_close, assert_close_precision, random_vec_rng}, }; use metal_rs::{Device, MTLResourceOptions}; use rand::{thread_rng, Rng}; @@ -690,4 +692,52 @@ mod tests { let d_c = d_b.matmul(d_a.permute()); assert_close(&out.data(), &d_c.as_vec()); } + + #[test] + fn test_quantized_matmul_fp16() { + let mut rng = thread_rng(); + let mat_data: Vec = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect(); + let inp_mat_data = random_vec_rng(1024 * 16, &mut rng); + let mut cx = Graph::new(); + let weights = cx.tensor::>(); + let inp_mat = cx.tensor::>().set(inp_mat_data.clone()); + let mut out = inp_mat.matmul(weights.permute()).retrieve(); + + // "Load" weights in 8bit + let blocks = mat_data + .chunks_exact(32) + .map(|chunk| { + let mut array = [0; 32]; + for (i, n) in chunk.iter().enumerate() { + array[i] = *n; + } + BlockQ8_0 { + _d: f16::from_f32(1.0), + _qs: array, + } + }) + .collect::>(); + let dev = Device::system_default().unwrap(); + cx.tensors + .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); + + cx.compile( + MetalQuantizedCompiler::::new(vec![weights.id]), + &mut out, + ); + cx.execute(); + + let cpu = dfdx::tensor::Cpu::default(); + let d_a = cpu.tensor_from_vec( + mat_data.into_iter().map(|i| i as f32).collect::>(), + (dfdx::shapes::Const::<512>, dfdx::shapes::Const::<1024>), + ); + let d_b = cpu.tensor_from_vec( + inp_mat_data, + (dfdx::shapes::Const::<16>, dfdx::shapes::Const::<1024>), + ); + let d_c = d_b.matmul(d_a.permute()); + assert_close_precision(&out.data(), &d_c.as_vec(), 0); + // This is imprecise currently because we accumulate in fp16 in the matmul. TODO: accumulate in fp32 and convert before saving to dest + } }