diff --git a/core/src/ops/einsum/as_matmul.rs b/core/src/ops/einsum/as_matmul.rs index c3f955f715..8bcdf11dd1 100644 --- a/core/src/ops/einsum/as_matmul.rs +++ b/core/src/ops/einsum/as_matmul.rs @@ -279,7 +279,7 @@ impl TypedOp for BasicMatMul { let [a, b] = inputs else { bail!("Expects 2 inputs"); }; - if a.datum_type.is_number() { + if a.datum_type.is_number() && b.datum_type.is_number(){ ensure!(a.rank() == b.rank()); ensure!(a.rank() >= 2); ensure!( @@ -313,9 +313,32 @@ impl TypedOp for BasicMatMul { .quantize_output .unwrap_or(b.datum_type) .fact(self.output_shape(&a_shape, &b.shape)))) + } else if let Some(opf) = inputs[1] + .opaque_fact + .as_ref() + .and_then(|of| of.downcast_ref::()) + .or_else(|| { + inputs[1] + .konst + .as_ref() + .and_then(|k| k.to_scalar::().ok()) + .and_then(|o| o.downcast_ref::()) + .map(|v| &v.fact) + }) + { + let b_shape: ShapeFact = b + .shape + .iter() + .cloned() + .chain(opf.shape.iter().map(|d| d.to_dim())) + .collect(); + Ok(tvec!(self + .quantize_output + .unwrap_or(a.datum_type) + .fact(self.output_shape(&a.shape, &b_shape)))) } else { - todo!() - } + todo!() + } } as_op!(); diff --git a/metal/src/kernels/matmul/basic/mod.rs b/metal/src/kernels/matmul/basic/mod.rs index bda2160fe3..a3b58f3fc3 100644 --- a/metal/src/kernels/matmul/basic/mod.rs +++ b/metal/src/kernels/matmul/basic/mod.rs @@ -34,6 +34,7 @@ impl GemmKernel for BasicMatMul { transpose_b, b_offset, c_offset, + .. } = params; ensure!( diff --git a/metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal b/metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal index a016769a0c..c42c57c676 100644 --- a/metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal +++ b/metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal @@ -4,6 +4,14 @@ using namespace metal; +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +#define QK4_0 32 +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; + typedef struct { int32_t batch; int32_t m; @@ -207,6 +215,122 @@ typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); + + for (int i = 0; i < 8; i += 2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); + } + + return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); +} + +// putting them in the kernel cause a significant performance penalty +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// guard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template +void mul_vec_q_n_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const int nb = args.k/QK4_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%args.batch; + const uint i13 = im/args.batch; + + //const uint64_t offset0 = first_row*args. + (i12/args.channel_broadcast_ratio)*args.b_strides[1] + (i13/args.batch_broadcast_ratio)*args.b_strides[0]; + const uint64_t offset1 = r1*args.a_strides[2] + (i12 )*args.a_strides[1] + (i13 )*args.a_strides[0]; + + //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_q_type * ax[nr]; + for (int row = 0; row < nr; ++row) { + const uint64_t offset0 = (first_row + row)*args.b_strides[2] + (i12/args.channel_broadcast_ratio)*args.b_strides[1] + (i13/args.batch_broadcast_ratio)*args.b_strides[0]; + + ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); + } + + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; + + device const float * yb = y + ix*QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + float sumy[2] = { 0.f, 0.f }; + +#pragma unroll + for (int i = 0; i < 8; i += 2) { + sumy[0] += yb[i + 0] + yb[i + 1]; + yl[i + 0] = yb[i + 0]; + yl[i + 1] = yb[i + 1]/256.f; + + sumy[1] += yb[i + 16] + yb[i + 17]; + yl[i + 8] = yb[i + 16]/16.f; + yl[i + 9] = yb[i + 17]/4096.f; + } + +#pragma unroll + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); + } + + yb += QK4_0 * 16; + } + + device float * dst_f32 = (device float *) dst + im*args.n*args.m + r1*args.n; + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.n) { + dst_f32[first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mv_q4_0_f32( + constant ggml_metal_kargs_mul & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32 @@ -372,7 +496,27 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) reg = (type4x4)(*src); } +template +void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md; + reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md; + } + + reg = (type4x4) reg_f; +} + typedef decltype(kernel_mul_mm) mat_mm_t; template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; \ No newline at end of file diff --git a/metal/src/kernels/matmul/ggml_gemm/mod.rs b/metal/src/kernels/matmul/ggml_gemm/mod.rs index 00b59fd500..c47054a12d 100644 --- a/metal/src/kernels/matmul/ggml_gemm/mod.rs +++ b/metal/src/kernels/matmul/ggml_gemm/mod.rs @@ -1,10 +1,12 @@ use crate::kernels::matmul::{GemmDispatchParams, GemmKernel}; +use crate::utils::as_q40_fact; use crate::MetalTensor; use crate::{LibraryName, MetalContext}; use anyhow::{ensure, Result}; use metal::{Buffer, MTLSize, NSUInteger}; use std::fmt; use tract_core::internal::*; +use DatumType::{F16, F32}; #[derive(Debug)] #[repr(C)] @@ -33,18 +35,20 @@ impl GemmKernel for GgmlGemm { "ggml" } - fn is_supported_dts(&self, dts: &[DatumType]) -> TractResult { - ensure!(dts.len() == 2); + fn is_supported_dts(&self, facts: &[TypedFact]) -> bool { + assert!(facts.len() == 2, "Ggml: Expected 2 inputs for Matmul"); - if dts[0] == DatumType::F32 { - Ok(dts[1] == DatumType::F32) - } else { - Ok(dts[0] == DatumType::F16 && matches!(dts[1], DatumType::F32 | DatumType::F16)) - } + let regular_types_support = matches!( + (facts[0].datum_type, facts[1].datum_type), + (F32, F32) | (F16, F16) | (F16, F32) + ); + + regular_types_support + || (as_q40_fact(&facts[1]).is_some() && matches!(facts[0].datum_type, F16 | F32)) } fn output_dt(&self, _a_dt: DatumType, _b_dt: DatumType) -> TractResult { - Ok(DatumType::F32) + Ok(F32) } fn dispatch_eval( @@ -65,12 +69,12 @@ impl GemmKernel for GgmlGemm { a_offset, transpose_b, b_offset, + q40_b, c_offset, } = params; ensure!(!transpose_a && transpose_b); - // Kernel output is transposed so we switch the inputs let a_el_size = dts[0].size_of(); let a_strides = [ (batch * m * k * a_el_size) as u64, @@ -79,13 +83,23 @@ impl GemmKernel for GgmlGemm { a_el_size as u64, ]; - let b_el_size = dts[1].size_of(); - let b_strides = [ - (batch * n * k * b_el_size) as u64, - (n * k * b_el_size) as u64, - (k * b_el_size) as u64, - b_el_size as u64, - ]; + let b_strides = if q40_b { + let b_el_size = 18; + [ + (batch * n * (k / 32) * b_el_size) as u64, + (n * (k / 32) * b_el_size) as u64, + (b_el_size * (k / 32)) as u64, + b_el_size as u64, + ] + } else { + let b_el_size = dts[1].size_of(); + [ + (batch * n * k * b_el_size) as u64, + (n * k * b_el_size) as u64, + (b_el_size * k) as u64, + b_el_size as u64, + ] + }; let params = GgmlParams { batch: batch as i32, @@ -98,13 +112,15 @@ impl GemmKernel for GgmlGemm { batch_broadcast_ratio: 1, }; - if (dts[0] == DatumType::F32) && (k % 32 == 0) && (k >= 64) && (m > 4) { + if (dts[0] == F32) && (k % 32 == 0) && (k >= 64) && (m > 4) { dispatch_metal_ggml_gemm( - context, dts, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, c_offset, + context, dts, q40_b, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, + c_offset, )?; } else { dispatch_metal_ggml_gemv( - context, dts, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, c_offset, + context, dts, q40_b, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, + c_offset, )?; } @@ -114,16 +130,16 @@ impl GemmKernel for GgmlGemm { fn mv_kernel_name_and_dispatch_params( dts: &[DatumType], + q40_b: bool, params: &GgmlParams, ) -> Result<(String, (u64, u64, u64))> { let (nth0, nth1, nrows): (u64, u64, u64) = (32, 1, 1); - if dts[1] == DatumType::F32 { - ensure!(dts[0] == DatumType::F32); + if dts[1] == F32 { + ensure!(dts[0] == F32); Ok(("kernel_mul_mv_f32_f32".to_string(), (nth0, nth1, 4))) - } else { - ensure!(dts[1] == DatumType::F16); - if dts[0] == DatumType::F32 { + } else if dts[1] == F16 { + if dts[0] == F32 { if (params.m * params.batch) < 4 { Ok(("kernel_mul_mv_f16_f32_1row".to_string(), (nth0, nth1, nrows))) } else if (params.k >= 128) && (params.k % 4 == 0) && (params.n >= 8) { @@ -133,9 +149,12 @@ fn mv_kernel_name_and_dispatch_params( } } else { // Never used in practice since we upcast input[0] to f32 - ensure!(dts[1] == DatumType::F16); + ensure!(dts[1] == F16); Ok(("kernel_mul_mv_f16_f16".to_string(), (nth0, nth1, 4))) } + } else { + ensure!((q40_b) && (dts[0] == F32)); + Ok(("kernel_mul_mv_q4_0_f32".to_string(), (8, 8, 1))) } } @@ -143,6 +162,7 @@ fn mv_kernel_name_and_dispatch_params( fn dispatch_metal_ggml_gemv( context: &MetalContext, dts: [DatumType; 3], + q40_b: bool, params: GgmlParams, a_offset: usize, a_buffer: &Buffer, @@ -151,7 +171,7 @@ fn dispatch_metal_ggml_gemv( output: &Buffer, output_offset: usize, ) -> Result<()> { - let (name, (nth0, nth1, nrows)) = mv_kernel_name_and_dispatch_params(&dts, ¶ms)?; + let (name, (nth0, nth1, nrows)) = mv_kernel_name_and_dispatch_params(&dts, q40_b, ¶ms)?; //dbg!(&name); let pipeline = context.shared_context().load_pipeline(LibraryName::Ggml, &name)?; @@ -159,16 +179,27 @@ fn dispatch_metal_ggml_gemv( let command_buffer = context.command_buffer(); command_buffer.encode(|encoder| { encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, std::mem::size_of::() as u64, ¶ms as *const _ as *const _); + encoder.set_bytes( + 0, + std::mem::size_of::() as u64, + ¶ms as *const _ as *const _, + ); encoder.set_buffer(1, Some(b_buffer), b_offset as NSUInteger); encoder.set_buffer(2, Some(a_buffer), a_offset as NSUInteger); encoder.set_buffer(3, Some(output), output_offset as NSUInteger); - let ny = (params.m as u64).div_ceil(nrows); - let grid_size = MTLSize { - width: params.n as u64, - height: ny, - depth: /* batch_size_out */ params.batch as u64, + let grid_size = if !q40_b { + MTLSize { + width: params.n as u64, + height: (params.m as u64).div_ceil(nrows), + depth: /* batch_size_out */ params.batch as u64, + } + } else { + MTLSize { + width: (params.n as u64).div_ceil(8), + height: params.m as u64, + depth: /* batch_size_out */ params.batch as u64, + } }; let group_size = MTLSize { width: nth0, height: nth1, depth: 1 }; @@ -183,6 +214,7 @@ fn dispatch_metal_ggml_gemv( fn dispatch_metal_ggml_gemm( context: &MetalContext, dts: [DatumType; 3], + q40_b: bool, params: GgmlParams, a_offset: usize, a_buffer: &Buffer, @@ -191,19 +223,26 @@ fn dispatch_metal_ggml_gemm( output: &Buffer, output_offset: usize, ) -> Result<()> { - ensure!(matches!(dts[1], DatumType::F32 | DatumType::F16) && dts[0] == DatumType::F32); + ensure!((matches!(dts[1], F32 | F16) || q40_b) && dts[0] == F32); - let i1_tname = MetalTensor::tname(dts[1])?; + let mut i1_tname = MetalTensor::tname(dts[1])?; let i2_tname = MetalTensor::tname(dts[0])?; + if q40_b { + i1_tname = "q4_0"; + } let name = format!("kernel_mul_mm_{i1_tname}_{i2_tname}"); - + //dbg!(&name); let pipeline = context.shared_context().load_pipeline(LibraryName::Ggml, &name)?; let command_buffer = context.command_buffer(); command_buffer.encode(|encoder| { encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, std::mem::size_of::() as u64, ¶ms as *const _ as *const _); + encoder.set_bytes( + 0, + std::mem::size_of::() as u64, + ¶ms as *const _ as *const _, + ); encoder.set_buffer(1, Some(b_buffer), b_offset as NSUInteger); encoder.set_buffer(2, Some(a_buffer), a_offset as NSUInteger); encoder.set_buffer(3, Some(output), output_offset as NSUInteger); @@ -225,8 +264,13 @@ fn dispatch_metal_ggml_gemm( #[cfg(test)] mod tests { + use tract_core::ops::einsum::BasicMatMul; + use tract_linalg::frame::block_quant::{BlockQuant, BlockQuantFact, BlockQuantValue, Q4_0}; + use super::*; use crate::kernels::matmul::tests::run_mmm_test_case; + use crate::kernels::matmul::GemmImpl; + use crate::IntoMetal; #[test] fn test_ggml_compilation() -> Result<()> { @@ -237,63 +281,99 @@ mod tests { #[test] fn test_mat_mul() -> TractResult<()> { - run_mmm_test_case::((1, 5, 64, 2), false, true, DatumType::F32, DatumType::F32)?; - run_mmm_test_case::((2, 1, 32, 2), false, true, DatumType::F32, DatumType::F32)?; - run_mmm_test_case::((1, 5, 64, 2), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::( - (3, 8, 64, 200), - false, - true, - DatumType::F32, - DatumType::F16, - )?; - run_mmm_test_case::( - (10, 25, 512, 320), - false, - true, - DatumType::F32, - DatumType::F16, - )?; + run_mmm_test_case::((1, 5, 64, 2), false, true, F32, F32)?; + run_mmm_test_case::((2, 1, 32, 2), false, true, F32, F32)?; + run_mmm_test_case::((1, 5, 64, 2), false, true, F32, F16)?; + run_mmm_test_case::((3, 8, 64, 200), false, true, F32, F16)?; + run_mmm_test_case::((10, 25, 512, 320), false, true, F32, F16)?; Ok(()) } #[test] fn test_mat_vec() -> TractResult<()> { // f32_f32 - run_mmm_test_case::((1, 8, 32, 3), false, true, DatumType::F32, DatumType::F32)?; - run_mmm_test_case::((1, 4, 61, 2), false, true, DatumType::F32, DatumType::F32)?; - run_mmm_test_case::((2, 4, 128, 8), false, true, DatumType::F32, DatumType::F32)?; + run_mmm_test_case::((1, 8, 32, 3), false, true, F32, F32)?; + run_mmm_test_case::((1, 4, 61, 2), false, true, F32, F32)?; + run_mmm_test_case::((2, 4, 128, 8), false, true, F32, F32)?; // f16_f32_1row - run_mmm_test_case::((1, 1, 32, 2), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::((1, 3, 62, 2), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::((1, 3, 2, 9), false, true, DatumType::F32, DatumType::F16)?; + run_mmm_test_case::((1, 1, 32, 2), false, true, F32, F16)?; + run_mmm_test_case::((1, 3, 62, 2), false, true, F32, F16)?; + run_mmm_test_case::((1, 3, 2, 9), false, true, F32, F16)?; // f16_f32_L4 - run_mmm_test_case::((2, 2, 128, 8), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::( - (4, 4, 156, 30), - false, - true, - DatumType::F32, - DatumType::F16, - )?; + run_mmm_test_case::((2, 2, 128, 8), false, true, F32, F16)?; + run_mmm_test_case::((4, 4, 156, 30), false, true, F32, F16)?; // f16_f32 - run_mmm_test_case::((1, 4, 32, 2), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::((1, 4, 61, 2), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::((4, 4, 128, 7), false, true, DatumType::F32, DatumType::F16)?; + run_mmm_test_case::((1, 4, 32, 2), false, true, F32, F16)?; + run_mmm_test_case::((1, 4, 61, 2), false, true, F32, F16)?; + run_mmm_test_case::((4, 4, 128, 7), false, true, F32, F16)?; // f16_f16 - run_mmm_test_case::((1, 1, 2, 1), false, true, DatumType::F16, DatumType::F16)?; - run_mmm_test_case::((1, 4, 61, 2), false, true, DatumType::F16, DatumType::F16)?; - run_mmm_test_case::( - (2, 16, 128, 9), - false, - true, - DatumType::F16, - DatumType::F16, - )?; + run_mmm_test_case::((1, 1, 2, 1), false, true, F16, F16)?; + run_mmm_test_case::((1, 4, 61, 2), false, true, F16, F16)?; + run_mmm_test_case::((2, 16, 128, 9), false, true, F16, F16)?; + Ok(()) + } + + fn run_mat_mul_q4_test(batch: usize, m: usize, k: usize, n: usize) -> TractResult<()> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + ensure!(k % 32 == 0); + let a_shape = [batch, m, k]; + let b_shape = [batch, n, k]; + + let a_data = (0..batch * k * m) + .map(|f| f as f32 / (batch * m * k) as f32) + .collect::>(); + + let a = Tensor::from_shape(&a_shape, &a_data)?; + + let b_data = (0..batch * n * k) + .map(|f| f as f32 / (batch * n * k) as f32) + .collect::>(); + + let mut b_quant = Q4_0.quant_f32(&b_data)?; + + let b_tensor = Tensor::from_shape(&b_shape, &b_data)?; + + crate::utils::tract_to_gguf_q4_0_packing(&mut b_quant)?; + + let b_q4_0_tensor = tensor0(Opaque(Arc::new(BlockQuantValue { + fact: BlockQuantFact { format: Box::new(Q4_0), shape: tvec![batch, n, k] }, + value: b_quant, + }))); + + let metal_output = GemmImpl::::new(false, true).eval( + context, + &a.clone().into_metal()?, + &b_q4_0_tensor.clone().into_metal()?, + )?; + + let matmul = BasicMatMul { + transpose_a: false, + transpose_b: true, + transpose_c: false, + quantize_output: None, + }; + + let output = args_1!(matmul.eval(tvec![a.into_tvalue(), b_tensor.into_tvalue()])?); + metal_output.to_cpu()?.close_enough(&output, Approximation::SuperApproximate)?; + Ok(()) + }) + }) + } + + #[test] + fn test_q4() -> TractResult<()> { + run_mat_mul_q4_test(32, 1, 32, 32)?; + run_mat_mul_q4_test(1, 32003, 2048, 1)?; + run_mat_mul_q4_test(4, 1, 2048, 32003)?; + run_mat_mul_q4_test(1, 1, 32, 32)?; + run_mat_mul_q4_test(1, 1, 64, 4)?; + run_mat_mul_q4_test(3, 1, 4096, 4096)?; + Ok(()) } } diff --git a/metal/src/kernels/matmul/mfa/mod.rs b/metal/src/kernels/matmul/mfa/mod.rs index d646f53195..4e7eef19d9 100644 --- a/metal/src/kernels/matmul/mfa/mod.rs +++ b/metal/src/kernels/matmul/mfa/mod.rs @@ -39,6 +39,7 @@ impl GemmKernel for MfaGemm { transpose_b, b_offset, c_offset, + .. } = params; let a_strides = if transpose_a { diff --git a/metal/src/kernels/matmul/mlx_gemm/mod.rs b/metal/src/kernels/matmul/mlx_gemm/mod.rs index b60db21f1b..fb025f0533 100644 --- a/metal/src/kernels/matmul/mlx_gemm/mod.rs +++ b/metal/src/kernels/matmul/mlx_gemm/mod.rs @@ -73,6 +73,7 @@ impl GemmKernel for MlxGemm { transpose_b, b_offset, c_offset, + .. } = params; let a_strides = if transpose_a { diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 1d12bc0551..c6036b6257 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -12,6 +12,7 @@ pub use mlx_gemm::MlxGemm; pub use mmm_tile_8x8::{metal_mmm_tile_8x8, mmm_tile_8x8}; pub use mps::MpsMatMul; +use crate::utils::as_q40_tensor; use crate::{MetalContext, MetalTensor}; use metal::Buffer; use num_traits::One; @@ -28,7 +29,7 @@ pub enum MetalGemmImplKind { impl Default for MetalGemmImplKind { fn default() -> Self { - Self::Mlx + Self::Ggml } } @@ -43,6 +44,7 @@ pub struct GemmDispatchParams { pub a_offset: usize, pub transpose_b: bool, pub b_offset: usize, + pub q40_b: bool, pub c_offset: usize, } @@ -56,6 +58,7 @@ impl GemmDispatchParams { b_offset: usize, b_shape: &[usize], transpose_b: bool, + q40_b: bool, c_offset: usize, c_shape: &[usize], ) -> TractResult> { @@ -86,6 +89,7 @@ impl GemmDispatchParams { a_offset, transpose_b, b_offset, + q40_b, c_offset, }]), // bkm, 1kn -> bmn @@ -102,6 +106,7 @@ impl GemmDispatchParams { a_offset: a_offset + a_batch_idx * m * k * dts[0].size_of(), transpose_b, b_offset, + q40_b, c_offset: c_offset + a_batch_idx * m * n * dts[2].size_of(), }) .collect()), @@ -121,6 +126,7 @@ impl GemmDispatchParams { a_offset, transpose_b, b_offset: b_offset + b_batch_idx * n * k * dts[1].size_of(), + q40_b, c_offset: c_offset + b_batch_idx * m * n * dts[2].size_of(), }) .collect()), @@ -141,6 +147,7 @@ impl GemmDispatchParams { a_offset, transpose_b, b_offset, + q40_b, c_offset, }]) } @@ -151,9 +158,10 @@ impl GemmDispatchParams { pub trait GemmKernel: fmt::Display + fmt::Debug + Clone + Default + Send + Sync { fn name() -> &'static str; - fn is_supported_dts(&self, dts: &[DatumType]) -> TractResult { - ensure!(dts.len() == 2); - Ok(matches!(dts[0], DatumType::F32 | DatumType::F16) && dts[0] == dts[1]) + fn is_supported_dts(&self, facts: &[TypedFact]) -> bool { + assert!(facts.len() == 2, "Expected 2 inputs for matmul"); + matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16) + && facts[0].datum_type == facts[1].datum_type } fn output_dt(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult { @@ -205,15 +213,15 @@ impl GemmImpl { output } - pub fn output_facts(&self, a: &TypedFact, b: &TypedFact) -> TractResult> { - let out_shape = self.output_shape(&a.shape, &b.shape).to_vec(); - let out_dt = self.matmul.output_dt(a.datum_type().unwrap(), b.datum_type().unwrap())?; - if out_dt == DatumType::F32 { - Ok(tvec!(f32::fact(out_shape))) - } else { - ensure!(out_dt == DatumType::F16); - Ok(tvec!(f16::fact(out_shape))) - } + pub fn output_facts( + &self, + shape: &[TDim], + a_dt: DatumType, + b_dt: DatumType, + ) -> TractResult> { + let out_dt = self.matmul.output_dt(a_dt, b_dt)?; + ensure!([DatumType::F16, DatumType::F32].contains(&out_dt)); + Ok(tvec!(out_dt.fact(shape))) } pub fn eval( @@ -222,8 +230,12 @@ impl GemmImpl { a: &MetalTensor, b: &MetalTensor, ) -> TractResult { + let b_shape = as_q40_tensor(b.view().tensor) + .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().copied()).collect()) + .unwrap_or(b.shape().to_vec()); + let c_dt = self.matmul.output_dt(a.datum_type(), b.datum_type())?; - let c_shape = self.output_shape(a.shape(), b.shape()); + let c_shape = self.output_shape(a.shape(), &b_shape); let c = unsafe { MetalTensor::uninitialized_dt(c_dt, &c_shape)? }; self.dispatch_eval(context, a, b, &c)?; @@ -242,7 +254,12 @@ impl GemmImpl { b.retain_until_completion(); c.retain_until_completion(); - ensure!(c.shape() == self.output_shape(a.shape(), b.shape()).as_slice()); + let q40_b = as_q40_tensor(b.view().tensor); + let b_shape = q40_b + .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().copied()).collect()) + .unwrap_or(b.shape().to_vec()); + + ensure!(c.shape() == self.output_shape(a.shape(), &b_shape).as_slice()); if c.shape().iter().product::() == 0 { return Ok(()); @@ -254,8 +271,9 @@ impl GemmImpl { a.shape(), self.transpose_a, b.metal_offset(), - b.shape(), + &b_shape, self.transpose_b, + q40_b.is_some(), c.metal_offset(), c.shape(), )?; @@ -308,6 +326,10 @@ mod tests { use proptest::collection::vec; use proptest::prelude::*; use tract_core::ops::einsum::BasicMatMul; + use tract_core::tract_data::itertools::Itertools; + use tract_core::tract_linalg::frame::block_quant::{ + BlockQuant, BlockQuantFact, BlockQuantValue, Q4_0, + }; pub(crate) fn run_mmm_test_case( (batch, m, k, n): (usize, usize, usize, usize), @@ -393,6 +415,7 @@ mod tests { 0, &[1, k, n], false, + false, 0, &[1, m, n], )?, @@ -406,6 +429,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 0, }] ); @@ -419,6 +443,7 @@ mod tests { 0, &[10, k, n], false, + false, 0, &[10, m, n], )?, @@ -432,6 +457,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 0, }] ); @@ -445,6 +471,7 @@ mod tests { 0, &[2, k, n], false, + false, 10, &[2, m, n], )?, @@ -459,6 +486,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 10, }, GemmDispatchParams { @@ -471,6 +499,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 1 * n * k * dt.size_of(), + q40_b: false, c_offset: 10 + m * n * dt.size_of(), } ] @@ -485,6 +514,7 @@ mod tests { 0, &[2, k, n], false, + false, 100, &[2, m, n], )?, @@ -498,6 +528,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 100, }] ); @@ -511,6 +542,7 @@ mod tests { 0, &[1, k, n], false, + false, 100, &[2, m, n], )?, @@ -525,6 +557,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 100, }, GemmDispatchParams { @@ -537,6 +570,7 @@ mod tests { a_offset: 1 * m * k * dt.size_of(), transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 100 + 1 * m * n * dt.size_of(), } ] @@ -551,6 +585,7 @@ mod tests { 10, &[1, k, n], false, + false, 0, &[10, m, n], )?, @@ -564,6 +599,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 10, + q40_b: false, c_offset: 0, }] ); @@ -609,10 +645,51 @@ mod tests { #[test] fn mmm_mlx_prop_f16(pb in any::>()) { - prop_assert_eq!(pb.run().unwrap(), pb.reference().unwrap()) + let output = pb.run().unwrap(); + let _ = output.close_enough(&pb.reference().unwrap(), Approximation::Approximate); + } + + #[test] + fn mmm_ggml_prop_f32(pb in >::arbitrary_with( + MmmProblemParams { + force_k_as_inner_axis: true, + q4_0_weights: false, + } + )) { + let output = pb.run().unwrap(); + let _ = output.close_enough(&pb.reference().unwrap(), Approximation::Approximate); + } + + #[test] + fn mmm_ggml_prop_f16(pb in >::arbitrary_with( + MmmProblemParams { + force_k_as_inner_axis: true, + q4_0_weights: false, + } + )) { + let output = pb.run().unwrap(); + let _ = output.close_enough(&pb.reference().unwrap(), Approximation::Approximate); + } + + #[test] + fn mmm_ggml_prop_q4(pb in >::arbitrary_with( + MmmProblemParams { + force_k_as_inner_axis: true, + q4_0_weights: true, + } + )) { + let output = pb.run().unwrap(); + let _ = output.close_enough(&pb.reference().unwrap(), Approximation::Approximate + ); } } + #[derive(Default, Debug, Clone)] + pub struct MmmProblemParams { + pub force_k_as_inner_axis: bool, + pub q4_0_weights: bool, + } + #[derive(Debug, new)] pub struct MmmProblem where @@ -627,6 +704,7 @@ mod tests { pub transpose_lhs: bool, pub rhs: Vec, pub transpose_rhs: bool, + pub q4_0: bool, pub _phantom: std::marker::PhantomData, } @@ -636,15 +714,19 @@ mod tests { F: Datum + Float, usize: AsPrimitive, { - type Parameters = (); + type Parameters = MmmProblemParams; type Strategy = BoxedStrategy; - fn arbitrary_with(_: ()) -> Self::Strategy { - (1usize..10, 1usize..20, 1usize..20, 1usize..20) - .prop_flat_map(|(b, m, k, n)| { + fn arbitrary_with(params: MmmProblemParams) -> Self::Strategy { + (1usize..4, 1usize..128, 1usize..256, 1usize..128) + .prop_flat_map(move |(b, m, mut k, n)| { + if params.q4_0_weights { + k = k.div_ceil(32) * 32 + }; + let lhs_len = b * m * k; - let rhs_len = b * k * n; - let datum = (0usize..10).prop_map(|x| x.as_()); + let rhs_len = b * n * k; + let datum = (0usize..100).prop_map(|x| x.as_()); ( Just(b), Just(m), @@ -656,16 +738,22 @@ mod tests { proptest::bool::ANY, ) }) - .prop_map(|(b, m, k, n, lhs, transpose_lhs, rhs, transpose_rhs)| Self { - b, - m, - k, - n, - lhs, - transpose_lhs, - rhs, - transpose_rhs, - _phantom: std::marker::PhantomData, + .prop_map(move |(b, m, k, n, lhs, mut transpose_lhs, rhs, mut transpose_rhs)| { + if params.force_k_as_inner_axis { + (transpose_lhs, transpose_rhs) = (false, true); + } + Self { + b, + m, + k, + n, + lhs, + transpose_lhs, + rhs, + transpose_rhs, + q4_0: params.q4_0_weights, + _phantom: std::marker::PhantomData, + } }) .boxed() } @@ -677,7 +765,7 @@ mod tests { F: Datum + Float + std::ops::AddAssign, usize: AsPrimitive, { - pub fn reference(&self) -> Result> { + pub fn reference(&self) -> Result { let matmul = BasicMatMul { transpose_a: self.transpose_lhs, transpose_b: self.transpose_rhs, @@ -698,10 +786,10 @@ mod tests { let output = matmul.eval(tvec![lhs_tensor.into_tvalue(), rhs_tensor.into_tvalue()])?; - Ok(output[0].clone().into_tensor().as_slice::()?.to_vec()) + Ok(output[0].clone().into_tensor()) } - pub fn run(&self) -> Result> { + pub fn run(&self) -> Result { objc::rc::autoreleasepool(|| { crate::METAL_CONTEXT.with_borrow(|context| { let lhs = if self.transpose_lhs { @@ -710,15 +798,37 @@ mod tests { Tensor::from_shape(&[self.b, self.m, self.k], &self.lhs)?.into_metal()? }; let rhs = if self.transpose_rhs { - Tensor::from_shape(&[self.b, self.n, self.k], &self.rhs)?.into_metal()? + if !self.q4_0 { + Tensor::from_shape(&[self.b, self.n, self.k], &self.rhs)? + } else { + let mut b_quant = Q4_0.quant_f32( + &self + .rhs + .clone() + .into_iter() + .map(|x| x.to_f32().unwrap()) + .collect_vec(), + )?; + + crate::utils::tract_to_gguf_q4_0_packing(&mut b_quant)?; + + tensor0(Opaque(Arc::new(BlockQuantValue { + fact: BlockQuantFact { + format: Box::new(Q4_0), + shape: tvec![self.b, self.n, self.k], + }, + value: b_quant, + }))) + } } else { - Tensor::from_shape(&[self.b, self.k, self.n], &self.rhs)?.into_metal()? - }; + Tensor::from_shape(&[self.b, self.k, self.n], &self.rhs)? + } + .into_metal()?; let matmul = GemmImpl::::new(self.transpose_lhs, self.transpose_rhs); let c = matmul.eval(context, &lhs, &rhs)?; - Ok(c.to_cpu()?.as_slice::()?.to_vec()) + Ok(c.to_cpu()?) }) }) } diff --git a/metal/src/kernels/matmul/mps/mod.rs b/metal/src/kernels/matmul/mps/mod.rs index 33efac3266..0b8917d690 100644 --- a/metal/src/kernels/matmul/mps/mod.rs +++ b/metal/src/kernels/matmul/mps/mod.rs @@ -53,6 +53,7 @@ impl GemmKernel for MpsMatMul { transpose_b, b_offset, c_offset, + .. } = params; let dt = dts[0]; diff --git a/metal/src/ops/gemm.rs b/metal/src/ops/gemm.rs index fc82a27d8d..84f62b679c 100644 --- a/metal/src/ops/gemm.rs +++ b/metal/src/ops/gemm.rs @@ -1,6 +1,7 @@ use crate::kernels::matmul::{GemmImpl, GemmKernel}; use crate::ops::MetalEvalOp; +use crate::utils::{as_q40_fact, as_q40_tensor}; use crate::{MetalContext, MetalTensorExt}; use anyhow::{bail, ensure}; use tract_core::internal::*; @@ -43,17 +44,32 @@ impl MetalGemm { let [a, b] = inputs else { bail!("Expects 2 inputs"); }; - ensure!(a.rank() == b.rank()); - ensure!(a.rank() >= 2); - ensure!( - a.shape[a.rank() - 2 + !self.transpose_a() as usize] - == b.shape[b.rank() - 2 + self.transpose_b() as usize] - ); - - self.kernel.output_facts(a, b) + + if a.datum_type.is_number() && b.datum_type.is_number() { + ensure!(a.rank() == b.rank()); + ensure!(a.rank() >= 2); + ensure!( + a.shape[a.rank() - 2 + !self.transpose_a() as usize] + == b.shape[b.rank() - 2 + self.transpose_b() as usize] + ); + let out_shape = self.kernel.output_shape(&a.shape, &b.shape); + Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?) + } else if let Some(opf) = as_q40_fact(inputs[0]) { + let a_shape: ShapeFact = + a.shape.iter().cloned().chain(opf.shape.iter().map(|d| d.to_dim())).collect(); + + let out_shape = self.kernel.output_shape(&a_shape, &b.shape); + Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?) + } else if let Some(opf) = as_q40_fact(inputs[1]) { + let b_shape: ShapeFact = + b.shape.iter().cloned().chain(opf.shape.iter().map(|d| d.to_dim())).collect(); + let out_shape = self.kernel.output_shape(&a.shape, &b_shape); + Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?) + } else { + todo!() + } } } - impl MetalEvalOp for MetalGemm { fn metal_eval( &self, @@ -69,8 +85,13 @@ impl MetalEvalOp for MetalGemm { let b = b_opaque .to_metal_tensor() .with_context(|| anyhow!("B tensor is not a metal tensor {:?}", b_opaque))?; - let c_dt = a.datum_type(); - let c_shape = self.kernel.output_shape(a.shape(), b.shape()); + + let b_shape = as_q40_tensor(b.view().tensor) + .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().copied()).collect()) + .unwrap_or(b.shape().to_vec()); + + let c_dt = self.kernel.matmul.output_dt(a.datum_type(), b.datum_type())?; + let c_shape = self.kernel.output_shape(a.shape(), &b_shape); let c = crate::ops::make_tensor_for_node(session, node_id, c_dt, &c_shape)?; self.kernel.dispatch_eval(context, a, b, &c)?; Ok(tvec![c.into_opaque_tensor().into_tvalue()]) diff --git a/metal/src/tensor/mod.rs b/metal/src/tensor/mod.rs index 2a054d2110..e242f0ee05 100644 --- a/metal/src/tensor/mod.rs +++ b/metal/src/tensor/mod.rs @@ -21,7 +21,7 @@ pub enum MetalTensor { } impl MetalTensor { - pub const SUPPORTED_DT: [DatumType; 11] = [ + pub const SUPPORTED_DT: [DatumType; 12] = [ DatumType::Bool, DatumType::F32, DatumType::F16, @@ -33,6 +33,7 @@ impl MetalTensor { DatumType::U32, DatumType::I64, DatumType::U64, + DatumType::Opaque, ]; pub fn tname(dt: DatumType) -> TractResult<&'static str> { @@ -48,6 +49,7 @@ impl MetalTensor { DatumType::I32 => "i32", DatumType::I64 => "i64", DatumType::Bool => "bool", + DatumType::Opaque => "opaque", _ => bail!("Unsupport dt {:?} for metal kernel function", dt), }) } diff --git a/metal/src/tensor/owned.rs b/metal/src/tensor/owned.rs index a936eb0063..4cb23654c3 100644 --- a/metal/src/tensor/owned.rs +++ b/metal/src/tensor/owned.rs @@ -1,3 +1,4 @@ +use crate::utils::as_q40_tensor; use crate::MetalTensor; use anyhow::Result; use metal::Buffer; @@ -129,14 +130,19 @@ impl OwnedMetalTensor { /// Create a owned metal tensor from a cpu tensor. pub fn from_tensor>(tensor: T) -> Result { crate::METAL_CONTEXT.with_borrow(|ctxt| { - let m_value = tensor.into(); + let m_value: MValue = tensor.into(); let tensor_view = m_value.view(); ensure!( MetalTensor::is_supported_dt(tensor_view.datum_type()), "Tensor of {:?} is not copied. No Metal buffer can be allocated for it.", tensor_view.datum_type(), ); - let buffer = ctxt.buffer_from_slice(tensor_view.tensor.as_bytes()); + + let data_bytes = as_q40_tensor(tensor_view.tensor) + .map(|bqv| bqv.value.as_bytes()) + .unwrap_or(tensor_view.tensor.as_bytes()); + + let buffer = ctxt.buffer_from_slice(data_bytes); Ok(OwnedMetalTensor { inner: m_value, metal: buffer }) }) } diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 090e680671..e45c14560d 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -14,6 +14,7 @@ use crate::rewrite_rules::{ BasicSilu, }; use crate::tensor::MetalTensorExt; +use crate::utils::{as_q40_fact, as_q40_tensor}; use crate::{IntoMetal, MetalFact, MetalTensor}; use std::borrow::Cow; use std::fmt::Debug; @@ -27,6 +28,7 @@ use tract_core::ops::element_wise::ElementWiseOp; use tract_core::ops::konst::Const; use tract_core::ops::logic::Comp; use tract_core::ops::nn::{Reduce, Softmax as CoreSoftmax}; +use tract_core::tract_linalg::frame::block_quant::BlockQuantValue; use tract_core::transform::ModelTransform; use tract_itertools::Itertools; @@ -193,13 +195,14 @@ fn can_translate_to_metal_op( node: &TypedNode, gemm_impl: MetalGemmImplKind, ) -> TractResult { - let input_dts = source - .node_input_facts(node.id)? + let input_facts = source.node_input_facts(node.id)?.iter().map(|f| (*f).clone()).collect_vec(); + let input_dts = input_facts .iter() .map(|f| f.as_metal_fact().map(|f| f.datum_type).unwrap_or(f.datum_type)) .collect_vec(); - let in_dts_metal_compatible = input_dts.iter().all(|dt| MetalTensor::is_supported_dt(*dt)); + let in_dts_metal_compatible = + input_facts.iter().all(|fact| MetalTensor::is_supported_dt(fact.datum_type)); Ok(in_dts_metal_compatible && (node @@ -211,7 +214,7 @@ fn can_translate_to_metal_op( || node.op_as::().is_some_and(|op| { !op.transpose_c && op.quantize_output.is_none() - && check_matmul_in_dts(gemm_impl, &input_dts) + && check_matmul_in_dts(gemm_impl, &input_facts) }) || node .op_as::() @@ -337,14 +340,16 @@ macro_rules! map_element_wise_ops { }; } -fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, dts: &[DatumType]) -> bool { - let is_supported = match gemm_impl { - MetalGemmImplKind::Mlx => MlxGemm.is_supported_dts(dts), - MetalGemmImplKind::Mps => MpsMatMul.is_supported_dts(dts), - MetalGemmImplKind::Mfa => MfaGemm.is_supported_dts(dts), - MetalGemmImplKind::Ggml => MfaGemm.is_supported_dts(dts), - }; - is_supported.unwrap_or(false) +fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, in_facts: &[TypedFact]) -> bool { + match gemm_impl { + MetalGemmImplKind::Mlx => MlxGemm.is_supported_dts(in_facts), + MetalGemmImplKind::Mps => MpsMatMul.is_supported_dts(in_facts), + MetalGemmImplKind::Mfa => MfaGemm.is_supported_dts(in_facts), + MetalGemmImplKind::Ggml => { + GgmlGemm.is_supported_dts(in_facts) + || GgmlGemm.is_supported_dts(&[in_facts[1].clone(), in_facts[0].clone()]) + } + } } fn convert_matmul_to_metal( @@ -355,29 +360,44 @@ fn convert_matmul_to_metal( op: &BasicMatMul, gemm_impl: MetalGemmImplKind, ) -> TractResult> { - let matmul: Box = match gemm_impl { + let mut matmul_output = match gemm_impl { MetalGemmImplKind::Mlx => { - Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + let op = ops::MetalGemm::::new(op.transpose_a, op.transpose_b); + target.wire_node(node.name.clone(), op, inputs)? } MetalGemmImplKind::Mps => { - Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + let op = ops::MetalGemm::::new(op.transpose_a, op.transpose_b); + target.wire_node(node.name.clone(), op, inputs)? } MetalGemmImplKind::Mfa => { - Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + let op = ops::MetalGemm::::new(op.transpose_a, op.transpose_b); + target.wire_node(node.name.clone(), op, inputs)? } MetalGemmImplKind::Ggml => { - let input_facts: tract_smallvec::SmallVec<[&TypedFact; 4]> = + let mut input_facts: tract_smallvec::SmallVec<[&TypedFact; 4]> = model.node_input_facts(node.id)?; + let mut swap_inputs = false; + if !GgmlGemm.is_supported_dts(&[input_facts[0].clone(), input_facts[1].clone()]) + && GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()]) + { + input_facts.swap(0, 1); + inputs.swap(0, 1); + swap_inputs = true; + } + + let a_pos = swap_inputs as usize; + let b_pos = 1 - swap_inputs as usize; if op.transpose_a { - let rank = input_facts[0].rank(); - let perm_a_op: Box = - Box::new(ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( - rank - 2, - rank - 1, - ))); + assert!(as_q40_fact(input_facts[b_pos]).is_none(), "Cannot transpose Q40 tensor"); + + let rank = input_facts[a_pos].rank(); + let perm_a_op = ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( + rank - 2, + rank - 1, + )); let perm_a_name = node.name.clone() + ".perm_a"; - inputs[0] = target.wire_node(perm_a_name, perm_a_op, &[inputs[0]])?[0]; + inputs[a_pos] = target.wire_node(perm_a_name, perm_a_op, &[inputs[a_pos]])?[0]; } if input_facts[0].datum_type == DatumType::F16 { @@ -387,25 +407,44 @@ fn convert_matmul_to_metal( } if !op.transpose_b { - let rank = input_facts[1].rank(); - let perm_b_op: Box = - Box::new(ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( - rank - 2, - rank - 1, - ))); + assert!(as_q40_fact(input_facts[b_pos]).is_none(), "Cannot transpose Q40 tensor"); + + let rank = input_facts[b_pos].rank(); + let perm_b_op = ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( + rank - 2, + rank - 1, + )); let perm_b_name = node.name.clone() + ".perm_b"; - inputs[1] = target.wire_node(perm_b_name, perm_b_op, &[inputs[1]])?[0]; + inputs[b_pos] = target.wire_node(perm_b_name, perm_b_op, &[inputs[b_pos]])?[0]; + } + let op = ops::MetalGemm::::new(false, true); + let mut matmul_output = target.wire_node(node.name.clone(), op, inputs)?; + + if swap_inputs { + let out_fact = target.outlet_fact(matmul_output[0])?; + let rank = &out_fact + .opaque_fact + .clone() + .map(|fact| fact.clarify_dt_shape().unwrap().1.len()) + .unwrap(); + + let perm_out_op = ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( + rank - 2, + rank - 1, + )); + matmul_output = target.wire_node( + node.name.clone() + ".perm_out", + perm_out_op, + &matmul_output, + )?; } - Box::new(ops::MetalGemm::::new(false, true)) + matmul_output } }; - let new_in_facts = [target.outlet_fact(inputs[0])?, target.outlet_fact(inputs[1])?]; - - let out_fact = &matmul.output_facts(&new_in_facts)?[0]; + let out_fact = target.outlet_fact(matmul_output[0])?; let out_dt = out_fact.to_metal_fact().map(|f| f.datum_type).unwrap_or(out_fact.datum_type); - let mut matmul_output = target.wire_node(node.name.clone(), matmul, inputs)?; let expected_dt = model.node_output_facts(node.id)?[0].datum_type; if out_dt != expected_dt { @@ -445,8 +484,19 @@ fn convert_logic_ops_to_metal(op: &Comp) -> ops::MetalBinOp { } fn convert_const(op: &Const) -> TractResult { - let metal_fact = MetalFact::from_cpu(Arc::clone(&op.0).into())?; - let metal_const = op.0.clone().into_metal()?.into_opaque_tensor().into_arc_tensor(); + let (tensor, metal_fact) = if let Some(curr_bqv) = as_q40_tensor(op.0.view().tensor) { + let mut bqv = curr_bqv.clone(); + crate::utils::tract_to_gguf_q4_0_packing(&mut (bqv.value))?; + + let bqv = BlockQuantValue { value: bqv.value, fact: bqv.fact }; + ( + tensor0(Opaque(Arc::new(bqv))).broadcast_into_rank(op.0.rank())?.into_arc_tensor(), + MetalFact::from_cpu(Arc::clone(&op.0).into())?, + ) + } else { + (op.0.clone(), MetalFact::from_cpu(Arc::clone(&op.0).into())?) + }; + let metal_const = tensor.into_metal()?.into_opaque_tensor().into_arc_tensor(); Ok(Const::new_with_opaque_fact(metal_const, Box::new(metal_fact))) } diff --git a/metal/src/utils.rs b/metal/src/utils.rs index f000a9c626..a8569eea26 100644 --- a/metal/src/utils.rs +++ b/metal/src/utils.rs @@ -1,6 +1,7 @@ use crate::fact::{MetalFact, MetalOrigin, MetalTypedFactExt}; use num_traits::{AsPrimitive, Zero}; use tract_core::internal::*; +use tract_linalg::frame::block_quant::{BlockQuantFact, BlockQuantValue, Q4_0}; #[macro_export] macro_rules! impl_eval_op_for_metal_op { @@ -95,6 +96,58 @@ where .collect::>()) } +pub fn as_q40_fact(fact: &TypedFact) -> Option<&BlockQuantFact> { + fact.opaque_fact + .as_ref() + .and_then(|of| of.downcast_ref::()) + .and_then(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }) + .or_else(|| { + fact.konst + .as_ref() + .and_then(|k| k.to_scalar::().ok()) + .and_then(|o| o.downcast_ref::()) + .map(|v| &v.fact) + .and_then(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }) + }) +} + +pub fn as_q40_tensor(a: &Tensor) -> Option<&BlockQuantValue> { + a.to_scalar::().ok().and_then(|od| { + od.downcast_ref::().and_then(|bqv| { + if bqv.fact.format.same_as(&Q4_0) { + Some(bqv) + } else { + None + } + }) + }) +} + +pub fn tract_to_gguf_q4_0_packing(data: &mut Blob) -> TractResult<()> { + let block_size = 18; + ensure!(data.layout().size() % block_size == 0); + + let n_block = data.layout().size() / block_size; + let data_bytes = data.as_bytes_mut(); + + for b in 0..n_block { + let offset = b * block_size + 2; + let nibbles = &mut data_bytes[offset..offset + 16]; + let second_part: &mut [u8; 8] = &mut [0; 8]; + second_part.clone_from_slice(&nibbles[8..]); + for i in (0..16).rev() { + let lsb = if i % 2 == 0 { nibbles[i / 2] & 0x0F } else { (nibbles[i / 2] & 0xF0) >> 4 }; + let msb = if i % 2 == 0 { + (second_part[i / 2] & 0x0F) << 4 + } else { + second_part[i / 2] & 0xF0 + }; + nibbles[i] = msb | lsb; + } + } + Ok(()) +} + pub fn rescale_gpu_duration( pass_duration: u64, cpu_start: u64,