From a65e56602743eca063759aab282b2b0d343a4604 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 5 Feb 2025 12:59:02 +0100 Subject: [PATCH 01/15] fixes dt bugs --- metal/src/kernels/matmul/ggml_gemm/mod.rs | 6 +++--- metal/src/ops/gemm.rs | 2 +- metal/src/transform.rs | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/metal/src/kernels/matmul/ggml_gemm/mod.rs b/metal/src/kernels/matmul/ggml_gemm/mod.rs index 00b59fd500..671847aeac 100644 --- a/metal/src/kernels/matmul/ggml_gemm/mod.rs +++ b/metal/src/kernels/matmul/ggml_gemm/mod.rs @@ -36,10 +36,10 @@ impl GemmKernel for GgmlGemm { fn is_supported_dts(&self, dts: &[DatumType]) -> TractResult { ensure!(dts.len() == 2); - if dts[0] == DatumType::F32 { - Ok(dts[1] == DatumType::F32) + if dts[1] == DatumType::F32 { + Ok(dts[0] == DatumType::F32) } else { - Ok(dts[0] == DatumType::F16 && matches!(dts[1], DatumType::F32 | DatumType::F16)) + Ok(dts[1] == DatumType::F16 && matches!(dts[0], DatumType::F32 | DatumType::F16)) } } diff --git a/metal/src/ops/gemm.rs b/metal/src/ops/gemm.rs index fc82a27d8d..223de457fa 100644 --- a/metal/src/ops/gemm.rs +++ b/metal/src/ops/gemm.rs @@ -69,7 +69,7 @@ impl MetalEvalOp for MetalGemm { let b = b_opaque .to_metal_tensor() .with_context(|| anyhow!("B tensor is not a metal tensor {:?}", b_opaque))?; - let c_dt = a.datum_type(); + let c_dt = self.kernel.matmul.output_dt(a.datum_type(), b.datum_type())?; let c_shape = self.kernel.output_shape(a.shape(), b.shape()); let c = crate::ops::make_tensor_for_node(session, node_id, c_dt, &c_shape)?; self.kernel.dispatch_eval(context, a, b, &c)?; diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 090e680671..7eb9d5fef5 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -342,7 +342,7 @@ fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, dts: &[DatumType]) -> bool MetalGemmImplKind::Mlx => MlxGemm.is_supported_dts(dts), MetalGemmImplKind::Mps => MpsMatMul.is_supported_dts(dts), MetalGemmImplKind::Mfa => MfaGemm.is_supported_dts(dts), - MetalGemmImplKind::Ggml => MfaGemm.is_supported_dts(dts), + MetalGemmImplKind::Ggml => GgmlGemm.is_supported_dts(dts), }; is_supported.unwrap_or(false) } From 355175ca2f30122c20bc769382147225829865fb Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Fri, 7 Feb 2025 09:53:13 +0100 Subject: [PATCH 02/15] (Really) dirty implem for GGML Q4 support --- core/src/ops/einsum/as_matmul.rs | 29 +++- .../kernels/matmul/ggml_gemm/ggml_mm_mv.metal | 144 +++++++++++++++++ metal/src/kernels/matmul/ggml_gemm/mod.rs | 152 +++++++++++++++--- metal/src/kernels/matmul/mod.rs | 33 ++-- metal/src/ops/gemm.rs | 62 +++++-- metal/src/tensor/mod.rs | 4 +- metal/src/tensor/owned.rs | 16 +- metal/src/transform.rs | 19 +-- metal/src/utils.rs | 48 ++++++ 9 files changed, 445 insertions(+), 62 deletions(-) diff --git a/core/src/ops/einsum/as_matmul.rs b/core/src/ops/einsum/as_matmul.rs index c3f955f715..8bcdf11dd1 100644 --- a/core/src/ops/einsum/as_matmul.rs +++ b/core/src/ops/einsum/as_matmul.rs @@ -279,7 +279,7 @@ impl TypedOp for BasicMatMul { let [a, b] = inputs else { bail!("Expects 2 inputs"); }; - if a.datum_type.is_number() { + if a.datum_type.is_number() && b.datum_type.is_number(){ ensure!(a.rank() == b.rank()); ensure!(a.rank() >= 2); ensure!( @@ -313,9 +313,32 @@ impl TypedOp for BasicMatMul { .quantize_output .unwrap_or(b.datum_type) .fact(self.output_shape(&a_shape, &b.shape)))) + } else if let Some(opf) = inputs[1] + .opaque_fact + .as_ref() + .and_then(|of| of.downcast_ref::()) + .or_else(|| { + inputs[1] + .konst + .as_ref() + .and_then(|k| k.to_scalar::().ok()) + .and_then(|o| o.downcast_ref::()) + .map(|v| &v.fact) + }) + { + let b_shape: ShapeFact = b + .shape + .iter() + .cloned() + .chain(opf.shape.iter().map(|d| d.to_dim())) + .collect(); + Ok(tvec!(self + .quantize_output + .unwrap_or(a.datum_type) + .fact(self.output_shape(&a.shape, &b_shape)))) } else { - todo!() - } + todo!() + } } as_op!(); diff --git a/metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal b/metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal index a016769a0c..c42c57c676 100644 --- a/metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal +++ b/metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal @@ -4,6 +4,14 @@ using namespace metal; +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +#define QK4_0 32 +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; + typedef struct { int32_t batch; int32_t m; @@ -207,6 +215,122 @@ typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); + + for (int i = 0; i < 8; i += 2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); + } + + return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); +} + +// putting them in the kernel cause a significant performance penalty +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// guard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template +void mul_vec_q_n_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const int nb = args.k/QK4_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%args.batch; + const uint i13 = im/args.batch; + + //const uint64_t offset0 = first_row*args. + (i12/args.channel_broadcast_ratio)*args.b_strides[1] + (i13/args.batch_broadcast_ratio)*args.b_strides[0]; + const uint64_t offset1 = r1*args.a_strides[2] + (i12 )*args.a_strides[1] + (i13 )*args.a_strides[0]; + + //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_q_type * ax[nr]; + for (int row = 0; row < nr; ++row) { + const uint64_t offset0 = (first_row + row)*args.b_strides[2] + (i12/args.channel_broadcast_ratio)*args.b_strides[1] + (i13/args.batch_broadcast_ratio)*args.b_strides[0]; + + ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); + } + + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; + + device const float * yb = y + ix*QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + float sumy[2] = { 0.f, 0.f }; + +#pragma unroll + for (int i = 0; i < 8; i += 2) { + sumy[0] += yb[i + 0] + yb[i + 1]; + yl[i + 0] = yb[i + 0]; + yl[i + 1] = yb[i + 1]/256.f; + + sumy[1] += yb[i + 16] + yb[i + 17]; + yl[i + 8] = yb[i + 16]/16.f; + yl[i + 9] = yb[i + 17]/4096.f; + } + +#pragma unroll + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); + } + + yb += QK4_0 * 16; + } + + device float * dst_f32 = (device float *) dst + im*args.n*args.m + r1*args.n; + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.n) { + dst_f32[first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mv_q4_0_f32( + constant ggml_metal_kargs_mul & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32 @@ -372,7 +496,27 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) reg = (type4x4)(*src); } +template +void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md; + reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md; + } + + reg = (type4x4) reg_f; +} + typedef decltype(kernel_mul_mm) mat_mm_t; template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; \ No newline at end of file diff --git a/metal/src/kernels/matmul/ggml_gemm/mod.rs b/metal/src/kernels/matmul/ggml_gemm/mod.rs index 671847aeac..3dae55e5bc 100644 --- a/metal/src/kernels/matmul/ggml_gemm/mod.rs +++ b/metal/src/kernels/matmul/ggml_gemm/mod.rs @@ -1,4 +1,5 @@ use crate::kernels::matmul::{GemmDispatchParams, GemmKernel}; +use crate::utils::is_q4_0; use crate::MetalTensor; use crate::{LibraryName, MetalContext}; use anyhow::{ensure, Result}; @@ -33,13 +34,16 @@ impl GemmKernel for GgmlGemm { "ggml" } - fn is_supported_dts(&self, dts: &[DatumType]) -> TractResult { - ensure!(dts.len() == 2); + fn is_supported_dts(&self, facts: &[TypedFact]) -> TractResult { + ensure!(facts.len() == 2); - if dts[1] == DatumType::F32 { - Ok(dts[0] == DatumType::F32) + if is_q4_0(facts[1].clone()) { + Ok(matches!(facts[0].datum_type, DatumType::F16 | DatumType::F32)) + } else if facts[1].datum_type == DatumType::F32 { + Ok(facts[0].datum_type == DatumType::F32) } else { - Ok(dts[1] == DatumType::F16 && matches!(dts[0], DatumType::F32 | DatumType::F16)) + Ok(facts[1].datum_type == DatumType::F16 + && matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16)) } } @@ -79,13 +83,24 @@ impl GemmKernel for GgmlGemm { a_el_size as u64, ]; - let b_el_size = dts[1].size_of(); - let b_strides = [ - (batch * n * k * b_el_size) as u64, - (n * k * b_el_size) as u64, - (k * b_el_size) as u64, - b_el_size as u64, - ]; + let b_strides = if dts[1] == DatumType::Opaque { + // Assume Opaque == Q4. TODO + let b_el_size = 18; + [ + (batch * n * (k / 32) * b_el_size) as u64, + (n * (k / 32) * b_el_size) as u64, + (b_el_size * (k / 32)) as u64, + b_el_size as u64, + ] + } else { + let b_el_size = dts[1].size_of(); + [ + (batch * n * k * b_el_size) as u64, + (n * k * b_el_size) as u64, + (b_el_size * k) as u64, + b_el_size as u64, + ] + }; let params = GgmlParams { batch: batch as i32, @@ -121,8 +136,7 @@ fn mv_kernel_name_and_dispatch_params( if dts[1] == DatumType::F32 { ensure!(dts[0] == DatumType::F32); Ok(("kernel_mul_mv_f32_f32".to_string(), (nth0, nth1, 4))) - } else { - ensure!(dts[1] == DatumType::F16); + } else if dts[1] == DatumType::F16 { if dts[0] == DatumType::F32 { if (params.m * params.batch) < 4 { Ok(("kernel_mul_mv_f16_f32_1row".to_string(), (nth0, nth1, nrows))) @@ -136,6 +150,9 @@ fn mv_kernel_name_and_dispatch_params( ensure!(dts[1] == DatumType::F16); Ok(("kernel_mul_mv_f16_f16".to_string(), (nth0, nth1, 4))) } + } else { + ensure!((dts[1] == DatumType::Opaque) && (dts[0] == DatumType::F32)); + Ok(("kernel_mul_mv_q4_0_f32".to_string(), (8, 8, 1))) } } @@ -159,16 +176,27 @@ fn dispatch_metal_ggml_gemv( let command_buffer = context.command_buffer(); command_buffer.encode(|encoder| { encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, std::mem::size_of::() as u64, ¶ms as *const _ as *const _); + encoder.set_bytes( + 0, + std::mem::size_of::() as u64, + ¶ms as *const _ as *const _, + ); encoder.set_buffer(1, Some(b_buffer), b_offset as NSUInteger); encoder.set_buffer(2, Some(a_buffer), a_offset as NSUInteger); encoder.set_buffer(3, Some(output), output_offset as NSUInteger); - let ny = (params.m as u64).div_ceil(nrows); - let grid_size = MTLSize { - width: params.n as u64, - height: ny, - depth: /* batch_size_out */ params.batch as u64, + let grid_size = if dts[1] != DatumType::Opaque { + MTLSize { + width: params.n as u64, + height: (params.m as u64).div_ceil(nrows), + depth: /* batch_size_out */ params.batch as u64, + } + } else { + MTLSize { + width: (params.n as u64).div_ceil(8), + height: params.m as u64, + depth: /* batch_size_out */ params.batch as u64, + } }; let group_size = MTLSize { width: nth0, height: nth1, depth: 1 }; @@ -191,19 +219,30 @@ fn dispatch_metal_ggml_gemm( output: &Buffer, output_offset: usize, ) -> Result<()> { - ensure!(matches!(dts[1], DatumType::F32 | DatumType::F16) && dts[0] == DatumType::F32); + // Warning: currently assuming opaque == q4_0, + ensure!( + matches!(dts[1], DatumType::F32 | DatumType::F16 | DatumType::Opaque) + && dts[0] == DatumType::F32 + ); - let i1_tname = MetalTensor::tname(dts[1])?; + let mut i1_tname = MetalTensor::tname(dts[1])?; let i2_tname = MetalTensor::tname(dts[0])?; + if i1_tname == "opaque" { + i1_tname = "q4_0"; + } let name = format!("kernel_mul_mm_{i1_tname}_{i2_tname}"); - + //dbg!(&name); let pipeline = context.shared_context().load_pipeline(LibraryName::Ggml, &name)?; let command_buffer = context.command_buffer(); command_buffer.encode(|encoder| { encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, std::mem::size_of::() as u64, ¶ms as *const _ as *const _); + encoder.set_bytes( + 0, + std::mem::size_of::() as u64, + ¶ms as *const _ as *const _, + ); encoder.set_buffer(1, Some(b_buffer), b_offset as NSUInteger); encoder.set_buffer(2, Some(a_buffer), a_offset as NSUInteger); encoder.set_buffer(3, Some(output), output_offset as NSUInteger); @@ -225,8 +264,13 @@ fn dispatch_metal_ggml_gemm( #[cfg(test)] mod tests { + use tract_core::ops::einsum::BasicMatMul; + use tract_linalg::frame::block_quant::{BlockQuant, BlockQuantFact, BlockQuantValue, Q4_0}; + use super::*; use crate::kernels::matmul::tests::run_mmm_test_case; + use crate::kernels::matmul::GemmImpl; + use crate::IntoMetal; #[test] fn test_ggml_compilation() -> Result<()> { @@ -296,4 +340,64 @@ mod tests { )?; Ok(()) } + + fn run_mat_mul_q4_test(batch: usize, m: usize, k: usize, n: usize) -> TractResult<()> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + ensure!(k % 32 == 0); + let a_shape = [batch, m, k]; + let b_shape = [batch, n, k]; + + let a_data = (0..batch * k * m) + .map(|f| f as f32 / (batch * m * k) as f32) + .collect::>(); + + let a = Tensor::from_shape(&a_shape, &a_data)?; + + let b_data = (0..batch * n * k) + .map(|f| f as f32 / (batch * n * k) as f32) + .collect::>(); + + let mut b_quant = Q4_0.quant_f32(&b_data)?; + + let b_tensor = Tensor::from_shape(&b_shape, &b_data)?; + + crate::utils::tract_to_gguf_q4_0_packing(&mut b_quant)?; + + let b_q4_0_tensor = tensor0(Opaque(Arc::new(BlockQuantValue { + fact: BlockQuantFact { format: Box::new(Q4_0), shape: tvec![batch, n, k] }, + value: b_quant, + }))); + + let metal_output = GemmImpl::::new(false, true).eval( + context, + &a.clone().into_metal()?, + &b_q4_0_tensor.clone().into_metal()?, + )?; + + let matmul = BasicMatMul { + transpose_a: false, + transpose_b: true, + transpose_c: false, + quantize_output: None, + }; + + let output = args_1!(matmul.eval(tvec![a.into_tvalue(), b_tensor.into_tvalue()])?); + metal_output.to_cpu()?.close_enough(&output, Approximation::SuperApproximate)?; + Ok(()) + }) + }) + } + + #[test] + fn test_q4() -> TractResult<()> { + run_mat_mul_q4_test(32, 1, 32, 32)?; + run_mat_mul_q4_test(1, 32003, 2048, 1)?; + run_mat_mul_q4_test(4, 1, 2048, 32003)?; + run_mat_mul_q4_test(1, 1, 32, 32)?; + run_mat_mul_q4_test(1, 1, 64, 4)?; + run_mat_mul_q4_test(3, 1, 4096, 4096)?; + + Ok(()) + } } diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 1d12bc0551..0fe6b2db24 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -12,6 +12,7 @@ pub use mlx_gemm::MlxGemm; pub use mmm_tile_8x8::{metal_mmm_tile_8x8, mmm_tile_8x8}; pub use mps::MpsMatMul; +use crate::utils::resolve_tensor_shape; use crate::{MetalContext, MetalTensor}; use metal::Buffer; use num_traits::One; @@ -28,7 +29,7 @@ pub enum MetalGemmImplKind { impl Default for MetalGemmImplKind { fn default() -> Self { - Self::Mlx + Self::Ggml } } @@ -151,9 +152,11 @@ impl GemmDispatchParams { pub trait GemmKernel: fmt::Display + fmt::Debug + Clone + Default + Send + Sync { fn name() -> &'static str; - fn is_supported_dts(&self, dts: &[DatumType]) -> TractResult { - ensure!(dts.len() == 2); - Ok(matches!(dts[0], DatumType::F32 | DatumType::F16) && dts[0] == dts[1]) + fn is_supported_dts(&self, facts: &[TypedFact]) -> TractResult { + ensure!(facts.len() == 2); + ensure!(facts.iter().all(|f| f.datum_type != DatumType::Opaque)); + Ok(matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16) + && facts[0].datum_type == facts[1].datum_type) } fn output_dt(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult { @@ -205,14 +208,18 @@ impl GemmImpl { output } - pub fn output_facts(&self, a: &TypedFact, b: &TypedFact) -> TractResult> { - let out_shape = self.output_shape(&a.shape, &b.shape).to_vec(); - let out_dt = self.matmul.output_dt(a.datum_type().unwrap(), b.datum_type().unwrap())?; + pub fn output_facts( + &self, + shape: &[TDim], + a_dt: DatumType, + b_dt: DatumType, + ) -> TractResult> { + let out_dt = self.matmul.output_dt(a_dt, b_dt)?; if out_dt == DatumType::F32 { - Ok(tvec!(f32::fact(out_shape))) + Ok(tvec!(f32::fact(shape))) } else { ensure!(out_dt == DatumType::F16); - Ok(tvec!(f16::fact(out_shape))) + Ok(tvec!(f16::fact(shape))) } } @@ -223,7 +230,7 @@ impl GemmImpl { b: &MetalTensor, ) -> TractResult { let c_dt = self.matmul.output_dt(a.datum_type(), b.datum_type())?; - let c_shape = self.output_shape(a.shape(), b.shape()); + let c_shape = self.output_shape(a.shape(), &resolve_tensor_shape(b)); let c = unsafe { MetalTensor::uninitialized_dt(c_dt, &c_shape)? }; self.dispatch_eval(context, a, b, &c)?; @@ -241,8 +248,8 @@ impl GemmImpl { a.retain_until_completion(); b.retain_until_completion(); c.retain_until_completion(); - - ensure!(c.shape() == self.output_shape(a.shape(), b.shape()).as_slice()); + let b_shape = resolve_tensor_shape(b); + ensure!(c.shape() == self.output_shape(a.shape(), &b_shape).as_slice()); if c.shape().iter().product::() == 0 { return Ok(()); @@ -254,7 +261,7 @@ impl GemmImpl { a.shape(), self.transpose_a, b.metal_offset(), - b.shape(), + &b_shape, self.transpose_b, c.metal_offset(), c.shape(), diff --git a/metal/src/ops/gemm.rs b/metal/src/ops/gemm.rs index 223de457fa..e63c7bbd50 100644 --- a/metal/src/ops/gemm.rs +++ b/metal/src/ops/gemm.rs @@ -1,9 +1,11 @@ use crate::kernels::matmul::{GemmImpl, GemmKernel}; use crate::ops::MetalEvalOp; +use crate::utils::resolve_tensor_shape; use crate::{MetalContext, MetalTensorExt}; use anyhow::{bail, ensure}; use tract_core::internal::*; +use tract_core::tract_linalg::frame::block_quant::{BlockQuantFact, BlockQuantValue}; #[derive(Debug, Default, Clone)] pub struct MetalGemm { @@ -43,17 +45,56 @@ impl MetalGemm { let [a, b] = inputs else { bail!("Expects 2 inputs"); }; - ensure!(a.rank() == b.rank()); - ensure!(a.rank() >= 2); - ensure!( - a.shape[a.rank() - 2 + !self.transpose_a() as usize] - == b.shape[b.rank() - 2 + self.transpose_b() as usize] - ); - - self.kernel.output_facts(a, b) + + if a.datum_type.is_number() && b.datum_type.is_number() { + ensure!(a.rank() == b.rank()); + ensure!(a.rank() >= 2); + ensure!( + a.shape[a.rank() - 2 + !self.transpose_a() as usize] + == b.shape[b.rank() - 2 + self.transpose_b() as usize] + ); + let out_shape = self.kernel.output_shape(&a.shape, &b.shape); + Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?) + } else if let Some(opf) = inputs[0] + .opaque_fact + .as_ref() + .and_then(|of| of.downcast_ref::()) + .or_else(|| { + inputs[0] + .konst + .as_ref() + .and_then(|k| k.to_scalar::().ok()) + .and_then(|o| o.downcast_ref::()) + .map(|v| &v.fact) + }) + { + let a_shape: ShapeFact = + a.shape.iter().cloned().chain(opf.shape.iter().map(|d| d.to_dim())).collect(); + + let out_shape = self.kernel.output_shape(&a_shape, &b.shape); + Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?) + } else if let Some(opf) = inputs[1] + .opaque_fact + .as_ref() + .and_then(|of| of.downcast_ref::()) + .or_else(|| { + inputs[1] + .konst + .as_ref() + .and_then(|k| k.to_scalar::().ok()) + .and_then(|o| o.downcast_ref::()) + .map(|v| &v.fact) + }) + { + let b_shape: ShapeFact = + b.shape.iter().cloned().chain(opf.shape.iter().map(|d| d.to_dim())).collect(); + let out_shape = self.kernel.output_shape(&a.shape, &b_shape); + Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?) + } else { + todo!() + } } } - impl MetalEvalOp for MetalGemm { fn metal_eval( &self, @@ -70,7 +111,8 @@ impl MetalEvalOp for MetalGemm { .to_metal_tensor() .with_context(|| anyhow!("B tensor is not a metal tensor {:?}", b_opaque))?; let c_dt = self.kernel.matmul.output_dt(a.datum_type(), b.datum_type())?; - let c_shape = self.kernel.output_shape(a.shape(), b.shape()); + + let c_shape = self.kernel.output_shape(a.shape(), &resolve_tensor_shape(&b)); let c = crate::ops::make_tensor_for_node(session, node_id, c_dt, &c_shape)?; self.kernel.dispatch_eval(context, a, b, &c)?; Ok(tvec![c.into_opaque_tensor().into_tvalue()]) diff --git a/metal/src/tensor/mod.rs b/metal/src/tensor/mod.rs index 2a054d2110..e242f0ee05 100644 --- a/metal/src/tensor/mod.rs +++ b/metal/src/tensor/mod.rs @@ -21,7 +21,7 @@ pub enum MetalTensor { } impl MetalTensor { - pub const SUPPORTED_DT: [DatumType; 11] = [ + pub const SUPPORTED_DT: [DatumType; 12] = [ DatumType::Bool, DatumType::F32, DatumType::F16, @@ -33,6 +33,7 @@ impl MetalTensor { DatumType::U32, DatumType::I64, DatumType::U64, + DatumType::Opaque, ]; pub fn tname(dt: DatumType) -> TractResult<&'static str> { @@ -48,6 +49,7 @@ impl MetalTensor { DatumType::I32 => "i32", DatumType::I64 => "i64", DatumType::Bool => "bool", + DatumType::Opaque => "opaque", _ => bail!("Unsupport dt {:?} for metal kernel function", dt), }) } diff --git a/metal/src/tensor/owned.rs b/metal/src/tensor/owned.rs index a936eb0063..4da61211a9 100644 --- a/metal/src/tensor/owned.rs +++ b/metal/src/tensor/owned.rs @@ -4,6 +4,7 @@ use metal::Buffer; use num_traits::AsPrimitive; use std::fmt::Display; use tract_core::internal::*; +use tract_linalg::frame::block_quant::BlockQuantValue; #[derive(Debug, Clone, Hash)] pub enum MValue { @@ -129,14 +130,25 @@ impl OwnedMetalTensor { /// Create a owned metal tensor from a cpu tensor. pub fn from_tensor>(tensor: T) -> Result { crate::METAL_CONTEXT.with_borrow(|ctxt| { - let m_value = tensor.into(); + let m_value: MValue = tensor.into(); let tensor_view = m_value.view(); ensure!( MetalTensor::is_supported_dt(tensor_view.datum_type()), "Tensor of {:?} is not copied. No Metal buffer can be allocated for it.", tensor_view.datum_type(), ); - let buffer = ctxt.buffer_from_slice(tensor_view.tensor.as_bytes()); + + let data_bytes = if tensor_view.datum_type() == DatumType::Opaque { + &tensor_view + .tensor + .to_scalar::() + .map(|od| od.downcast_ref::().unwrap())? + .value + } else { + tensor_view.tensor.as_bytes() + }; + + let buffer = ctxt.buffer_from_slice(data_bytes); Ok(OwnedMetalTensor { inner: m_value, metal: buffer }) }) } diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 7eb9d5fef5..be59c71554 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -193,13 +193,14 @@ fn can_translate_to_metal_op( node: &TypedNode, gemm_impl: MetalGemmImplKind, ) -> TractResult { - let input_dts = source - .node_input_facts(node.id)? + let input_facts = source.node_input_facts(node.id)?.iter().map(|f| (*f).clone()).collect_vec(); + let input_dts = input_facts .iter() .map(|f| f.as_metal_fact().map(|f| f.datum_type).unwrap_or(f.datum_type)) .collect_vec(); - let in_dts_metal_compatible = input_dts.iter().all(|dt| MetalTensor::is_supported_dt(*dt)); + let in_dts_metal_compatible = + input_facts.iter().all(|fact| MetalTensor::is_supported_dt(fact.datum_type)); Ok(in_dts_metal_compatible && (node @@ -211,7 +212,7 @@ fn can_translate_to_metal_op( || node.op_as::().is_some_and(|op| { !op.transpose_c && op.quantize_output.is_none() - && check_matmul_in_dts(gemm_impl, &input_dts) + && check_matmul_in_dts(gemm_impl, &input_facts) }) || node .op_as::() @@ -337,12 +338,12 @@ macro_rules! map_element_wise_ops { }; } -fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, dts: &[DatumType]) -> bool { +fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, in_facts: &[TypedFact]) -> bool { let is_supported = match gemm_impl { - MetalGemmImplKind::Mlx => MlxGemm.is_supported_dts(dts), - MetalGemmImplKind::Mps => MpsMatMul.is_supported_dts(dts), - MetalGemmImplKind::Mfa => MfaGemm.is_supported_dts(dts), - MetalGemmImplKind::Ggml => GgmlGemm.is_supported_dts(dts), + MetalGemmImplKind::Mlx => MlxGemm.is_supported_dts(in_facts), + MetalGemmImplKind::Mps => MpsMatMul.is_supported_dts(in_facts), + MetalGemmImplKind::Mfa => MfaGemm.is_supported_dts(in_facts), + MetalGemmImplKind::Ggml => GgmlGemm.is_supported_dts(in_facts), }; is_supported.unwrap_or(false) } diff --git a/metal/src/utils.rs b/metal/src/utils.rs index f000a9c626..5a17c98cf0 100644 --- a/metal/src/utils.rs +++ b/metal/src/utils.rs @@ -1,6 +1,9 @@ use crate::fact::{MetalFact, MetalOrigin, MetalTypedFactExt}; +use crate::MetalTensor; use num_traits::{AsPrimitive, Zero}; use tract_core::internal::*; +use tract_data::itertools::Itertools; +use tract_linalg::frame::block_quant::{BlockQuantFact, BlockQuantValue, Q4_0}; #[macro_export] macro_rules! impl_eval_op_for_metal_op { @@ -95,6 +98,51 @@ where .collect::>()) } +pub fn is_q4_0(fact: TypedFact) -> bool { + fact.opaque_fact.is_some_and(|of| { + of.downcast_ref::().map(|bqf| bqf.format.same_as(&Q4_0)).unwrap_or(false) + }) +} + +pub fn resolve_tensor_shape(a: &MetalTensor) -> Vec { + a.view() + .tensor + .to_scalar::() + .map(|od| { + od.downcast_ref::() + .map(|bqv| { + a.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect_vec() + }) + .unwrap_or(a.shape().to_vec()) + }) + .unwrap_or(a.shape().to_vec()) +} + +pub fn tract_to_gguf_q4_0_packing(data: &mut Blob) -> TractResult<()> { + let block_size = 18; + ensure!(data.layout().size() % block_size == 0); + + let n_block = data.layout().size() / block_size; + let data_bytes = data.as_bytes_mut(); + + for b in 0..n_block { + let offset = b * block_size + 2; + let nibbles = &mut data_bytes[offset..offset + 16]; + let second_part: &mut [u8; 8] = &mut [0; 8]; + second_part.clone_from_slice(&nibbles[8..]); + for i in (0..16).rev() { + let lsb = if i % 2 == 0 { nibbles[i / 2] & 0x0F } else { nibbles[i / 2] & 0xF0 >> 4 }; + let msb = if i % 2 == 0 { + (second_part[i / 2] & 0x0F) << 4 + } else { + second_part[i / 2] & 0xF0 + }; + nibbles[i] = msb | lsb; + } + } + Ok(()) +} + pub fn rescale_gpu_duration( pass_duration: u64, cpu_start: u64, From a581fbb7674dff33147aac0aa5c1fba05a9dd15a Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Fri, 7 Feb 2025 10:32:21 +0100 Subject: [PATCH 03/15] repack Q4 const --- metal/src/tensor/owned.rs | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/metal/src/tensor/owned.rs b/metal/src/tensor/owned.rs index 4da61211a9..44174b623b 100644 --- a/metal/src/tensor/owned.rs +++ b/metal/src/tensor/owned.rs @@ -138,18 +138,22 @@ impl OwnedMetalTensor { tensor_view.datum_type(), ); - let data_bytes = if tensor_view.datum_type() == DatumType::Opaque { - &tensor_view - .tensor - .to_scalar::() - .map(|od| od.downcast_ref::().unwrap())? - .value + if tensor_view.datum_type() == DatumType::Opaque { + let mut q4_blob = tensor_view + .tensor + .to_scalar::() + .map(|od| od.downcast_ref::().unwrap())? + .value.clone(); + crate::utils::tract_to_gguf_q4_0_packing(&mut q4_blob)?; + + let buffer = ctxt.buffer_from_slice(&q4_blob); + Ok(OwnedMetalTensor { inner: m_value, metal: buffer }) } else { - tensor_view.tensor.as_bytes() - }; + let tensor_data = tensor_view.tensor.as_bytes(); - let buffer = ctxt.buffer_from_slice(data_bytes); - Ok(OwnedMetalTensor { inner: m_value, metal: buffer }) + let buffer = ctxt.buffer_from_slice(tensor_data); + Ok(OwnedMetalTensor { inner: m_value, metal: buffer }) + } }) } From 19389ed95256d20f25efb9c461a4d73bd62bdda0 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Fri, 7 Feb 2025 12:11:54 +0100 Subject: [PATCH 04/15] Revert "repack Q4 const" This reverts commit 6af09949ded4a014b6723f0423fdd4b56fe8bce4. --- metal/src/tensor/owned.rs | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/metal/src/tensor/owned.rs b/metal/src/tensor/owned.rs index 44174b623b..4da61211a9 100644 --- a/metal/src/tensor/owned.rs +++ b/metal/src/tensor/owned.rs @@ -138,22 +138,18 @@ impl OwnedMetalTensor { tensor_view.datum_type(), ); - if tensor_view.datum_type() == DatumType::Opaque { - let mut q4_blob = tensor_view - .tensor - .to_scalar::() - .map(|od| od.downcast_ref::().unwrap())? - .value.clone(); - crate::utils::tract_to_gguf_q4_0_packing(&mut q4_blob)?; - - let buffer = ctxt.buffer_from_slice(&q4_blob); - Ok(OwnedMetalTensor { inner: m_value, metal: buffer }) + let data_bytes = if tensor_view.datum_type() == DatumType::Opaque { + &tensor_view + .tensor + .to_scalar::() + .map(|od| od.downcast_ref::().unwrap())? + .value } else { - let tensor_data = tensor_view.tensor.as_bytes(); + tensor_view.tensor.as_bytes() + }; - let buffer = ctxt.buffer_from_slice(tensor_data); - Ok(OwnedMetalTensor { inner: m_value, metal: buffer }) - } + let buffer = ctxt.buffer_from_slice(data_bytes); + Ok(OwnedMetalTensor { inner: m_value, metal: buffer }) }) } From f178faa27040f09b591727e47c7244229a263fea Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Fri, 7 Feb 2025 18:01:37 +0100 Subject: [PATCH 05/15] add proptest --- metal/src/kernels/matmul/ggml_gemm/mod.rs | 2 +- metal/src/kernels/matmul/mod.rs | 108 +++++++++++++++++----- metal/src/utils.rs | 4 +- 3 files changed, 88 insertions(+), 26 deletions(-) diff --git a/metal/src/kernels/matmul/ggml_gemm/mod.rs b/metal/src/kernels/matmul/ggml_gemm/mod.rs index 3dae55e5bc..88ec0502ed 100644 --- a/metal/src/kernels/matmul/ggml_gemm/mod.rs +++ b/metal/src/kernels/matmul/ggml_gemm/mod.rs @@ -37,7 +37,7 @@ impl GemmKernel for GgmlGemm { fn is_supported_dts(&self, facts: &[TypedFact]) -> TractResult { ensure!(facts.len() == 2); - if is_q4_0(facts[1].clone()) { + if is_q4_0(&facts[1]) { Ok(matches!(facts[0].datum_type, DatumType::F16 | DatumType::F32)) } else if facts[1].datum_type == DatumType::F32 { Ok(facts[0].datum_type == DatumType::F32) diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 0fe6b2db24..ca198eb2a6 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -154,7 +154,6 @@ pub trait GemmKernel: fmt::Display + fmt::Debug + Clone + Default + Send + Sync fn is_supported_dts(&self, facts: &[TypedFact]) -> TractResult { ensure!(facts.len() == 2); - ensure!(facts.iter().all(|f| f.datum_type != DatumType::Opaque)); Ok(matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16) && facts[0].datum_type == facts[1].datum_type) } @@ -315,6 +314,8 @@ mod tests { use proptest::collection::vec; use proptest::prelude::*; use tract_core::ops::einsum::BasicMatMul; + use tract_core::tract_data::itertools::Itertools; + use tract_core::tract_linalg::frame::block_quant::{BlockQuant, BlockQuantFact, BlockQuantValue, Q4_0}; pub(crate) fn run_mmm_test_case( (batch, m, k, n): (usize, usize, usize, usize), @@ -618,6 +619,46 @@ mod tests { fn mmm_mlx_prop_f16(pb in any::>()) { prop_assert_eq!(pb.run().unwrap(), pb.reference().unwrap()) } + + #[test] + fn mmm_ggml_prop_f32(pb in >::arbitrary_with( + MmmProblemParams { + force_k_as_inner_axis: true, + q4_0_weights: false, + } + )) { + let output = pb.run().unwrap(); + let _ = output.close_enough(&pb.reference().unwrap(), Approximation::Approximate); + } + + #[test] + fn mmm_ggml_prop_f16(pb in >::arbitrary_with( + MmmProblemParams { + force_k_as_inner_axis: true, + q4_0_weights: false, + } + )) { + let output = pb.run().unwrap(); + let _ = output.close_enough(&pb.reference().unwrap(), Approximation::SuperApproximate); + } + + #[test] + fn mmm_ggml_prop_q4(pb in >::arbitrary_with( + MmmProblemParams { + force_k_as_inner_axis: true, + q4_0_weights: true, + } + )) { + let output = pb.run().unwrap(); + let _ = output.close_enough(&pb.reference().unwrap(), Approximation::SuperApproximate + ); + } + } + + #[derive(Default, Debug, Clone)] + pub struct MmmProblemParams { + pub force_k_as_inner_axis: bool, + pub q4_0_weights: bool, } #[derive(Debug, new)] @@ -634,6 +675,7 @@ mod tests { pub transpose_lhs: bool, pub rhs: Vec, pub transpose_rhs: bool, + pub q4_0: bool, pub _phantom: std::marker::PhantomData, } @@ -643,15 +685,18 @@ mod tests { F: Datum + Float, usize: AsPrimitive, { - type Parameters = (); + type Parameters = MmmProblemParams; type Strategy = BoxedStrategy; - fn arbitrary_with(_: ()) -> Self::Strategy { - (1usize..10, 1usize..20, 1usize..20, 1usize..20) - .prop_flat_map(|(b, m, k, n)| { + fn arbitrary_with(params: MmmProblemParams) -> Self::Strategy { + (1usize..4, 1usize..128, 1usize..512, 1usize..128) + .prop_flat_map(move |(b, m, mut k, n)| { + if params.q4_0_weights { k = k.div_ceil(32) * 32 }; + let lhs_len = b * m * k; let rhs_len = b * k * n; - let datum = (0usize..10).prop_map(|x| x.as_()); + + let datum = (0usize..10).prop_map(move |x| x.as_() / (b * m * k * n).as_()); ( Just(b), Just(m), @@ -663,16 +708,20 @@ mod tests { proptest::bool::ANY, ) }) - .prop_map(|(b, m, k, n, lhs, transpose_lhs, rhs, transpose_rhs)| Self { - b, - m, - k, - n, - lhs, - transpose_lhs, - rhs, - transpose_rhs, - _phantom: std::marker::PhantomData, + .prop_map(move |(b, m, k, n, lhs, mut transpose_lhs, rhs, mut transpose_rhs)| { + if params.force_k_as_inner_axis{ (transpose_lhs, transpose_rhs) = (false, true); } + Self { + b, + m, + k, + n, + lhs, + transpose_lhs, + rhs, + transpose_rhs, + q4_0: params.q4_0_weights, + _phantom: std::marker::PhantomData, + } }) .boxed() } @@ -684,7 +733,7 @@ mod tests { F: Datum + Float + std::ops::AddAssign, usize: AsPrimitive, { - pub fn reference(&self) -> Result> { + pub fn reference(&self) -> Result { let matmul = BasicMatMul { transpose_a: self.transpose_lhs, transpose_b: self.transpose_rhs, @@ -705,10 +754,10 @@ mod tests { let output = matmul.eval(tvec![lhs_tensor.into_tvalue(), rhs_tensor.into_tvalue()])?; - Ok(output[0].clone().into_tensor().as_slice::()?.to_vec()) + Ok(output[0].clone().into_tensor()) } - pub fn run(&self) -> Result> { + pub fn run(&self) -> Result { objc::rc::autoreleasepool(|| { crate::METAL_CONTEXT.with_borrow(|context| { let lhs = if self.transpose_lhs { @@ -717,15 +766,28 @@ mod tests { Tensor::from_shape(&[self.b, self.m, self.k], &self.lhs)?.into_metal()? }; let rhs = if self.transpose_rhs { - Tensor::from_shape(&[self.b, self.n, self.k], &self.rhs)?.into_metal()? + if !self.q4_0 { + Tensor::from_shape(&[self.b, self.n, self.k], &self.rhs)? + } + else { + dbg!(&self.rhs); + let mut b_quant = Q4_0.quant_f32(&self.rhs.clone().into_iter().map(|x| x.to_f32().unwrap()).collect_vec())?; + + crate::utils::tract_to_gguf_q4_0_packing(&mut b_quant)?; + + tensor0(Opaque(Arc::new(BlockQuantValue { + fact: BlockQuantFact { format: Box::new(Q4_0), shape: tvec![self.b, self.n, self.k] }, + value: b_quant, + }))) + } } else { - Tensor::from_shape(&[self.b, self.k, self.n], &self.rhs)?.into_metal()? - }; + Tensor::from_shape(&[self.b, self.k, self.n], &self.rhs)? + }.into_metal()?; let matmul = GemmImpl::::new(self.transpose_lhs, self.transpose_rhs); let c = matmul.eval(context, &lhs, &rhs)?; - Ok(c.to_cpu()?.as_slice::()?.to_vec()) + Ok(c.to_cpu()?) }) }) } diff --git a/metal/src/utils.rs b/metal/src/utils.rs index 5a17c98cf0..651820e477 100644 --- a/metal/src/utils.rs +++ b/metal/src/utils.rs @@ -98,8 +98,8 @@ where .collect::>()) } -pub fn is_q4_0(fact: TypedFact) -> bool { - fact.opaque_fact.is_some_and(|of| { +pub fn is_q4_0(fact: &TypedFact) -> bool { + fact.opaque_fact.as_ref().is_some_and(|of| { of.downcast_ref::().map(|bqf| bqf.format.same_as(&Q4_0)).unwrap_or(false) }) } From 4eec7de8f45e5e834472580098364ddb35fb9216 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 12 Feb 2025 13:45:39 +0100 Subject: [PATCH 06/15] fix q4 bug and allow input swap --- metal/src/kernels/matmul/mod.rs | 61 ++++++++++++------- metal/src/transform.rs | 100 ++++++++++++++++++++++++-------- metal/src/utils.rs | 2 +- 3 files changed, 118 insertions(+), 45 deletions(-) diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index ca198eb2a6..e00be3bb34 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -311,11 +311,12 @@ mod tests { use derive_new::new; use num_traits::AsPrimitive; use num_traits::Float; - use proptest::collection::vec; use proptest::prelude::*; use tract_core::ops::einsum::BasicMatMul; use tract_core::tract_data::itertools::Itertools; - use tract_core::tract_linalg::frame::block_quant::{BlockQuant, BlockQuantFact, BlockQuantValue, Q4_0}; + use tract_core::tract_linalg::frame::block_quant::{ + BlockQuant, BlockQuantFact, BlockQuantValue, Q4_0, + }; pub(crate) fn run_mmm_test_case( (batch, m, k, n): (usize, usize, usize, usize), @@ -617,7 +618,8 @@ mod tests { #[test] fn mmm_mlx_prop_f16(pb in any::>()) { - prop_assert_eq!(pb.run().unwrap(), pb.reference().unwrap()) + let output = pb.run().unwrap(); + let _ = output.close_enough(&pb.reference().unwrap(), Approximation::Approximate); } #[test] @@ -639,7 +641,7 @@ mod tests { } )) { let output = pb.run().unwrap(); - let _ = output.close_enough(&pb.reference().unwrap(), Approximation::SuperApproximate); + let _ = output.close_enough(&pb.reference().unwrap(), Approximation::Approximate); } #[test] @@ -650,7 +652,7 @@ mod tests { } )) { let output = pb.run().unwrap(); - let _ = output.close_enough(&pb.reference().unwrap(), Approximation::SuperApproximate + let _ = output.close_enough(&pb.reference().unwrap(), Approximation::Approximate ); } } @@ -689,27 +691,35 @@ mod tests { type Strategy = BoxedStrategy; fn arbitrary_with(params: MmmProblemParams) -> Self::Strategy { - (1usize..4, 1usize..128, 1usize..512, 1usize..128) + (1usize..4, 1usize..128, 1usize..256, 1usize..128) .prop_flat_map(move |(b, m, mut k, n)| { - if params.q4_0_weights { k = k.div_ceil(32) * 32 }; - - let lhs_len = b * m * k; - let rhs_len = b * k * n; + if params.q4_0_weights { + k = k.div_ceil(32) * 32 + }; - let datum = (0usize..10).prop_map(move |x| x.as_() / (b * m * k * n).as_()); + let mut rng = rand::thread_rng(); + let lhs_data: Vec = (0..b * m * k) // Create a vector with 10 elements + .map(|_| F::from(rng.gen_range(0.0..1.0)).unwrap()) // Generate a random float in [0.0, 1.0) + .collect(); + + let rhs_data: Vec = (0..b * n * k) // Create a vector with 10 elements + .map(|_| F::from(rng.gen_range(0.0..1.0)).unwrap()) // Generate a random float in [0.0, 1.0) + .collect(); ( Just(b), Just(m), Just(k), Just(n), - vec(datum.clone(), lhs_len..=lhs_len), + Just(lhs_data), proptest::bool::ANY, - vec(datum, rhs_len..=rhs_len), + Just(rhs_data), proptest::bool::ANY, ) }) .prop_map(move |(b, m, k, n, lhs, mut transpose_lhs, rhs, mut transpose_rhs)| { - if params.force_k_as_inner_axis{ (transpose_lhs, transpose_rhs) = (false, true); } + if params.force_k_as_inner_axis { + (transpose_lhs, transpose_rhs) = (false, true); + } Self { b, m, @@ -721,7 +731,7 @@ mod tests { transpose_rhs, q4_0: params.q4_0_weights, _phantom: std::marker::PhantomData, - } + } }) .boxed() } @@ -768,21 +778,30 @@ mod tests { let rhs = if self.transpose_rhs { if !self.q4_0 { Tensor::from_shape(&[self.b, self.n, self.k], &self.rhs)? - } - else { - dbg!(&self.rhs); - let mut b_quant = Q4_0.quant_f32(&self.rhs.clone().into_iter().map(|x| x.to_f32().unwrap()).collect_vec())?; + } else { + let mut b_quant = Q4_0.quant_f32( + &self + .rhs + .clone() + .into_iter() + .map(|x| x.to_f32().unwrap()) + .collect_vec(), + )?; crate::utils::tract_to_gguf_q4_0_packing(&mut b_quant)?; tensor0(Opaque(Arc::new(BlockQuantValue { - fact: BlockQuantFact { format: Box::new(Q4_0), shape: tvec![self.b, self.n, self.k] }, + fact: BlockQuantFact { + format: Box::new(Q4_0), + shape: tvec![self.b, self.n, self.k], + }, value: b_quant, }))) } } else { Tensor::from_shape(&[self.b, self.k, self.n], &self.rhs)? - }.into_metal()?; + } + .into_metal()?; let matmul = GemmImpl::::new(self.transpose_lhs, self.transpose_rhs); diff --git a/metal/src/transform.rs b/metal/src/transform.rs index be59c71554..f75ca44af8 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -27,6 +27,7 @@ use tract_core::ops::element_wise::ElementWiseOp; use tract_core::ops::konst::Const; use tract_core::ops::logic::Comp; use tract_core::ops::nn::{Reduce, Softmax as CoreSoftmax}; +use tract_core::tract_linalg::frame::block_quant::BlockQuantValue; use tract_core::transform::ModelTransform; use tract_itertools::Itertools; @@ -343,7 +344,7 @@ fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, in_facts: &[TypedFact]) -> MetalGemmImplKind::Mlx => MlxGemm.is_supported_dts(in_facts), MetalGemmImplKind::Mps => MpsMatMul.is_supported_dts(in_facts), MetalGemmImplKind::Mfa => MfaGemm.is_supported_dts(in_facts), - MetalGemmImplKind::Ggml => GgmlGemm.is_supported_dts(in_facts), + MetalGemmImplKind::Ggml => Ok(GgmlGemm.is_supported_dts(in_facts).unwrap_or(false) || GgmlGemm.is_supported_dts(&[in_facts[1].clone(), in_facts[0].clone()]).unwrap_or(false)), }; is_supported.unwrap_or(false) } @@ -356,29 +357,45 @@ fn convert_matmul_to_metal( op: &BasicMatMul, gemm_impl: MetalGemmImplKind, ) -> TractResult> { - let matmul: Box = match gemm_impl { + let mut matmul_output = match gemm_impl { MetalGemmImplKind::Mlx => { - Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + let op = ops::MetalGemm::::new(op.transpose_a, op.transpose_b); + target.wire_node(node.name.clone(), op, inputs)? } MetalGemmImplKind::Mps => { - Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + let op = ops::MetalGemm::::new(op.transpose_a, op.transpose_b); + target.wire_node(node.name.clone(), op, inputs)? } MetalGemmImplKind::Mfa => { - Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + let op = ops::MetalGemm::::new(op.transpose_a, op.transpose_b); + target.wire_node(node.name.clone(), op, inputs)? } MetalGemmImplKind::Ggml => { - let input_facts: tract_smallvec::SmallVec<[&TypedFact; 4]> = + let mut input_facts: tract_smallvec::SmallVec<[&TypedFact; 4]> = model.node_input_facts(node.id)?; + let mut swap_inputs = false; + if !GgmlGemm.is_supported_dts(&[input_facts[0].clone(), input_facts[1].clone()]).unwrap_or(false) && + GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()]).unwrap_or(false) + { + println!("Swap inputs"); + input_facts.swap(0, 1); + inputs.swap(0, 1); + swap_inputs = true; + } + + let a_pos = swap_inputs as usize; + let b_pos = 1 - swap_inputs as usize; if op.transpose_a { - let rank = input_facts[0].rank(); - let perm_a_op: Box = - Box::new(ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( + assert!(input_facts[a_pos].datum_type != DatumType::Opaque, "Cannot transpose Opaque tensor"); + + let rank = input_facts[a_pos].rank(); + let perm_a_op = ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( rank - 2, rank - 1, - ))); + )); let perm_a_name = node.name.clone() + ".perm_a"; - inputs[0] = target.wire_node(perm_a_name, perm_a_op, &[inputs[0]])?[0]; + inputs[a_pos] = target.wire_node(perm_a_name, perm_a_op, &[inputs[a_pos]])?[0]; } if input_facts[0].datum_type == DatumType::F16 { @@ -388,25 +405,36 @@ fn convert_matmul_to_metal( } if !op.transpose_b { - let rank = input_facts[1].rank(); - let perm_b_op: Box = - Box::new(ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( + assert!(input_facts[b_pos].datum_type != DatumType::Opaque, "Cannot transpose Opaque tensor 1 "); + + let rank = input_facts[b_pos].rank(); + let perm_b_op= ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( rank - 2, rank - 1, - ))); + )); let perm_b_name = node.name.clone() + ".perm_b"; - inputs[1] = target.wire_node(perm_b_name, perm_b_op, &[inputs[1]])?[0]; + inputs[b_pos] = target.wire_node(perm_b_name, perm_b_op, &[inputs[b_pos]])?[0]; } - Box::new(ops::MetalGemm::::new(false, true)) + let op = ops::MetalGemm::::new(false, true); + let mut matmul_output = target.wire_node(node.name.clone(), op, inputs)?; + + if swap_inputs { + let out_fact = target.outlet_fact(matmul_output[0])?; + let rank = &out_fact.opaque_fact.clone().map(|fact| fact.clarify_dt_shape().unwrap().1.len()).unwrap(); + + let perm_out_op = ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( + rank - 2, + rank - 1, + )); + matmul_output = target.wire_node(node.name.clone() + ".perm_out", perm_out_op, &matmul_output)?; + } + matmul_output } }; - let new_in_facts = [target.outlet_fact(inputs[0])?, target.outlet_fact(inputs[1])?]; - - let out_fact = &matmul.output_facts(&new_in_facts)?[0]; + let out_fact = target.outlet_fact(matmul_output[0])?; let out_dt = out_fact.to_metal_fact().map(|f| f.datum_type).unwrap_or(out_fact.datum_type); - let mut matmul_output = target.wire_node(node.name.clone(), matmul, inputs)?; let expected_dt = model.node_output_facts(node.id)?[0].datum_type; if out_dt != expected_dt { @@ -446,8 +474,34 @@ fn convert_logic_ops_to_metal(op: &Comp) -> ops::MetalBinOp { } fn convert_const(op: &Const) -> TractResult { - let metal_fact = MetalFact::from_cpu(Arc::clone(&op.0).into())?; - let metal_const = op.0.clone().into_metal()?.into_opaque_tensor().into_arc_tensor(); + let (tensor, metal_fact) = if op + .0 + .clone() + .view() + .tensor + .to_scalar::() + .is_ok_and(|of| of.downcast_ref::().is_some()) + { + let mut curr_bqv = + op.0.clone() + .view() + .tensor + .to_scalar::()? + .downcast_ref::() + .unwrap() + .clone(); + + crate::utils::tract_to_gguf_q4_0_packing(&mut (curr_bqv.value))?; + + let bqv = BlockQuantValue { value: curr_bqv.clone().value, fact: curr_bqv.clone().fact }; + ( + tensor1(&[Opaque(Arc::new(bqv))]).into_arc_tensor(), + MetalFact::from_cpu(Arc::clone(&op.0).into())?, + ) + } else { + (op.0.clone(), MetalFact::from_cpu(Arc::clone(&op.0).into())?) + }; + let metal_const = tensor.into_metal()?.into_opaque_tensor().into_arc_tensor(); Ok(Const::new_with_opaque_fact(metal_const, Box::new(metal_fact))) } diff --git a/metal/src/utils.rs b/metal/src/utils.rs index 651820e477..7f25e18ab7 100644 --- a/metal/src/utils.rs +++ b/metal/src/utils.rs @@ -131,7 +131,7 @@ pub fn tract_to_gguf_q4_0_packing(data: &mut Blob) -> TractResult<()> { let second_part: &mut [u8; 8] = &mut [0; 8]; second_part.clone_from_slice(&nibbles[8..]); for i in (0..16).rev() { - let lsb = if i % 2 == 0 { nibbles[i / 2] & 0x0F } else { nibbles[i / 2] & 0xF0 >> 4 }; + let lsb = if i % 2 == 0 { nibbles[i / 2] & 0x0F } else { (nibbles[i / 2] & 0xF0) >> 4 }; let msb = if i % 2 == 0 { (second_part[i / 2] & 0x0F) << 4 } else { From 7146285a463c7dff06066495f4a0eacfd621b13d Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 12 Feb 2025 15:06:03 +0100 Subject: [PATCH 07/15] fixed no batch const --- metal/src/transform.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/metal/src/transform.rs b/metal/src/transform.rs index f75ca44af8..14861cca8a 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -378,7 +378,6 @@ fn convert_matmul_to_metal( if !GgmlGemm.is_supported_dts(&[input_facts[0].clone(), input_facts[1].clone()]).unwrap_or(false) && GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()]).unwrap_or(false) { - println!("Swap inputs"); input_facts.swap(0, 1); inputs.swap(0, 1); swap_inputs = true; @@ -405,7 +404,7 @@ fn convert_matmul_to_metal( } if !op.transpose_b { - assert!(input_facts[b_pos].datum_type != DatumType::Opaque, "Cannot transpose Opaque tensor 1 "); + assert!(input_facts[b_pos].datum_type != DatumType::Opaque, "Cannot transpose Opaque tensor"); let rank = input_facts[b_pos].rank(); let perm_b_op= ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( @@ -494,8 +493,14 @@ fn convert_const(op: &Const) -> TractResult { crate::utils::tract_to_gguf_q4_0_packing(&mut (curr_bqv.value))?; let bqv = BlockQuantValue { value: curr_bqv.clone().value, fact: curr_bqv.clone().fact }; + let tensor = if op.0.rank() == 0 { + tensor0(Opaque(Arc::new(bqv))) + } else { + tensor1(&[Opaque(Arc::new(bqv))]) + }; + ( - tensor1(&[Opaque(Arc::new(bqv))]).into_arc_tensor(), + tensor.into_arc_tensor(), MetalFact::from_cpu(Arc::clone(&op.0).into())?, ) } else { From f943cabe2f050efcd3e13459ac8940489374e9ee Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 12 Feb 2025 16:05:27 +0100 Subject: [PATCH 08/15] some code cleaning --- metal/src/kernels/matmul/ggml_gemm/mod.rs | 17 ++++------ metal/src/kernels/matmul/mod.rs | 24 +++++++++----- metal/src/ops/gemm.rs | 38 ++++++----------------- metal/src/transform.rs | 8 ++--- metal/src/utils.rs | 28 ++++++++++------- 5 files changed, 54 insertions(+), 61 deletions(-) diff --git a/metal/src/kernels/matmul/ggml_gemm/mod.rs b/metal/src/kernels/matmul/ggml_gemm/mod.rs index 88ec0502ed..a82c9d2fae 100644 --- a/metal/src/kernels/matmul/ggml_gemm/mod.rs +++ b/metal/src/kernels/matmul/ggml_gemm/mod.rs @@ -1,5 +1,5 @@ use crate::kernels::matmul::{GemmDispatchParams, GemmKernel}; -use crate::utils::is_q4_0; +use crate::utils::as_q40_fact; use crate::MetalTensor; use crate::{LibraryName, MetalContext}; use anyhow::{ensure, Result}; @@ -34,17 +34,12 @@ impl GemmKernel for GgmlGemm { "ggml" } - fn is_supported_dts(&self, facts: &[TypedFact]) -> TractResult { - ensure!(facts.len() == 2); + fn is_supported_dts(&self, facts: &[TypedFact]) -> bool { + assert!(facts.len() == 2, "Ggml: Expected 2 inputs for Matmul"); - if is_q4_0(&facts[1]) { - Ok(matches!(facts[0].datum_type, DatumType::F16 | DatumType::F32)) - } else if facts[1].datum_type == DatumType::F32 { - Ok(facts[0].datum_type == DatumType::F32) - } else { - Ok(facts[1].datum_type == DatumType::F16 - && matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16)) - } + (as_q40_fact(&facts[1]).is_some() && matches!(facts[0].datum_type, DatumType::F16 | DatumType::F32)) || + ((facts[1].datum_type == DatumType::F32) && (facts[0].datum_type == DatumType::F32)) || + ((facts[1].datum_type == DatumType::F16) && matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16)) } fn output_dt(&self, _a_dt: DatumType, _b_dt: DatumType) -> TractResult { diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index e00be3bb34..844ef395ea 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -12,7 +12,7 @@ pub use mlx_gemm::MlxGemm; pub use mmm_tile_8x8::{metal_mmm_tile_8x8, mmm_tile_8x8}; pub use mps::MpsMatMul; -use crate::utils::resolve_tensor_shape; +use crate::utils::as_q40_tensor; use crate::{MetalContext, MetalTensor}; use metal::Buffer; use num_traits::One; @@ -152,10 +152,10 @@ impl GemmDispatchParams { pub trait GemmKernel: fmt::Display + fmt::Debug + Clone + Default + Send + Sync { fn name() -> &'static str; - fn is_supported_dts(&self, facts: &[TypedFact]) -> TractResult { - ensure!(facts.len() == 2); - Ok(matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16) - && facts[0].datum_type == facts[1].datum_type) + fn is_supported_dts(&self, facts: &[TypedFact]) -> bool { + assert!(facts.len() == 2, "Expected 2 inputs for matmul"); + matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16) + && facts[0].datum_type == facts[1].datum_type } fn output_dt(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult { @@ -228,8 +228,13 @@ impl GemmImpl { a: &MetalTensor, b: &MetalTensor, ) -> TractResult { + let b_shape = as_q40_tensor(&b).map(|bqv| { + b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect() + }) + .unwrap_or(b.shape().to_vec()); + let c_dt = self.matmul.output_dt(a.datum_type(), b.datum_type())?; - let c_shape = self.output_shape(a.shape(), &resolve_tensor_shape(b)); + let c_shape = self.output_shape(a.shape(), &b_shape); let c = unsafe { MetalTensor::uninitialized_dt(c_dt, &c_shape)? }; self.dispatch_eval(context, a, b, &c)?; @@ -247,7 +252,12 @@ impl GemmImpl { a.retain_until_completion(); b.retain_until_completion(); c.retain_until_completion(); - let b_shape = resolve_tensor_shape(b); + + let b_shape = as_q40_tensor(&b).map(|bqv| { + b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect() + }) + .unwrap_or(b.shape().to_vec()); + ensure!(c.shape() == self.output_shape(a.shape(), &b_shape).as_slice()); if c.shape().iter().product::() == 0 { diff --git a/metal/src/ops/gemm.rs b/metal/src/ops/gemm.rs index e63c7bbd50..93b94e556d 100644 --- a/metal/src/ops/gemm.rs +++ b/metal/src/ops/gemm.rs @@ -1,11 +1,10 @@ use crate::kernels::matmul::{GemmImpl, GemmKernel}; use crate::ops::MetalEvalOp; -use crate::utils::resolve_tensor_shape; +use crate::utils::{as_q40_fact, as_q40_tensor}; use crate::{MetalContext, MetalTensorExt}; use anyhow::{bail, ensure}; use tract_core::internal::*; -use tract_core::tract_linalg::frame::block_quant::{BlockQuantFact, BlockQuantValue}; #[derive(Debug, Default, Clone)] pub struct MetalGemm { @@ -55,36 +54,14 @@ impl MetalGemm { ); let out_shape = self.kernel.output_shape(&a.shape, &b.shape); Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?) - } else if let Some(opf) = inputs[0] - .opaque_fact - .as_ref() - .and_then(|of| of.downcast_ref::()) - .or_else(|| { - inputs[0] - .konst - .as_ref() - .and_then(|k| k.to_scalar::().ok()) - .and_then(|o| o.downcast_ref::()) - .map(|v| &v.fact) - }) + } else if let Some(opf) = as_q40_fact(inputs[0]) { let a_shape: ShapeFact = a.shape.iter().cloned().chain(opf.shape.iter().map(|d| d.to_dim())).collect(); let out_shape = self.kernel.output_shape(&a_shape, &b.shape); Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?) - } else if let Some(opf) = inputs[1] - .opaque_fact - .as_ref() - .and_then(|of| of.downcast_ref::()) - .or_else(|| { - inputs[1] - .konst - .as_ref() - .and_then(|k| k.to_scalar::().ok()) - .and_then(|o| o.downcast_ref::()) - .map(|v| &v.fact) - }) + } else if let Some(opf) = as_q40_fact(inputs[1]) { let b_shape: ShapeFact = b.shape.iter().cloned().chain(opf.shape.iter().map(|d| d.to_dim())).collect(); @@ -110,9 +87,14 @@ impl MetalEvalOp for MetalGemm { let b = b_opaque .to_metal_tensor() .with_context(|| anyhow!("B tensor is not a metal tensor {:?}", b_opaque))?; - let c_dt = self.kernel.matmul.output_dt(a.datum_type(), b.datum_type())?; + + let b_shape = as_q40_tensor(&b).map(|bqv| { + b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect() + }) + .unwrap_or(b.shape().to_vec()); - let c_shape = self.kernel.output_shape(a.shape(), &resolve_tensor_shape(&b)); + let c_dt = self.kernel.matmul.output_dt(a.datum_type(), b.datum_type())?; + let c_shape = self.kernel.output_shape(a.shape(), &b_shape); let c = crate::ops::make_tensor_for_node(session, node_id, c_dt, &c_shape)?; self.kernel.dispatch_eval(context, a, b, &c)?; Ok(tvec![c.into_opaque_tensor().into_tvalue()]) diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 14861cca8a..81916be4ec 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -344,9 +344,9 @@ fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, in_facts: &[TypedFact]) -> MetalGemmImplKind::Mlx => MlxGemm.is_supported_dts(in_facts), MetalGemmImplKind::Mps => MpsMatMul.is_supported_dts(in_facts), MetalGemmImplKind::Mfa => MfaGemm.is_supported_dts(in_facts), - MetalGemmImplKind::Ggml => Ok(GgmlGemm.is_supported_dts(in_facts).unwrap_or(false) || GgmlGemm.is_supported_dts(&[in_facts[1].clone(), in_facts[0].clone()]).unwrap_or(false)), + MetalGemmImplKind::Ggml => GgmlGemm.is_supported_dts(in_facts) || GgmlGemm.is_supported_dts(&[in_facts[1].clone(), in_facts[0].clone()]), }; - is_supported.unwrap_or(false) + is_supported } fn convert_matmul_to_metal( @@ -375,8 +375,8 @@ fn convert_matmul_to_metal( model.node_input_facts(node.id)?; let mut swap_inputs = false; - if !GgmlGemm.is_supported_dts(&[input_facts[0].clone(), input_facts[1].clone()]).unwrap_or(false) && - GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()]).unwrap_or(false) + if !GgmlGemm.is_supported_dts(&[input_facts[0].clone(), input_facts[1].clone()]) && + GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()]) { input_facts.swap(0, 1); inputs.swap(0, 1); diff --git a/metal/src/utils.rs b/metal/src/utils.rs index 7f25e18ab7..773b65b9f9 100644 --- a/metal/src/utils.rs +++ b/metal/src/utils.rs @@ -2,7 +2,6 @@ use crate::fact::{MetalFact, MetalOrigin, MetalTypedFactExt}; use crate::MetalTensor; use num_traits::{AsPrimitive, Zero}; use tract_core::internal::*; -use tract_data::itertools::Itertools; use tract_linalg::frame::block_quant::{BlockQuantFact, BlockQuantValue, Q4_0}; #[macro_export] @@ -98,24 +97,31 @@ where .collect::>()) } -pub fn is_q4_0(fact: &TypedFact) -> bool { - fact.opaque_fact.as_ref().is_some_and(|of| { - of.downcast_ref::().map(|bqf| bqf.format.same_as(&Q4_0)).unwrap_or(false) +pub fn as_q40_fact(fact: &TypedFact) -> Option { + fact.opaque_fact + .as_ref() + .and_then(|of| of.downcast_ref::()) + .map(|bqf| { if bqf.format.same_as(&Q4_0) { Some(bqf.clone()) } else { None }}).flatten() + .or_else(|| { + fact + .konst + .as_ref() + .and_then(|k| k.to_scalar::().ok()) + .and_then(|o| o.downcast_ref::()) + .map(|v| &v.fact) + .map(|bqf| { if bqf.format.same_as(&Q4_0) { Some(bqf.clone()) } else { None }}).flatten() }) } -pub fn resolve_tensor_shape(a: &MetalTensor) -> Vec { +pub fn as_q40_tensor(a: &MetalTensor) -> Option { a.view() .tensor .to_scalar::() + .ok() .map(|od| { od.downcast_ref::() - .map(|bqv| { - a.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect_vec() - }) - .unwrap_or(a.shape().to_vec()) - }) - .unwrap_or(a.shape().to_vec()) + .map(|bqf| { if bqf.fact.format.same_as(&Q4_0) { Some(bqf.clone()) } else { None }}).flatten() + }).flatten() } pub fn tract_to_gguf_q4_0_packing(data: &mut Blob) -> TractResult<()> { From c00f406e405dd42426faf0ebaaef2c238f06de80 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Thu, 13 Feb 2025 10:33:29 +0100 Subject: [PATCH 09/15] do NOT clone BlockQuantFact/Value --- metal/src/utils.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/metal/src/utils.rs b/metal/src/utils.rs index 773b65b9f9..81a40aacda 100644 --- a/metal/src/utils.rs +++ b/metal/src/utils.rs @@ -97,11 +97,11 @@ where .collect::>()) } -pub fn as_q40_fact(fact: &TypedFact) -> Option { +pub fn as_q40_fact(fact: &TypedFact) -> Option<&BlockQuantFact> { fact.opaque_fact .as_ref() .and_then(|of| of.downcast_ref::()) - .map(|bqf| { if bqf.format.same_as(&Q4_0) { Some(bqf.clone()) } else { None }}).flatten() + .map(|bqf| { if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }}).flatten() .or_else(|| { fact .konst @@ -109,18 +109,18 @@ pub fn as_q40_fact(fact: &TypedFact) -> Option { .and_then(|k| k.to_scalar::().ok()) .and_then(|o| o.downcast_ref::()) .map(|v| &v.fact) - .map(|bqf| { if bqf.format.same_as(&Q4_0) { Some(bqf.clone()) } else { None }}).flatten() + .map(|bqf| { if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }}).flatten() }) } -pub fn as_q40_tensor(a: &MetalTensor) -> Option { +pub fn as_q40_tensor(a: &MetalTensor) -> Option<&BlockQuantValue> { a.view() .tensor .to_scalar::() .ok() .map(|od| { od.downcast_ref::() - .map(|bqf| { if bqf.fact.format.same_as(&Q4_0) { Some(bqf.clone()) } else { None }}).flatten() + .map(|bqf| { if bqf.fact.format.same_as(&Q4_0) { Some(bqf) } else { None }}).flatten() }).flatten() } From 1713d099cdd1d7071e3b49fbda08b14186a03895 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Thu, 13 Feb 2025 13:47:15 +0100 Subject: [PATCH 10/15] detect q40 in generic dispatch_eval --- metal/src/kernels/matmul/basic/mod.rs | 1 + metal/src/kernels/matmul/ggml_gemm/mod.rs | 23 ++++++++++++----------- metal/src/kernels/matmul/mfa/mod.rs | 1 + metal/src/kernels/matmul/mlx_gemm/mod.rs | 1 + metal/src/kernels/matmul/mod.rs | 21 +++++++++++++++++++++ metal/src/kernels/matmul/mps/mod.rs | 1 + 6 files changed, 37 insertions(+), 11 deletions(-) diff --git a/metal/src/kernels/matmul/basic/mod.rs b/metal/src/kernels/matmul/basic/mod.rs index bda2160fe3..a3b58f3fc3 100644 --- a/metal/src/kernels/matmul/basic/mod.rs +++ b/metal/src/kernels/matmul/basic/mod.rs @@ -34,6 +34,7 @@ impl GemmKernel for BasicMatMul { transpose_b, b_offset, c_offset, + .. } = params; ensure!( diff --git a/metal/src/kernels/matmul/ggml_gemm/mod.rs b/metal/src/kernels/matmul/ggml_gemm/mod.rs index a82c9d2fae..935bae4526 100644 --- a/metal/src/kernels/matmul/ggml_gemm/mod.rs +++ b/metal/src/kernels/matmul/ggml_gemm/mod.rs @@ -64,12 +64,12 @@ impl GemmKernel for GgmlGemm { a_offset, transpose_b, b_offset, + q40_b, c_offset, } = params; ensure!(!transpose_a && transpose_b); - // Kernel output is transposed so we switch the inputs let a_el_size = dts[0].size_of(); let a_strides = [ (batch * m * k * a_el_size) as u64, @@ -78,8 +78,7 @@ impl GemmKernel for GgmlGemm { a_el_size as u64, ]; - let b_strides = if dts[1] == DatumType::Opaque { - // Assume Opaque == Q4. TODO + let b_strides = if q40_b { let b_el_size = 18; [ (batch * n * (k / 32) * b_el_size) as u64, @@ -110,11 +109,11 @@ impl GemmKernel for GgmlGemm { if (dts[0] == DatumType::F32) && (k % 32 == 0) && (k >= 64) && (m > 4) { dispatch_metal_ggml_gemm( - context, dts, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, c_offset, + context, dts, q40_b, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, c_offset, )?; } else { dispatch_metal_ggml_gemv( - context, dts, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, c_offset, + context, dts, q40_b, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, c_offset, )?; } @@ -124,6 +123,7 @@ impl GemmKernel for GgmlGemm { fn mv_kernel_name_and_dispatch_params( dts: &[DatumType], + q40_b: bool, params: &GgmlParams, ) -> Result<(String, (u64, u64, u64))> { let (nth0, nth1, nrows): (u64, u64, u64) = (32, 1, 1); @@ -146,7 +146,7 @@ fn mv_kernel_name_and_dispatch_params( Ok(("kernel_mul_mv_f16_f16".to_string(), (nth0, nth1, 4))) } } else { - ensure!((dts[1] == DatumType::Opaque) && (dts[0] == DatumType::F32)); + ensure!((q40_b) && (dts[0] == DatumType::F32)); Ok(("kernel_mul_mv_q4_0_f32".to_string(), (8, 8, 1))) } } @@ -155,6 +155,7 @@ fn mv_kernel_name_and_dispatch_params( fn dispatch_metal_ggml_gemv( context: &MetalContext, dts: [DatumType; 3], + q40_b: bool, params: GgmlParams, a_offset: usize, a_buffer: &Buffer, @@ -163,7 +164,7 @@ fn dispatch_metal_ggml_gemv( output: &Buffer, output_offset: usize, ) -> Result<()> { - let (name, (nth0, nth1, nrows)) = mv_kernel_name_and_dispatch_params(&dts, ¶ms)?; + let (name, (nth0, nth1, nrows)) = mv_kernel_name_and_dispatch_params(&dts, q40_b, ¶ms)?; //dbg!(&name); let pipeline = context.shared_context().load_pipeline(LibraryName::Ggml, &name)?; @@ -180,7 +181,7 @@ fn dispatch_metal_ggml_gemv( encoder.set_buffer(2, Some(a_buffer), a_offset as NSUInteger); encoder.set_buffer(3, Some(output), output_offset as NSUInteger); - let grid_size = if dts[1] != DatumType::Opaque { + let grid_size = if !q40_b { MTLSize { width: params.n as u64, height: (params.m as u64).div_ceil(nrows), @@ -206,6 +207,7 @@ fn dispatch_metal_ggml_gemv( fn dispatch_metal_ggml_gemm( context: &MetalContext, dts: [DatumType; 3], + q40_b: bool, params: GgmlParams, a_offset: usize, a_buffer: &Buffer, @@ -214,16 +216,15 @@ fn dispatch_metal_ggml_gemm( output: &Buffer, output_offset: usize, ) -> Result<()> { - // Warning: currently assuming opaque == q4_0, ensure!( - matches!(dts[1], DatumType::F32 | DatumType::F16 | DatumType::Opaque) + (matches!(dts[1], DatumType::F32 | DatumType::F16) || q40_b) && dts[0] == DatumType::F32 ); let mut i1_tname = MetalTensor::tname(dts[1])?; let i2_tname = MetalTensor::tname(dts[0])?; - if i1_tname == "opaque" { + if q40_b { i1_tname = "q4_0"; } let name = format!("kernel_mul_mm_{i1_tname}_{i2_tname}"); diff --git a/metal/src/kernels/matmul/mfa/mod.rs b/metal/src/kernels/matmul/mfa/mod.rs index d646f53195..4e7eef19d9 100644 --- a/metal/src/kernels/matmul/mfa/mod.rs +++ b/metal/src/kernels/matmul/mfa/mod.rs @@ -39,6 +39,7 @@ impl GemmKernel for MfaGemm { transpose_b, b_offset, c_offset, + .. } = params; let a_strides = if transpose_a { diff --git a/metal/src/kernels/matmul/mlx_gemm/mod.rs b/metal/src/kernels/matmul/mlx_gemm/mod.rs index b60db21f1b..fb025f0533 100644 --- a/metal/src/kernels/matmul/mlx_gemm/mod.rs +++ b/metal/src/kernels/matmul/mlx_gemm/mod.rs @@ -73,6 +73,7 @@ impl GemmKernel for MlxGemm { transpose_b, b_offset, c_offset, + .. } = params; let a_strides = if transpose_a { diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 844ef395ea..a8976d8087 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -44,6 +44,7 @@ pub struct GemmDispatchParams { pub a_offset: usize, pub transpose_b: bool, pub b_offset: usize, + pub q40_b: bool, pub c_offset: usize, } @@ -57,6 +58,7 @@ impl GemmDispatchParams { b_offset: usize, b_shape: &[usize], transpose_b: bool, + q40_b: bool, c_offset: usize, c_shape: &[usize], ) -> TractResult> { @@ -87,6 +89,7 @@ impl GemmDispatchParams { a_offset, transpose_b, b_offset, + q40_b, c_offset, }]), // bkm, 1kn -> bmn @@ -103,6 +106,7 @@ impl GemmDispatchParams { a_offset: a_offset + a_batch_idx * m * k * dts[0].size_of(), transpose_b, b_offset, + q40_b, c_offset: c_offset + a_batch_idx * m * n * dts[2].size_of(), }) .collect()), @@ -122,6 +126,7 @@ impl GemmDispatchParams { a_offset, transpose_b, b_offset: b_offset + b_batch_idx * n * k * dts[1].size_of(), + q40_b, c_offset: c_offset + b_batch_idx * m * n * dts[2].size_of(), }) .collect()), @@ -142,6 +147,7 @@ impl GemmDispatchParams { a_offset, transpose_b, b_offset, + q40_b, c_offset, }]) } @@ -272,6 +278,7 @@ impl GemmImpl { b.metal_offset(), &b_shape, self.transpose_b, + as_q40_tensor(&b).is_some(), c.metal_offset(), c.shape(), )?; @@ -412,6 +419,7 @@ mod tests { 0, &[1, k, n], false, + false, 0, &[1, m, n], )?, @@ -425,6 +433,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 0, }] ); @@ -438,6 +447,7 @@ mod tests { 0, &[10, k, n], false, + false, 0, &[10, m, n], )?, @@ -451,6 +461,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 0, }] ); @@ -464,6 +475,7 @@ mod tests { 0, &[2, k, n], false, + false, 10, &[2, m, n], )?, @@ -478,6 +490,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 10, }, GemmDispatchParams { @@ -490,6 +503,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 1 * n * k * dt.size_of(), + q40_b: false, c_offset: 10 + m * n * dt.size_of(), } ] @@ -504,6 +518,7 @@ mod tests { 0, &[2, k, n], false, + false, 100, &[2, m, n], )?, @@ -517,6 +532,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 100, }] ); @@ -530,6 +546,7 @@ mod tests { 0, &[1, k, n], false, + false, 100, &[2, m, n], )?, @@ -544,6 +561,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 100, }, GemmDispatchParams { @@ -556,6 +574,7 @@ mod tests { a_offset: 1 * m * k * dt.size_of(), transpose_b: false, b_offset: 0, + q40_b: false, c_offset: 100 + 1 * m * n * dt.size_of(), } ] @@ -570,6 +589,7 @@ mod tests { 10, &[1, k, n], false, + false, 0, &[10, m, n], )?, @@ -583,6 +603,7 @@ mod tests { a_offset: 0, transpose_b: false, b_offset: 10, + q40_b: false, c_offset: 0, }] ); diff --git a/metal/src/kernels/matmul/mps/mod.rs b/metal/src/kernels/matmul/mps/mod.rs index 33efac3266..0b8917d690 100644 --- a/metal/src/kernels/matmul/mps/mod.rs +++ b/metal/src/kernels/matmul/mps/mod.rs @@ -53,6 +53,7 @@ impl GemmKernel for MpsMatMul { transpose_b, b_offset, c_offset, + .. } = params; let dt = dts[0]; From 9693d32b9f607049fbaf11ed617b3d3339ad4312 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Fri, 14 Feb 2025 15:40:39 +0100 Subject: [PATCH 11/15] reformat --- metal/src/kernels/matmul/ggml_gemm/mod.rs | 17 ++--- metal/src/kernels/matmul/mod.rs | 18 +++--- metal/src/ops/gemm.rs | 15 ++--- metal/src/transform.rs | 76 +++++++++++------------ metal/src/utils.rs | 31 ++++----- 5 files changed, 79 insertions(+), 78 deletions(-) diff --git a/metal/src/kernels/matmul/ggml_gemm/mod.rs b/metal/src/kernels/matmul/ggml_gemm/mod.rs index 935bae4526..13e28ed556 100644 --- a/metal/src/kernels/matmul/ggml_gemm/mod.rs +++ b/metal/src/kernels/matmul/ggml_gemm/mod.rs @@ -37,9 +37,11 @@ impl GemmKernel for GgmlGemm { fn is_supported_dts(&self, facts: &[TypedFact]) -> bool { assert!(facts.len() == 2, "Ggml: Expected 2 inputs for Matmul"); - (as_q40_fact(&facts[1]).is_some() && matches!(facts[0].datum_type, DatumType::F16 | DatumType::F32)) || - ((facts[1].datum_type == DatumType::F32) && (facts[0].datum_type == DatumType::F32)) || - ((facts[1].datum_type == DatumType::F16) && matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16)) + (as_q40_fact(&facts[1]).is_some() + && matches!(facts[0].datum_type, DatumType::F16 | DatumType::F32)) + || ((facts[1].datum_type == DatumType::F32) && (facts[0].datum_type == DatumType::F32)) + || ((facts[1].datum_type == DatumType::F16) + && matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16)) } fn output_dt(&self, _a_dt: DatumType, _b_dt: DatumType) -> TractResult { @@ -109,11 +111,13 @@ impl GemmKernel for GgmlGemm { if (dts[0] == DatumType::F32) && (k % 32 == 0) && (k >= 64) && (m > 4) { dispatch_metal_ggml_gemm( - context, dts, q40_b, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, c_offset, + context, dts, q40_b, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, + c_offset, )?; } else { dispatch_metal_ggml_gemv( - context, dts, q40_b, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, c_offset, + context, dts, q40_b, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, + c_offset, )?; } @@ -217,8 +221,7 @@ fn dispatch_metal_ggml_gemm( output_offset: usize, ) -> Result<()> { ensure!( - (matches!(dts[1], DatumType::F32 | DatumType::F16) || q40_b) - && dts[0] == DatumType::F32 + (matches!(dts[1], DatumType::F32 | DatumType::F16) || q40_b) && dts[0] == DatumType::F32 ); let mut i1_tname = MetalTensor::tname(dts[1])?; diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index a8976d8087..e604c420bb 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -234,10 +234,9 @@ impl GemmImpl { a: &MetalTensor, b: &MetalTensor, ) -> TractResult { - let b_shape = as_q40_tensor(&b).map(|bqv| { - b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect() - }) - .unwrap_or(b.shape().to_vec()); + let b_shape = as_q40_tensor(&b) + .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect()) + .unwrap_or(b.shape().to_vec()); let c_dt = self.matmul.output_dt(a.datum_type(), b.datum_type())?; let c_shape = self.output_shape(a.shape(), &b_shape); @@ -259,11 +258,10 @@ impl GemmImpl { b.retain_until_completion(); c.retain_until_completion(); - let b_shape = as_q40_tensor(&b).map(|bqv| { - b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect() - }) - .unwrap_or(b.shape().to_vec()); - + let b_shape = as_q40_tensor(&b) + .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect()) + .unwrap_or(b.shape().to_vec()); + ensure!(c.shape() == self.output_shape(a.shape(), &b_shape).as_slice()); if c.shape().iter().product::() == 0 { @@ -732,7 +730,7 @@ mod tests { let lhs_data: Vec = (0..b * m * k) // Create a vector with 10 elements .map(|_| F::from(rng.gen_range(0.0..1.0)).unwrap()) // Generate a random float in [0.0, 1.0) .collect(); - + let rhs_data: Vec = (0..b * n * k) // Create a vector with 10 elements .map(|_| F::from(rng.gen_range(0.0..1.0)).unwrap()) // Generate a random float in [0.0, 1.0) .collect(); diff --git a/metal/src/ops/gemm.rs b/metal/src/ops/gemm.rs index 93b94e556d..edf45ae5f8 100644 --- a/metal/src/ops/gemm.rs +++ b/metal/src/ops/gemm.rs @@ -54,15 +54,13 @@ impl MetalGemm { ); let out_shape = self.kernel.output_shape(&a.shape, &b.shape); Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?) - } else if let Some(opf) = as_q40_fact(inputs[0]) - { + } else if let Some(opf) = as_q40_fact(inputs[0]) { let a_shape: ShapeFact = a.shape.iter().cloned().chain(opf.shape.iter().map(|d| d.to_dim())).collect(); let out_shape = self.kernel.output_shape(&a_shape, &b.shape); Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?) - } else if let Some(opf) = as_q40_fact(inputs[1]) - { + } else if let Some(opf) = as_q40_fact(inputs[1]) { let b_shape: ShapeFact = b.shape.iter().cloned().chain(opf.shape.iter().map(|d| d.to_dim())).collect(); let out_shape = self.kernel.output_shape(&a.shape, &b_shape); @@ -87,11 +85,10 @@ impl MetalEvalOp for MetalGemm { let b = b_opaque .to_metal_tensor() .with_context(|| anyhow!("B tensor is not a metal tensor {:?}", b_opaque))?; - - let b_shape = as_q40_tensor(&b).map(|bqv| { - b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect() - }) - .unwrap_or(b.shape().to_vec()); + + let b_shape = as_q40_tensor(&b) + .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect()) + .unwrap_or(b.shape().to_vec()); let c_dt = self.kernel.matmul.output_dt(a.datum_type(), b.datum_type())?; let c_shape = self.kernel.output_shape(a.shape(), &b_shape); diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 81916be4ec..1c3da27094 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -14,6 +14,7 @@ use crate::rewrite_rules::{ BasicSilu, }; use crate::tensor::MetalTensorExt; +use crate::utils::as_q40_fact; use crate::{IntoMetal, MetalFact, MetalTensor}; use std::borrow::Cow; use std::fmt::Debug; @@ -27,7 +28,7 @@ use tract_core::ops::element_wise::ElementWiseOp; use tract_core::ops::konst::Const; use tract_core::ops::logic::Comp; use tract_core::ops::nn::{Reduce, Softmax as CoreSoftmax}; -use tract_core::tract_linalg::frame::block_quant::BlockQuantValue; +use tract_core::tract_linalg::frame::block_quant::{BlockQuantValue, Q4_0}; use tract_core::transform::ModelTransform; use tract_itertools::Itertools; @@ -344,7 +345,10 @@ fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, in_facts: &[TypedFact]) -> MetalGemmImplKind::Mlx => MlxGemm.is_supported_dts(in_facts), MetalGemmImplKind::Mps => MpsMatMul.is_supported_dts(in_facts), MetalGemmImplKind::Mfa => MfaGemm.is_supported_dts(in_facts), - MetalGemmImplKind::Ggml => GgmlGemm.is_supported_dts(in_facts) || GgmlGemm.is_supported_dts(&[in_facts[1].clone(), in_facts[0].clone()]), + MetalGemmImplKind::Ggml => { + GgmlGemm.is_supported_dts(in_facts) + || GgmlGemm.is_supported_dts(&[in_facts[1].clone(), in_facts[0].clone()]) + } }; is_supported } @@ -375,8 +379,8 @@ fn convert_matmul_to_metal( model.node_input_facts(node.id)?; let mut swap_inputs = false; - if !GgmlGemm.is_supported_dts(&[input_facts[0].clone(), input_facts[1].clone()]) && - GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()]) + if !GgmlGemm.is_supported_dts(&[input_facts[0].clone(), input_facts[1].clone()]) + && GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()]) { input_facts.swap(0, 1); inputs.swap(0, 1); @@ -386,13 +390,13 @@ fn convert_matmul_to_metal( let a_pos = swap_inputs as usize; let b_pos = 1 - swap_inputs as usize; if op.transpose_a { - assert!(input_facts[a_pos].datum_type != DatumType::Opaque, "Cannot transpose Opaque tensor"); + assert!(!as_q40_fact(input_facts[a_pos]).is_some(), "Cannot transpose Q40 tensor"); let rank = input_facts[a_pos].rank(); let perm_a_op = ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( - rank - 2, - rank - 1, - )); + rank - 2, + rank - 1, + )); let perm_a_name = node.name.clone() + ".perm_a"; inputs[a_pos] = target.wire_node(perm_a_name, perm_a_op, &[inputs[a_pos]])?[0]; } @@ -404,13 +408,13 @@ fn convert_matmul_to_metal( } if !op.transpose_b { - assert!(input_facts[b_pos].datum_type != DatumType::Opaque, "Cannot transpose Opaque tensor"); + assert!(!as_q40_fact(input_facts[b_pos]).is_some(), "Cannot transpose Q40 tensor"); let rank = input_facts[b_pos].rank(); - let perm_b_op= ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( - rank - 2, - rank - 1, - )); + let perm_b_op = ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( + rank - 2, + rank - 1, + )); let perm_b_name = node.name.clone() + ".perm_b"; inputs[b_pos] = target.wire_node(perm_b_name, perm_b_op, &[inputs[b_pos]])?[0]; } @@ -419,13 +423,21 @@ fn convert_matmul_to_metal( if swap_inputs { let out_fact = target.outlet_fact(matmul_output[0])?; - let rank = &out_fact.opaque_fact.clone().map(|fact| fact.clarify_dt_shape().unwrap().1.len()).unwrap(); + let rank = &out_fact + .opaque_fact + .clone() + .map(|fact| fact.clarify_dt_shape().unwrap().1.len()) + .unwrap(); let perm_out_op = ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( rank - 2, rank - 1, )); - matmul_output = target.wire_node(node.name.clone() + ".perm_out", perm_out_op, &matmul_output)?; + matmul_output = target.wire_node( + node.name.clone() + ".perm_out", + perm_out_op, + &matmul_output, + )?; } matmul_output } @@ -473,34 +485,22 @@ fn convert_logic_ops_to_metal(op: &Comp) -> ops::MetalBinOp { } fn convert_const(op: &Const) -> TractResult { - let (tensor, metal_fact) = if op + let (tensor, metal_fact) = if let Some(curr_bqv) = op .0 - .clone() .view() .tensor .to_scalar::() - .is_ok_and(|of| of.downcast_ref::().is_some()) - { - let mut curr_bqv = - op.0.clone() - .view() - .tensor - .to_scalar::()? - .downcast_ref::() - .unwrap() - .clone(); - - crate::utils::tract_to_gguf_q4_0_packing(&mut (curr_bqv.value))?; - - let bqv = BlockQuantValue { value: curr_bqv.clone().value, fact: curr_bqv.clone().fact }; - let tensor = if op.0.rank() == 0 { - tensor0(Opaque(Arc::new(bqv))) - } else { - tensor1(&[Opaque(Arc::new(bqv))]) - }; - + .ok() + .and_then(|of| of.downcast_ref::()) + .map(|bqv| if bqv.fact.format.same_as(&Q4_0) { Some(bqv) } else { None }) + .flatten() + { + let mut bqv = curr_bqv.clone(); + crate::utils::tract_to_gguf_q4_0_packing(&mut (bqv.value))?; + + let bqv = BlockQuantValue { value: bqv.value, fact: bqv.fact }; ( - tensor.into_arc_tensor(), + tensor0(Opaque(Arc::new(bqv))).broadcast_into_rank(op.0.rank())?.into_arc_tensor(), MetalFact::from_cpu(Arc::clone(&op.0).into())?, ) } else { diff --git a/metal/src/utils.rs b/metal/src/utils.rs index 81a40aacda..41bcec1c6b 100644 --- a/metal/src/utils.rs +++ b/metal/src/utils.rs @@ -99,18 +99,19 @@ where pub fn as_q40_fact(fact: &TypedFact) -> Option<&BlockQuantFact> { fact.opaque_fact - .as_ref() - .and_then(|of| of.downcast_ref::()) - .map(|bqf| { if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }}).flatten() - .or_else(|| { - fact - .konst - .as_ref() - .and_then(|k| k.to_scalar::().ok()) - .and_then(|o| o.downcast_ref::()) - .map(|v| &v.fact) - .map(|bqf| { if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }}).flatten() - }) + .as_ref() + .and_then(|of| of.downcast_ref::()) + .map(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }) + .flatten() + .or_else(|| { + fact.konst + .as_ref() + .and_then(|k| k.to_scalar::().ok()) + .and_then(|o| o.downcast_ref::()) + .map(|v| &v.fact) + .map(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }) + .flatten() + }) } pub fn as_q40_tensor(a: &MetalTensor) -> Option<&BlockQuantValue> { @@ -120,8 +121,10 @@ pub fn as_q40_tensor(a: &MetalTensor) -> Option<&BlockQuantValue> { .ok() .map(|od| { od.downcast_ref::() - .map(|bqf| { if bqf.fact.format.same_as(&Q4_0) { Some(bqf) } else { None }}).flatten() - }).flatten() + .map(|bqv| if bqv.fact.format.same_as(&Q4_0) { Some(bqv) } else { None }) + .flatten() + }) + .flatten() } pub fn tract_to_gguf_q4_0_packing(data: &mut Blob) -> TractResult<()> { From 916c9cdead6d96ad903ff4bb9b68c8ebe93a08cc Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Fri, 14 Feb 2025 15:50:09 +0100 Subject: [PATCH 12/15] Tensor input for as q40_tensor --- metal/src/kernels/matmul/mod.rs | 7 ++++--- metal/src/ops/gemm.rs | 2 +- metal/src/transform.rs | 13 ++----------- metal/src/utils.rs | 6 ++---- 4 files changed, 9 insertions(+), 19 deletions(-) diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index e604c420bb..59bb7d3722 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -234,7 +234,7 @@ impl GemmImpl { a: &MetalTensor, b: &MetalTensor, ) -> TractResult { - let b_shape = as_q40_tensor(&b) + let b_shape = as_q40_tensor(b.view().tensor) .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect()) .unwrap_or(b.shape().to_vec()); @@ -258,7 +258,8 @@ impl GemmImpl { b.retain_until_completion(); c.retain_until_completion(); - let b_shape = as_q40_tensor(&b) + let q40_b = as_q40_tensor(b.view().tensor); + let b_shape = q40_b .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect()) .unwrap_or(b.shape().to_vec()); @@ -276,7 +277,7 @@ impl GemmImpl { b.metal_offset(), &b_shape, self.transpose_b, - as_q40_tensor(&b).is_some(), + q40_b.is_some(), c.metal_offset(), c.shape(), )?; diff --git a/metal/src/ops/gemm.rs b/metal/src/ops/gemm.rs index edf45ae5f8..6b051533f1 100644 --- a/metal/src/ops/gemm.rs +++ b/metal/src/ops/gemm.rs @@ -86,7 +86,7 @@ impl MetalEvalOp for MetalGemm { .to_metal_tensor() .with_context(|| anyhow!("B tensor is not a metal tensor {:?}", b_opaque))?; - let b_shape = as_q40_tensor(&b) + let b_shape = as_q40_tensor(b.view().tensor) .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect()) .unwrap_or(b.shape().to_vec()); diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 1c3da27094..1b268f99ef 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -14,7 +14,7 @@ use crate::rewrite_rules::{ BasicSilu, }; use crate::tensor::MetalTensorExt; -use crate::utils::as_q40_fact; +use crate::utils::{as_q40_fact, as_q40_tensor}; use crate::{IntoMetal, MetalFact, MetalTensor}; use std::borrow::Cow; use std::fmt::Debug; @@ -485,16 +485,7 @@ fn convert_logic_ops_to_metal(op: &Comp) -> ops::MetalBinOp { } fn convert_const(op: &Const) -> TractResult { - let (tensor, metal_fact) = if let Some(curr_bqv) = op - .0 - .view() - .tensor - .to_scalar::() - .ok() - .and_then(|of| of.downcast_ref::()) - .map(|bqv| if bqv.fact.format.same_as(&Q4_0) { Some(bqv) } else { None }) - .flatten() - { + let (tensor, metal_fact) = if let Some(curr_bqv) = as_q40_tensor(op.0.view().tensor) { let mut bqv = curr_bqv.clone(); crate::utils::tract_to_gguf_q4_0_packing(&mut (bqv.value))?; diff --git a/metal/src/utils.rs b/metal/src/utils.rs index 41bcec1c6b..6ffe3312d3 100644 --- a/metal/src/utils.rs +++ b/metal/src/utils.rs @@ -114,10 +114,8 @@ pub fn as_q40_fact(fact: &TypedFact) -> Option<&BlockQuantFact> { }) } -pub fn as_q40_tensor(a: &MetalTensor) -> Option<&BlockQuantValue> { - a.view() - .tensor - .to_scalar::() +pub fn as_q40_tensor(a: &Tensor) -> Option<&BlockQuantValue> { + a.to_scalar::() .ok() .map(|od| { od.downcast_ref::() From 467d66bef9e82867014f63d80b062ef75bb1d5ad Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Fri, 14 Feb 2025 16:09:56 +0100 Subject: [PATCH 13/15] safer from_tensor --- metal/src/tensor/owned.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/metal/src/tensor/owned.rs b/metal/src/tensor/owned.rs index 4da61211a9..4cb23654c3 100644 --- a/metal/src/tensor/owned.rs +++ b/metal/src/tensor/owned.rs @@ -1,10 +1,10 @@ +use crate::utils::as_q40_tensor; use crate::MetalTensor; use anyhow::Result; use metal::Buffer; use num_traits::AsPrimitive; use std::fmt::Display; use tract_core::internal::*; -use tract_linalg::frame::block_quant::BlockQuantValue; #[derive(Debug, Clone, Hash)] pub enum MValue { @@ -138,15 +138,9 @@ impl OwnedMetalTensor { tensor_view.datum_type(), ); - let data_bytes = if tensor_view.datum_type() == DatumType::Opaque { - &tensor_view - .tensor - .to_scalar::() - .map(|od| od.downcast_ref::().unwrap())? - .value - } else { - tensor_view.tensor.as_bytes() - }; + let data_bytes = as_q40_tensor(tensor_view.tensor) + .map(|bqv| bqv.value.as_bytes()) + .unwrap_or(tensor_view.tensor.as_bytes()); let buffer = ctxt.buffer_from_slice(data_bytes); Ok(OwnedMetalTensor { inner: m_value, metal: buffer }) From 1a5de3612e780d9eb43cd78324799da9bc845964 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Fri, 14 Feb 2025 16:16:58 +0100 Subject: [PATCH 14/15] fix clippy --- metal/src/kernels/matmul/mod.rs | 4 ++-- metal/src/ops/gemm.rs | 2 +- metal/src/transform.rs | 11 +++++------ metal/src/utils.rs | 22 ++++++++++------------ 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 59bb7d3722..f29b90bcf9 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -235,7 +235,7 @@ impl GemmImpl { b: &MetalTensor, ) -> TractResult { let b_shape = as_q40_tensor(b.view().tensor) - .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect()) + .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().copied()).collect()) .unwrap_or(b.shape().to_vec()); let c_dt = self.matmul.output_dt(a.datum_type(), b.datum_type())?; @@ -260,7 +260,7 @@ impl GemmImpl { let q40_b = as_q40_tensor(b.view().tensor); let b_shape = q40_b - .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect()) + .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().copied()).collect()) .unwrap_or(b.shape().to_vec()); ensure!(c.shape() == self.output_shape(a.shape(), &b_shape).as_slice()); diff --git a/metal/src/ops/gemm.rs b/metal/src/ops/gemm.rs index 6b051533f1..84f62b679c 100644 --- a/metal/src/ops/gemm.rs +++ b/metal/src/ops/gemm.rs @@ -87,7 +87,7 @@ impl MetalEvalOp for MetalGemm { .with_context(|| anyhow!("B tensor is not a metal tensor {:?}", b_opaque))?; let b_shape = as_q40_tensor(b.view().tensor) - .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().map(|d| *d)).collect()) + .map(|bqv| b.shape().iter().cloned().chain(bqv.fact.shape.iter().copied()).collect()) .unwrap_or(b.shape().to_vec()); let c_dt = self.kernel.matmul.output_dt(a.datum_type(), b.datum_type())?; diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 1b268f99ef..e45c14560d 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -28,7 +28,7 @@ use tract_core::ops::element_wise::ElementWiseOp; use tract_core::ops::konst::Const; use tract_core::ops::logic::Comp; use tract_core::ops::nn::{Reduce, Softmax as CoreSoftmax}; -use tract_core::tract_linalg::frame::block_quant::{BlockQuantValue, Q4_0}; +use tract_core::tract_linalg::frame::block_quant::BlockQuantValue; use tract_core::transform::ModelTransform; use tract_itertools::Itertools; @@ -341,7 +341,7 @@ macro_rules! map_element_wise_ops { } fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, in_facts: &[TypedFact]) -> bool { - let is_supported = match gemm_impl { + match gemm_impl { MetalGemmImplKind::Mlx => MlxGemm.is_supported_dts(in_facts), MetalGemmImplKind::Mps => MpsMatMul.is_supported_dts(in_facts), MetalGemmImplKind::Mfa => MfaGemm.is_supported_dts(in_facts), @@ -349,8 +349,7 @@ fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, in_facts: &[TypedFact]) -> GgmlGemm.is_supported_dts(in_facts) || GgmlGemm.is_supported_dts(&[in_facts[1].clone(), in_facts[0].clone()]) } - }; - is_supported + } } fn convert_matmul_to_metal( @@ -390,7 +389,7 @@ fn convert_matmul_to_metal( let a_pos = swap_inputs as usize; let b_pos = 1 - swap_inputs as usize; if op.transpose_a { - assert!(!as_q40_fact(input_facts[a_pos]).is_some(), "Cannot transpose Q40 tensor"); + assert!(as_q40_fact(input_facts[b_pos]).is_none(), "Cannot transpose Q40 tensor"); let rank = input_facts[a_pos].rank(); let perm_a_op = ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( @@ -408,7 +407,7 @@ fn convert_matmul_to_metal( } if !op.transpose_b { - assert!(!as_q40_fact(input_facts[b_pos]).is_some(), "Cannot transpose Q40 tensor"); + assert!(as_q40_fact(input_facts[b_pos]).is_none(), "Cannot transpose Q40 tensor"); let rank = input_facts[b_pos].rank(); let perm_b_op = ops::change_axes::MetalAxisOp::from_tract_core(AxisOp::Move( diff --git a/metal/src/utils.rs b/metal/src/utils.rs index 6ffe3312d3..a8569eea26 100644 --- a/metal/src/utils.rs +++ b/metal/src/utils.rs @@ -1,5 +1,4 @@ use crate::fact::{MetalFact, MetalOrigin, MetalTypedFactExt}; -use crate::MetalTensor; use num_traits::{AsPrimitive, Zero}; use tract_core::internal::*; use tract_linalg::frame::block_quant::{BlockQuantFact, BlockQuantValue, Q4_0}; @@ -101,28 +100,27 @@ pub fn as_q40_fact(fact: &TypedFact) -> Option<&BlockQuantFact> { fact.opaque_fact .as_ref() .and_then(|of| of.downcast_ref::()) - .map(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }) - .flatten() + .and_then(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }) .or_else(|| { fact.konst .as_ref() .and_then(|k| k.to_scalar::().ok()) .and_then(|o| o.downcast_ref::()) .map(|v| &v.fact) - .map(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }) - .flatten() + .and_then(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None }) }) } pub fn as_q40_tensor(a: &Tensor) -> Option<&BlockQuantValue> { - a.to_scalar::() - .ok() - .map(|od| { - od.downcast_ref::() - .map(|bqv| if bqv.fact.format.same_as(&Q4_0) { Some(bqv) } else { None }) - .flatten() + a.to_scalar::().ok().and_then(|od| { + od.downcast_ref::().and_then(|bqv| { + if bqv.fact.format.same_as(&Q4_0) { + Some(bqv) + } else { + None + } }) - .flatten() + }) } pub fn tract_to_gguf_q4_0_packing(data: &mut Blob) -> TractResult<()> { From 604714bfd97e1aac5570cb83cd2f14532ae77aab Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Fri, 14 Feb 2025 17:30:17 +0100 Subject: [PATCH 15/15] address PR comments --- metal/src/kernels/matmul/ggml_gemm/mod.rs | 95 +++++++++-------------- metal/src/kernels/matmul/mod.rs | 24 ++---- 2 files changed, 44 insertions(+), 75 deletions(-) diff --git a/metal/src/kernels/matmul/ggml_gemm/mod.rs b/metal/src/kernels/matmul/ggml_gemm/mod.rs index 13e28ed556..c47054a12d 100644 --- a/metal/src/kernels/matmul/ggml_gemm/mod.rs +++ b/metal/src/kernels/matmul/ggml_gemm/mod.rs @@ -6,6 +6,7 @@ use anyhow::{ensure, Result}; use metal::{Buffer, MTLSize, NSUInteger}; use std::fmt; use tract_core::internal::*; +use DatumType::{F16, F32}; #[derive(Debug)] #[repr(C)] @@ -37,15 +38,17 @@ impl GemmKernel for GgmlGemm { fn is_supported_dts(&self, facts: &[TypedFact]) -> bool { assert!(facts.len() == 2, "Ggml: Expected 2 inputs for Matmul"); - (as_q40_fact(&facts[1]).is_some() - && matches!(facts[0].datum_type, DatumType::F16 | DatumType::F32)) - || ((facts[1].datum_type == DatumType::F32) && (facts[0].datum_type == DatumType::F32)) - || ((facts[1].datum_type == DatumType::F16) - && matches!(facts[0].datum_type, DatumType::F32 | DatumType::F16)) + let regular_types_support = matches!( + (facts[0].datum_type, facts[1].datum_type), + (F32, F32) | (F16, F16) | (F16, F32) + ); + + regular_types_support + || (as_q40_fact(&facts[1]).is_some() && matches!(facts[0].datum_type, F16 | F32)) } fn output_dt(&self, _a_dt: DatumType, _b_dt: DatumType) -> TractResult { - Ok(DatumType::F32) + Ok(F32) } fn dispatch_eval( @@ -109,7 +112,7 @@ impl GemmKernel for GgmlGemm { batch_broadcast_ratio: 1, }; - if (dts[0] == DatumType::F32) && (k % 32 == 0) && (k >= 64) && (m > 4) { + if (dts[0] == F32) && (k % 32 == 0) && (k >= 64) && (m > 4) { dispatch_metal_ggml_gemm( context, dts, q40_b, params, a_offset, a_buffer, b_offset, b_buffer, c_buffer, c_offset, @@ -132,11 +135,11 @@ fn mv_kernel_name_and_dispatch_params( ) -> Result<(String, (u64, u64, u64))> { let (nth0, nth1, nrows): (u64, u64, u64) = (32, 1, 1); - if dts[1] == DatumType::F32 { - ensure!(dts[0] == DatumType::F32); + if dts[1] == F32 { + ensure!(dts[0] == F32); Ok(("kernel_mul_mv_f32_f32".to_string(), (nth0, nth1, 4))) - } else if dts[1] == DatumType::F16 { - if dts[0] == DatumType::F32 { + } else if dts[1] == F16 { + if dts[0] == F32 { if (params.m * params.batch) < 4 { Ok(("kernel_mul_mv_f16_f32_1row".to_string(), (nth0, nth1, nrows))) } else if (params.k >= 128) && (params.k % 4 == 0) && (params.n >= 8) { @@ -146,11 +149,11 @@ fn mv_kernel_name_and_dispatch_params( } } else { // Never used in practice since we upcast input[0] to f32 - ensure!(dts[1] == DatumType::F16); + ensure!(dts[1] == F16); Ok(("kernel_mul_mv_f16_f16".to_string(), (nth0, nth1, 4))) } } else { - ensure!((q40_b) && (dts[0] == DatumType::F32)); + ensure!((q40_b) && (dts[0] == F32)); Ok(("kernel_mul_mv_q4_0_f32".to_string(), (8, 8, 1))) } } @@ -220,9 +223,7 @@ fn dispatch_metal_ggml_gemm( output: &Buffer, output_offset: usize, ) -> Result<()> { - ensure!( - (matches!(dts[1], DatumType::F32 | DatumType::F16) || q40_b) && dts[0] == DatumType::F32 - ); + ensure!((matches!(dts[1], F32 | F16) || q40_b) && dts[0] == F32); let mut i1_tname = MetalTensor::tname(dts[1])?; let i2_tname = MetalTensor::tname(dts[0])?; @@ -280,63 +281,39 @@ mod tests { #[test] fn test_mat_mul() -> TractResult<()> { - run_mmm_test_case::((1, 5, 64, 2), false, true, DatumType::F32, DatumType::F32)?; - run_mmm_test_case::((2, 1, 32, 2), false, true, DatumType::F32, DatumType::F32)?; - run_mmm_test_case::((1, 5, 64, 2), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::( - (3, 8, 64, 200), - false, - true, - DatumType::F32, - DatumType::F16, - )?; - run_mmm_test_case::( - (10, 25, 512, 320), - false, - true, - DatumType::F32, - DatumType::F16, - )?; + run_mmm_test_case::((1, 5, 64, 2), false, true, F32, F32)?; + run_mmm_test_case::((2, 1, 32, 2), false, true, F32, F32)?; + run_mmm_test_case::((1, 5, 64, 2), false, true, F32, F16)?; + run_mmm_test_case::((3, 8, 64, 200), false, true, F32, F16)?; + run_mmm_test_case::((10, 25, 512, 320), false, true, F32, F16)?; Ok(()) } #[test] fn test_mat_vec() -> TractResult<()> { // f32_f32 - run_mmm_test_case::((1, 8, 32, 3), false, true, DatumType::F32, DatumType::F32)?; - run_mmm_test_case::((1, 4, 61, 2), false, true, DatumType::F32, DatumType::F32)?; - run_mmm_test_case::((2, 4, 128, 8), false, true, DatumType::F32, DatumType::F32)?; + run_mmm_test_case::((1, 8, 32, 3), false, true, F32, F32)?; + run_mmm_test_case::((1, 4, 61, 2), false, true, F32, F32)?; + run_mmm_test_case::((2, 4, 128, 8), false, true, F32, F32)?; // f16_f32_1row - run_mmm_test_case::((1, 1, 32, 2), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::((1, 3, 62, 2), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::((1, 3, 2, 9), false, true, DatumType::F32, DatumType::F16)?; + run_mmm_test_case::((1, 1, 32, 2), false, true, F32, F16)?; + run_mmm_test_case::((1, 3, 62, 2), false, true, F32, F16)?; + run_mmm_test_case::((1, 3, 2, 9), false, true, F32, F16)?; // f16_f32_L4 - run_mmm_test_case::((2, 2, 128, 8), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::( - (4, 4, 156, 30), - false, - true, - DatumType::F32, - DatumType::F16, - )?; + run_mmm_test_case::((2, 2, 128, 8), false, true, F32, F16)?; + run_mmm_test_case::((4, 4, 156, 30), false, true, F32, F16)?; // f16_f32 - run_mmm_test_case::((1, 4, 32, 2), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::((1, 4, 61, 2), false, true, DatumType::F32, DatumType::F16)?; - run_mmm_test_case::((4, 4, 128, 7), false, true, DatumType::F32, DatumType::F16)?; + run_mmm_test_case::((1, 4, 32, 2), false, true, F32, F16)?; + run_mmm_test_case::((1, 4, 61, 2), false, true, F32, F16)?; + run_mmm_test_case::((4, 4, 128, 7), false, true, F32, F16)?; // f16_f16 - run_mmm_test_case::((1, 1, 2, 1), false, true, DatumType::F16, DatumType::F16)?; - run_mmm_test_case::((1, 4, 61, 2), false, true, DatumType::F16, DatumType::F16)?; - run_mmm_test_case::( - (2, 16, 128, 9), - false, - true, - DatumType::F16, - DatumType::F16, - )?; + run_mmm_test_case::((1, 1, 2, 1), false, true, F16, F16)?; + run_mmm_test_case::((1, 4, 61, 2), false, true, F16, F16)?; + run_mmm_test_case::((2, 16, 128, 9), false, true, F16, F16)?; Ok(()) } diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index f29b90bcf9..c6036b6257 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -220,12 +220,8 @@ impl GemmImpl { b_dt: DatumType, ) -> TractResult> { let out_dt = self.matmul.output_dt(a_dt, b_dt)?; - if out_dt == DatumType::F32 { - Ok(tvec!(f32::fact(shape))) - } else { - ensure!(out_dt == DatumType::F16); - Ok(tvec!(f16::fact(shape))) - } + ensure!([DatumType::F16, DatumType::F32].contains(&out_dt)); + Ok(tvec!(out_dt.fact(shape))) } pub fn eval( @@ -327,6 +323,7 @@ mod tests { use derive_new::new; use num_traits::AsPrimitive; use num_traits::Float; + use proptest::collection::vec; use proptest::prelude::*; use tract_core::ops::einsum::BasicMatMul; use tract_core::tract_data::itertools::Itertools; @@ -727,22 +724,17 @@ mod tests { k = k.div_ceil(32) * 32 }; - let mut rng = rand::thread_rng(); - let lhs_data: Vec = (0..b * m * k) // Create a vector with 10 elements - .map(|_| F::from(rng.gen_range(0.0..1.0)).unwrap()) // Generate a random float in [0.0, 1.0) - .collect(); - - let rhs_data: Vec = (0..b * n * k) // Create a vector with 10 elements - .map(|_| F::from(rng.gen_range(0.0..1.0)).unwrap()) // Generate a random float in [0.0, 1.0) - .collect(); + let lhs_len = b * m * k; + let rhs_len = b * n * k; + let datum = (0usize..100).prop_map(|x| x.as_()); ( Just(b), Just(m), Just(k), Just(n), - Just(lhs_data), + vec(datum.clone(), lhs_len..=lhs_len), proptest::bool::ANY, - Just(rhs_data), + vec(datum, rhs_len..=rhs_len), proptest::bool::ANY, ) })