Skip to content

Commit

Permalink
Added metal q8 matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Mar 3, 2024
1 parent 94dd04d commit 75dd9a2
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 24 deletions.
1 change: 0 additions & 1 deletion crates/luminal_metal/src/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,6 @@ impl<T> MetalKernel for FusedElementwiseOp<T> {
);

// Execute
println!("Out: {:?}", out_size);
encoder.dispatch_1d(out_size);
encoder.end_encoding();
}
Expand Down
203 changes: 180 additions & 23 deletions crates/luminal_metal/src/quantized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,168 @@ impl<T: MetalFloat> QuantizedMatmul<T> {
fn new(device: Device, queue: CommandQueue) -> Self {
let type_name = T::type_name();
Self {
matmul_pipeline: compile_function("matmul", &format!("
matmul_pipeline: compile_function("matmul", "
using namespace metal;
#define QK8_0 32
#define NB_Q8_0 8
typedef struct {{
typedef struct {
half d; // delta
int8_t qs[QK8_0]; // quants
}} block_q8_0;
} block_q8_0;
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
#define BLOCK_SIZE_K 32
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
#define THREAD_PER_BLOCK 128
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
#define SG_MAT_ROW 8
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread half4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
for (int i = 0; i < 16; i++) {
reg[i/4][i%4] = (qs[i + 16*il] * d);
}
}
kernel void matmul(device const uchar * src0 [[buffer(0)]],
device const uchar * src1 [[buffer(1)]],
device float * dst [[buffer(2)]],
constant int64_t & ne00 [[buffer(3)]], // k
constant int64_t & ne02 [[buffer(4)]], // Always 1
constant uint64_t & nb01 [[buffer(5)]], // k * 1.0625 (avg bytes per weight)
constant uint64_t & nb02 [[buffer(6)]], // k * 1.0625 * n
constant int64_t & ne12 [[buffer(7)]], // Always 1
constant uint64_t & nb10 [[buffer(8)]], // 4 (bytes in float)
constant uint64_t & nb11 [[buffer(9)]], // 4 * k
constant uint64_t & nb12 [[buffer(10)]], // 4 * k * m
constant int64_t & ne0 [[buffer(11)]], // n
constant int64_t & ne1 [[buffer(12)]], // m
constant uint & r2 [[buffer(13)]], // 1
constant uint & r3 [[buffer(14)]], // 1
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel void matmul(
device block_q8_0* x [[buffer(0)]], // Quantized 2D matrix (KxN)
device {type_name}* y [[buffer(1)]], // Float src matrix (MxK)
device {type_name}* dst [[buffer(2)]], // Float dest matrix (MxN)
constant uint& M [[buffer(3)]],
constant uint& K [[buffer(4)]], // Must be >= 32
constant uint& N [[buffer(5)]], // Must be >= 4
constant uint& mat_batch_stride [[buffer(6)]], // x batch stride
constant uint& vec_batch_stride [[buffer(7)]], // y batch stride
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index_in_simdgroup[[thread_index_in_simdgroup]],
uint simdgroup_index_in_threadgroup [[simdgroup_index_in_threadgroup]], // 2 simdgroups in a threadgroup
threadgroup uchar* shared_memory [[threadgroup(0)]]
) {{
threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
}}
"), &device),
const uint r0 = tgpig.y;
const uint r1 = tgpig.x;
const uint im = tgpig.z;
const short nl = 2;
// if this block is of 64x32 shape or smaller
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
// Simdgroup a b and c mat storage
simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8]; // Accumulate into c so init to 0
for (int i = 0; i < 8; i++){
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
short il = (tiitg % THREAD_PER_ROW);
const uint i12 = im%ne12;
const uint i13 = im/ne12;
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
ushort offset1 = il/nl;
device const block_q8_0 * x = (device const block_q8_0 *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
+ nb12 * im
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
half4x4 temp_a;
dequantize_q8_0(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll(16)
for (int i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
}
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2+nl-1)/nl : x;
y += BLOCK_SIZE_K;
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
#pragma unroll(4)
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
}
simdgroup_barrier(mem_flags::mem_none);
#pragma unroll(2)
for (int i = 0; i < 2; i++) {
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
}
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
#pragma unroll(8)
for (int i = 0; i < 8; i++){
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
}
}
}
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
for (int i = 0; i < 8; i++) {
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
for (int i = 0; i < 8; i++) {
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
if (sgitg == 0) {
for (int i = 0; i < n_rows; i++) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
}
}
}
}
}", &device),
matvec_pipeline: compile_function("matvec", &format!("
using namespace metal;
#define QK8_0 32
Expand Down Expand Up @@ -91,7 +226,8 @@ kernel void matvec(
const int first_row = (threadgroup_position_in_grid.x * num_simdgroups_per_threadgroup + simdgroup_index_in_threadgroup) * num_rows;
// Offsets
x += first_row * num_quants_per_row + threadgroup_position_in_grid.z * (mat_batch_stride / 32);
// x += first_row * num_quants_per_row + threadgroup_position_in_grid.z * (mat_batch_stride / 32); // Batch offset
x += first_row * num_quants_per_row; // No batch offset
y += threadgroup_position_in_grid.z * vec_batch_stride;
dst += (threadgroup_position_in_grid.z * dest_vec_size);
Expand Down Expand Up @@ -189,7 +325,7 @@ impl<T> MetalKernel for QuantizedMatmul<T> {

let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
if batch_size == 1 {
if m == 1 && batch_size == 1 {
// Matvec
encoder.set_compute_pipeline_state(&self.matvec_pipeline);
encoder.set_buffer(0, Some(inputs[1].0), 0); // Matrix
Expand All @@ -200,11 +336,32 @@ impl<T> MetalKernel for QuantizedMatmul<T> {
encoder.set_u32(5, 0); // Matrix batch stride
encoder.set_u32(6, k as u32); // Vector batch stride
encoder.dispatch_thread_groups(
MTLSize::new(n.div_ceil(8) as u64, 1, m as u64),
MTLSize::new(n.div_ceil(8) as u64, 1, 1),
MTLSize::new(8, 8, 1),
);
} else {
todo!()
// Matmul
encoder.set_compute_pipeline_state(&self.matmul_pipeline);
encoder.set_buffer(0, Some(inputs[1].0), 0); // Weight matrix
encoder.set_buffer(1, Some(inputs[0].0), 0); // Input matrix
encoder.set_buffer(2, Some(output_buffers[0]), 0); // Dest matrix
encoder.set_i64(3, k as i64);
encoder.set_i64(4, 1);
encoder.set_i64(5, (k as f32 * 1.0625) as i64);
encoder.set_i64(6, (k as f32 * 1.0625 * n as f32) as i64);
encoder.set_i64(7, 1);
encoder.set_i64(8, 4);
encoder.set_i64(9, 4 * k as i64);
encoder.set_i64(10, 4 * k as i64 * m as i64);
encoder.set_i64(11, n as i64);
encoder.set_i64(12, m as i64);
encoder.set_u32(13, 1);
encoder.set_u32(14, 1);
encoder.set_threadgroup_memory_length(0, 8192);
encoder.dispatch_thread_groups(
MTLSize::new((m as u64 + 31) / 32, (n as u64 + 63) / 64, 1),
MTLSize::new(128, 1, 1),
);
}
encoder.end_encoding();
}
Expand Down

0 comments on commit 75dd9a2

Please sign in to comment.