From 0c83d5ba1fb28ec53d816008ec2414d459c39024 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Mon, 27 Jan 2025 14:12:25 +0100 Subject: [PATCH 1/9] Pass all datum types to metal matmul ops --- metal/src/kernels/matmul/basic/mod.rs | 7 +++- metal/src/kernels/matmul/mfa/mod.rs | 12 ++++-- metal/src/kernels/matmul/mlx_gemm/mod.rs | 29 ++++++++------ metal/src/kernels/matmul/mod.rs | 50 ++++++++++++++---------- metal/src/kernels/matmul/mps/mod.rs | 11 ++++-- 5 files changed, 68 insertions(+), 41 deletions(-) diff --git a/metal/src/kernels/matmul/basic/mod.rs b/metal/src/kernels/matmul/basic/mod.rs index 0206e46397..5ab957da11 100644 --- a/metal/src/kernels/matmul/basic/mod.rs +++ b/metal/src/kernels/matmul/basic/mod.rs @@ -28,7 +28,7 @@ impl GemmKernel for BasicMatMul { c_buffer: &Buffer, ) -> TractResult<()> { let GemmDispatchParams { - dt, + dts, batch, m, k, @@ -39,6 +39,11 @@ impl GemmKernel for BasicMatMul { b_offset, c_offset, } = params; + + ensure!(dts[0] == dts[1]); + ensure!(dts[0] == dts[2]); + + let dt = dts[0]; for b_idx in 0..batch { let a_offset = a_offset + b_idx * m * k * dt.size_of(); let b_offset = b_offset + b_idx * n * k * dt.size_of(); diff --git a/metal/src/kernels/matmul/mfa/mod.rs b/metal/src/kernels/matmul/mfa/mod.rs index 041cea790a..70333aa310 100644 --- a/metal/src/kernels/matmul/mfa/mod.rs +++ b/metal/src/kernels/matmul/mfa/mod.rs @@ -33,7 +33,7 @@ impl GemmKernel for MfaGemm { c_buffer: &Buffer, ) -> TractResult<()> { let GemmDispatchParams { - dt, + dts, batch, m, k, @@ -58,7 +58,7 @@ impl GemmKernel for MfaGemm { dispatch_metal_mfa_gemm( context, - dt, + dts, (batch, m, n, k), unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) }, a_offset, @@ -80,7 +80,7 @@ impl GemmKernel for MfaGemm { #[allow(clippy::too_many_arguments)] pub fn dispatch_metal_mfa_gemm( context: &MetalContext, - dt: DatumType, + dts: [DatumType; 3], (b, m, n, k): (usize, usize, usize, usize), lhs_stride: &[usize], lhs_offset: usize, @@ -167,7 +167,11 @@ pub fn dispatch_metal_mfa_gemm( (211, Value::U16(n_splits)), (50_001, Value::Bool(fused_bias)), ])); - + + ensure!(dts[0] == dts[1]); + ensure!(dts[0] == dts[2]); + + let dt = dts[0]; let name = match dt { DatumType::F32 => "sgemm", DatumType::F16 => "hgemm", diff --git a/metal/src/kernels/matmul/mlx_gemm/mod.rs b/metal/src/kernels/matmul/mlx_gemm/mod.rs index 7ea42145ff..950734f721 100644 --- a/metal/src/kernels/matmul/mlx_gemm/mod.rs +++ b/metal/src/kernels/matmul/mlx_gemm/mod.rs @@ -67,7 +67,7 @@ impl GemmKernel for MlxGemm { c_buffer: &Buffer, ) -> TractResult<()> { let GemmDispatchParams { - dt, + dts, batch, m, k, @@ -93,7 +93,7 @@ impl GemmKernel for MlxGemm { if m == 1 || n == 1 { dispatch_metal_mlx_gemv( context, - dt, + dts, (batch, m, n, k), unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) }, a_offset, @@ -109,7 +109,7 @@ impl GemmKernel for MlxGemm { } else { dispatch_metal_mlx_gemm( context, - dt, + dts, (batch, m, n, k), unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) }, a_offset, @@ -132,7 +132,7 @@ impl GemmKernel for MlxGemm { #[allow(clippy::too_many_arguments)] pub fn dispatch_metal_mlx_gemv( context: &MetalContext, - dt: DatumType, + dts: [DatumType; 3], (b, m, n, k): (usize, usize, usize, usize), a_strides: &[usize], a_offset: usize, @@ -148,7 +148,8 @@ pub fn dispatch_metal_mlx_gemv( ensure!(m == 1 || n == 1); assert!(a_strides.len() >= 2 && b_strides.len() >= 2); assert!(a_strides.len() >= 2); - ensure!(matches!(dt, DatumType::F32 | DatumType::F16)); + ensure!(matches!(dts[0], DatumType::F32 | DatumType::F16)); + ensure!(dts[0] == dts[1] && dts[0] == dts[2]); let lda = if a_trans { m } else { k }; let ldb = if b_trans { k } else { n }; @@ -201,7 +202,7 @@ pub fn dispatch_metal_mlx_gemv( let t_mat = if mv_trans { "t_" } else { "" }; - let tname = MetalTensor::tname(dt)?; + let tname = MetalTensor::tname(dts[0])?; let name = format!("gemv_{t_mat}{tname}_bm{bm}_bn{bn}_sm{sm}_sn{sn}_tm{tm}_tn{tn}_nc0_axpby0"); let pipeline = context.shared_context().load_pipeline(LibraryName::MlxGemv, &name)?; @@ -269,7 +270,7 @@ pub fn dispatch_metal_mlx_gemv( #[allow(clippy::too_many_arguments)] pub fn dispatch_metal_mlx_gemm( context: &MetalContext, - dt: DatumType, + dts: [DatumType; 3], (b, m, n, k): (usize, usize, usize, usize), lhs_stride: &[usize], lhs_offset: usize, @@ -285,7 +286,8 @@ pub fn dispatch_metal_mlx_gemm( ) -> Result<()> { assert!(rhs_stride.len() >= 2); assert!(lhs_stride.len() >= 2); - ensure!(matches!(dt, DatumType::F32 | DatumType::F16)); + ensure!(matches!(dts[0], DatumType::F32 | DatumType::F16)); + ensure!(dts[0] == dts[1] && dts[0] == dts[2]); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; @@ -358,7 +360,7 @@ pub fn dispatch_metal_mlx_gemm( let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b]; - let name = kernel_name_gemm(dt, a_trans, b_trans)?; + let name = kernel_name_gemm(dts, a_trans, b_trans)?; let pipeline = context.shared_context().load_pipeline_with_constants( LibraryName::MlxGemm, @@ -420,12 +422,13 @@ pub fn dispatch_metal_mlx_gemm( Ok(()) } -pub fn kernel_name_gemm(dt: DatumType, transpose_a: bool, transpose_b: bool) -> Result { +pub fn kernel_name_gemm(dts: [DatumType; 3], transpose_a: bool, transpose_b: bool) -> Result { let t_a = if transpose_a { "t" } else { "n" }; let t_b = if transpose_b { "t" } else { "n" }; - ensure!(matches!(dt, DatumType::F32 | DatumType::F16)); - let tname = MetalTensor::tname(dt)?; - Ok(format!("gemm_{t_a}{t_b}_{tname}_{tname}_32_32_16_2_2")) + + let i_tname = MetalTensor::tname(dts[0])?; + let o_tname = MetalTensor::tname(dts[2])?; + Ok(format!("gemm_{t_a}{t_b}_{i_tname}_{o_tname}_32_32_16_2_2")) } #[cfg(test)] diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 83d3850fff..48cf099cc6 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -31,7 +31,7 @@ impl Default for MetalGemmImplKind { #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] pub struct GemmDispatchParams { - pub dt: DatumType, + pub dts: [DatumType; 3], pub batch: usize, pub m: usize, pub k: usize, @@ -46,7 +46,7 @@ pub struct GemmDispatchParams { impl GemmDispatchParams { #[allow(clippy::too_many_arguments)] pub fn compute_dispatches_params( - dt: DatumType, + dts: [DatumType; 3], a_offset: usize, a_shape: &[usize], transpose_a: bool, @@ -74,7 +74,7 @@ impl GemmDispatchParams { // bmk, 1kn -> bmn // bmk, 1nk -> bmn (a_batch, 1) if a_batch != 1 && !transpose_a => Ok(vec![GemmDispatchParams { - dt, + dts, batch: 1, m: m * a_batch, n, @@ -90,16 +90,16 @@ impl GemmDispatchParams { // As many dispatches as batch dimension. (a_batch, 1) if a_batch != 1 => Ok((0..a_batch) .map(|a_batch_idx| GemmDispatchParams { - dt, + dts, batch: 1, m, n, k, transpose_a, - a_offset: a_offset + a_batch_idx * m * k * dt.size_of(), + a_offset: a_offset + a_batch_idx * m * k * dts[0].size_of(), transpose_b, b_offset, - c_offset: c_offset + a_batch_idx * m * n * dt.size_of(), + c_offset: c_offset + a_batch_idx * m * n * dts[2].size_of(), }) .collect()), // 1mk, bkn -> bmn @@ -109,7 +109,7 @@ impl GemmDispatchParams { // As many dispatch as batch dimension. (1, b_batch) if b_batch != 1 => Ok((0..b_batch) .map(|b_batch_idx| GemmDispatchParams { - dt, + dts, batch: 1, m, n, @@ -117,8 +117,8 @@ impl GemmDispatchParams { transpose_a, a_offset, transpose_b, - b_offset: b_offset + b_batch_idx * n * k * dt.size_of(), - c_offset: c_offset + b_batch_idx * m * n * dt.size_of(), + b_offset: b_offset + b_batch_idx * n * k * dts[0].size_of(), + c_offset: c_offset + b_batch_idx * m * n * dts[2].size_of(), }) .collect()), // bmk, bkn -> bmn @@ -129,7 +129,7 @@ impl GemmDispatchParams { ensure!(a_batch == b_batch); Ok(vec![GemmDispatchParams { - dt, + dts, batch: a_batch, m, n, @@ -226,7 +226,7 @@ impl GemmImpl { } let dispatches = GemmDispatchParams::compute_dispatches_params( - c.datum_type(), + [a.datum_type(), b.datum_type(), c.datum_type()], a.metal_offset(), a.shape(), self.transpose_a, @@ -330,7 +330,9 @@ mod tests { let (m, k, n) = (2, 3, 4); assert_eq!( GemmDispatchParams::compute_dispatches_params( + [dt, dt, + dt], 0, &[1, m, k], false, @@ -341,7 +343,7 @@ mod tests { &[1, m, n], )?, vec![GemmDispatchParams { - dt, + dts: [dt; 3], batch: 1, m, n, @@ -356,7 +358,9 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( + [dt, dt, + dt], 0, &[10, m, k], false, @@ -367,7 +371,7 @@ mod tests { &[10, m, n], )?, vec![GemmDispatchParams { - dt, + dts: [dt; 3], batch: 10, m, n, @@ -382,7 +386,9 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( + [dt, dt, + dt], 0, &[1, m, k], false, @@ -394,7 +400,7 @@ mod tests { )?, vec![ GemmDispatchParams { - dt, + dts: [dt; 3], batch: 1, m, n, @@ -406,7 +412,7 @@ mod tests { c_offset: 10, }, GemmDispatchParams { - dt, + dts: [dt; 3], batch: 1, m, n, @@ -422,7 +428,9 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( + [dt, dt, + dt], 0, &[2, k, m], true, @@ -433,7 +441,7 @@ mod tests { &[2, m, n], )?, vec![GemmDispatchParams { - dt, + dts: [dt; 3], batch: 2, m, n, @@ -448,7 +456,9 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( + [dt, dt, + dt], 0, &[2, k, m], true, @@ -460,7 +470,7 @@ mod tests { )?, vec![ GemmDispatchParams { - dt, + dts: [dt; 3], batch: 1, m, n, @@ -472,7 +482,7 @@ mod tests { c_offset: 100, }, GemmDispatchParams { - dt, + dts: [dt; 3], batch: 1, m, n, @@ -488,7 +498,7 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( - dt, + [dt; 3], 0, &[10, m, k], false, @@ -499,7 +509,7 @@ mod tests { &[10, m, n], )?, vec![GemmDispatchParams { - dt, + dts: [dt; 3], batch: 1, m: 10 * m, n, diff --git a/metal/src/kernels/matmul/mps/mod.rs b/metal/src/kernels/matmul/mps/mod.rs index df5c221e0d..2deca71b1d 100644 --- a/metal/src/kernels/matmul/mps/mod.rs +++ b/metal/src/kernels/matmul/mps/mod.rs @@ -47,7 +47,7 @@ impl GemmKernel for MpsMatMul { c_buffer: &Buffer, ) -> TractResult<()> { let GemmDispatchParams { - dt, + dts, batch, m, k, @@ -59,10 +59,15 @@ impl GemmKernel for MpsMatMul { c_offset, } = params; - let data_type = match params.dt { + ensure!(dts[0] == dts[1]); + ensure!(dts[0] == dts[2]); + + let dt = dts[0]; + + let data_type = match dt { DatumType::F32 => MPSDataType::Float32, DatumType::F16 => MPSDataType::Float16, - _ => bail!("Unsupported datum type for MpsMatMul {:?}", params.dt), + _ => bail!("Unsupported datum type for MpsMatMul {:?}", dt), }; for b_idx in 0..batch { From 923a6155abb2cdd5fe706ceed0b724e7b1362365 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Mon, 27 Jan 2025 17:11:49 +0100 Subject: [PATCH 2/9] Add cast if matmul outputs different type than expected --- metal/src/kernels/matmul/mod.rs | 17 ++++++- metal/src/ops/gemm.rs | 9 +--- metal/src/transform.rs | 84 +++++++++++++++++++-------------- 3 files changed, 65 insertions(+), 45 deletions(-) diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 48cf099cc6..1bb1f85cd7 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -150,6 +150,17 @@ pub trait GemmKernel: fmt::Display + fmt::Debug + Clone + Default + Send + Sync fn is_supported_dt(&self, dt: DatumType) -> bool; + fn output_facts(&self, a: &TypedFact, b:&TypedFact, c_shape: Vec) -> TractResult> { + if a.datum_type == f16::datum_type() { + ensure!(b.datum_type == f16::datum_type()); + Ok(tvec!(f16::fact(c_shape))) + } else { + ensure!(a.datum_type == f32::datum_type()); + ensure!(b.datum_type == f32::datum_type()); + Ok(tvec!(f32::fact(c_shape))) + } + } + fn dispatch_eval( &self, context: &MetalContext, @@ -192,6 +203,11 @@ impl GemmImpl { output } + pub fn output_facts(&self, a: &TypedFact, b: &TypedFact) -> TractResult> { + let out_shape = self.output_shape(&a.shape, &b.shape).to_vec(); + self.matmul.output_facts(a, b, out_shape) + } + pub fn eval( &self, context: &MetalContext, @@ -218,7 +234,6 @@ impl GemmImpl { b.retain_until_completion(); c.retain_until_completion(); - ensure!(c.datum_type() == a.datum_type()); ensure!(c.shape() == self.output_shape(a.shape(), b.shape()).as_slice()); if c.shape().iter().product::() == 0 { diff --git a/metal/src/ops/gemm.rs b/metal/src/ops/gemm.rs index 17881dbed5..fc82a27d8d 100644 --- a/metal/src/ops/gemm.rs +++ b/metal/src/ops/gemm.rs @@ -50,14 +50,7 @@ impl MetalGemm { == b.shape[b.rank() - 2 + self.transpose_b() as usize] ); - if a.datum_type == f16::datum_type() { - ensure!(b.datum_type == f16::datum_type()); - Ok(tvec!(f16::fact(self.kernel.output_shape(&a.shape, &b.shape)))) - } else { - ensure!(a.datum_type == f32::datum_type()); - ensure!(b.datum_type == f32::datum_type()); - Ok(tvec!(f32::fact(self.kernel.output_shape(&a.shape, &b.shape)))) - } + self.kernel.output_facts(a, b) } } diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 6897d176ea..ae4c3607eb 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -198,74 +198,76 @@ impl Translate, TypedFact, Box> for Met .iter() .all(|f| MetalTensor::is_supported_dt(f.datum_type) || f.as_metal_fact().is_some()); - let new_metal_op: Option> = if !in_dts_metal_compatible { + let new_metal_ops: Option>> = if !in_dts_metal_compatible { None } else if let Some(op) = node.op_as::() { - convert_element_wise_ops_to_metal(op).map(|o| -> Box { Box::new(o) }) + convert_element_wise_ops_to_metal(op).map(|o| -> Vec> { vec![Box::new(o)] }) } else if let Some(op) = node.op_as::() { - convert_bin_ops_to_metal(&op.0).map(|o| -> Box { Box::new(o) }) + convert_bin_ops_to_metal(&op.0).map(|o| -> Vec> { vec![Box::new(o)] }) } else if let Some(op) = node.op_as::() { - Some(Box::new(convert_logic_ops_to_metal(op))) + Some(vec![Box::new(convert_logic_ops_to_metal(op))]) } else if let Some(op) = node.op_as::() { convert_matmul_to_metal(source, node, op, self.gemm_impl)? } else if let Some(op) = node.op_as::() { - Some(Box::new(ops::MetalMultiBroadcastTo::new(op.shape.clone()))) + Some(vec![Box::new(ops::MetalMultiBroadcastTo::new(op.shape.clone()))]) } else if let Some(op) = node.op_as::() { - convert_const(op)?.map(|o| -> Box { Box::new(o) }) + convert_const(op)?.map(|o| -> Vec> { vec![Box::new(o)] }) } else if let Some(op) = node.op_as::() { check_in_dts_are_supported(source, node.id, ops::MetalCast::is_supported_dt)? .then(|| ops::MetalCast::new(op.to)) .flatten() - .map(|o| -> Box { Box::new(o) }) + .map(|o| -> Vec> { vec![Box::new(o)] }) } else if let Some(op) = node.op_as::() { let in_fact = source.node_input_facts(node.id)?[0]; - Some(Box::new(ops::MetalAxisOp::from_tract_core_with_fact(op.clone(), in_fact))) + Some(vec![Box::new(ops::MetalAxisOp::from_tract_core_with_fact(op.clone(), in_fact))]) } else if let Some(op) = node.op_as::() { - Some(Box::new(ops::MetalSlice::from_tract_core(op.clone()))) + Some(vec![Box::new(ops::MetalSlice::from_tract_core(op.clone()))]) } else if let Some(op) = node.op_as::() { - Some(Box::new(ops::MetalConcat::from_tract_core(op))) + Some(vec![Box::new(ops::MetalConcat::from_tract_core(op))]) } else if let Some(op) = node.op_as::() { check_in_dts_are_supported(source, node.id, Reducer::is_supported_dt)? .then(|| ops::MetalReduce::from_tract_core(op).ok()) .flatten() - .map(|o| -> Box { Box::new(o) }) + .map(|o| -> Vec> { vec![Box::new(o)] }) } else if let Some(op) = node.op_as::() { check_in_dts_are_supported(source, node.id, Softmax::is_supported_dt)? .then(|| ops::MetalSoftmax::from_tract_core(op).ok()) .flatten() - .map(|o| -> Box { Box::new(o) }) + .map(|o| -> Vec> { vec![Box::new(o)] }) } else if let Some(op) = node.op_as::() { check_in_dts_are_supported(source, node.id, ScaledMaskedSoftmax::is_supported_dt)? .then(|| ops::MetalScaledMaskedSoftmax { scale: op.scale.clone() }) - .map(|o| -> Box { Box::new(o) }) + .map(|o| -> Vec> { vec![Box::new(o)] }) } else if let Some(op) = node.op_as::() { check_in_dts_are_supported(source, node.id, RmsNorm::is_supported_dt)? .then(|| ops::MetalRmsNorm::new(op.axis, op.eps.clone())) - .map(|o| -> Box { Box::new(o) }) + .map(|o| -> Vec> { vec![Box::new(o)] }) } else if let Some(_op) = node.op_as::() { check_in_dts_are_supported(source, node.id, RotateHalf::is_supported_dt)? .then_some(ops::MetalRotateHalf) - .map(|o| -> Box { Box::new(o) }) + .map(|o| -> Vec> { vec![Box::new(o)] }) } else if let Some(_op) = node.op_as::() { check_in_dts_are_supported(source, node.id, ApplyRope::is_supported_dt)? .then_some(ops::MetalApplyRope) - .map(|o| -> Box { Box::new(o) }) + .map(|o| -> Vec> { vec![Box::new(o)] }) } else if let Some(_op) = node.op_as::() { check_in_dts_are_supported(source, node.id, Silu::is_supported_dt)? - .then(|| -> Box { Box::new(ops::MetalSilu) }) + .then(|| -> Vec> { vec![Box::new(ops::MetalSilu)] }) } else if let Some(_op) = node.op_as::() { check_in_dts_are_supported(source, node.id, NewGelu::is_supported_dt)? - .then(|| -> Box { Box::new(ops::MetalNewGelu) }) + .then(|| -> Vec> { vec![Box::new(ops::MetalNewGelu)] }) } else { None }; - match new_metal_op { - Some(metal_op) => { - let gpu_inputs = + match new_metal_ops { + Some(metal_ops) => { + let mut gpu_inputs = self.sync_inputs_if_required(target, node, mapping, MetalSyncKind::ToGpu)?; - let target_node_outlet_ids = target.wire_node(&node.name, metal_op, &gpu_inputs)?; - self.sync_model_outputs_if_required(source, node, target, target_node_outlet_ids) + for op in metal_ops { + gpu_inputs = target.wire_node(node.name.clone() + "." + &op.name(), op, &gpu_inputs)?; + } + self.sync_model_outputs_if_required(source, node, target, gpu_inputs) } None => { let cpu_inputs = @@ -314,23 +316,33 @@ fn convert_matmul_to_metal( node: &TypedNode, op: &BasicMatMul, gemm_impl: MetalGemmImplKind, -) -> Result>> { +) -> Result>>> { if !op.transpose_c && op.quantize_output.is_none() - && (model.node_input_facts(node.id)?.iter().all(|f| f.datum_type == f32::datum_type()) - || model.node_input_facts(node.id)?.iter().all(|f| f.datum_type == f16::datum_type())) + // && gemm_impl::is_supported_dts() { - match gemm_impl { - MetalGemmImplKind::Mlx => { - Ok(Some(Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)))) - } - MetalGemmImplKind::Mps => { - Ok(Some(Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)))) - } - MetalGemmImplKind::Mfa => { - Ok(Some(Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)))) - } + let matmul: Box = match gemm_impl { + MetalGemmImplKind::Mlx => { + Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + } + MetalGemmImplKind::Mps => { + Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + } + MetalGemmImplKind::Mfa => { + Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + } + }; + + let out_dt = matmul.output_facts(&model.node_input_facts(node.id)?)?[0].datum_type; + let mut ops = vec![matmul]; + + dbg!(model.node_output_facts(node.id)?[0]); + if out_dt != model.node_output_facts(node.id)?[0].datum_type { + ensure!(ops::MetalCast::is_supported_dt(out_dt), "Matmul output type cannot be casted to expected type"); + let cast_op = ops::MetalCast::new(model.node_output_facts(node.id)?[0].datum_type).unwrap(); + ops.push(Box::new(cast_op)); } + Ok(Some(ops)) } else { Ok(None) } From 1ede20e6827d924419487a0ff5e273df9e894bf9 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Tue, 28 Jan 2025 16:40:32 +0100 Subject: [PATCH 3/9] refactor translate node to allow 1->many + fmt --- metal/src/kernels/matmul/basic/mod.rs | 2 +- metal/src/kernels/matmul/mfa/mod.rs | 4 +- metal/src/kernels/matmul/mlx_gemm/mod.rs | 6 +- metal/src/kernels/matmul/mod.rs | 27 +- metal/src/kernels/matmul/mps/mod.rs | 2 +- metal/src/memory/schema.rs | 2 +- metal/src/tensor/mod.rs | 4 +- metal/src/transform.rs | 320 +++++++++++++---------- 8 files changed, 204 insertions(+), 163 deletions(-) diff --git a/metal/src/kernels/matmul/basic/mod.rs b/metal/src/kernels/matmul/basic/mod.rs index 5ab957da11..1ab4ab6d89 100644 --- a/metal/src/kernels/matmul/basic/mod.rs +++ b/metal/src/kernels/matmul/basic/mod.rs @@ -42,7 +42,7 @@ impl GemmKernel for BasicMatMul { ensure!(dts[0] == dts[1]); ensure!(dts[0] == dts[2]); - + let dt = dts[0]; for b_idx in 0..batch { let a_offset = a_offset + b_idx * m * k * dt.size_of(); diff --git a/metal/src/kernels/matmul/mfa/mod.rs b/metal/src/kernels/matmul/mfa/mod.rs index 70333aa310..99628eda73 100644 --- a/metal/src/kernels/matmul/mfa/mod.rs +++ b/metal/src/kernels/matmul/mfa/mod.rs @@ -167,10 +167,10 @@ pub fn dispatch_metal_mfa_gemm( (211, Value::U16(n_splits)), (50_001, Value::Bool(fused_bias)), ])); - + ensure!(dts[0] == dts[1]); ensure!(dts[0] == dts[2]); - + let dt = dts[0]; let name = match dt { DatumType::F32 => "sgemm", diff --git a/metal/src/kernels/matmul/mlx_gemm/mod.rs b/metal/src/kernels/matmul/mlx_gemm/mod.rs index 950734f721..c17a0cfe22 100644 --- a/metal/src/kernels/matmul/mlx_gemm/mod.rs +++ b/metal/src/kernels/matmul/mlx_gemm/mod.rs @@ -422,7 +422,11 @@ pub fn dispatch_metal_mlx_gemm( Ok(()) } -pub fn kernel_name_gemm(dts: [DatumType; 3], transpose_a: bool, transpose_b: bool) -> Result { +pub fn kernel_name_gemm( + dts: [DatumType; 3], + transpose_a: bool, + transpose_b: bool, +) -> Result { let t_a = if transpose_a { "t" } else { "n" }; let t_b = if transpose_b { "t" } else { "n" }; diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 1bb1f85cd7..33edac307c 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -150,7 +150,12 @@ pub trait GemmKernel: fmt::Display + fmt::Debug + Clone + Default + Send + Sync fn is_supported_dt(&self, dt: DatumType) -> bool; - fn output_facts(&self, a: &TypedFact, b:&TypedFact, c_shape: Vec) -> TractResult> { + fn output_facts( + &self, + a: &TypedFact, + b: &TypedFact, + c_shape: Vec, + ) -> TractResult> { if a.datum_type == f16::datum_type() { ensure!(b.datum_type == f16::datum_type()); Ok(tvec!(f16::fact(c_shape))) @@ -345,9 +350,7 @@ mod tests { let (m, k, n) = (2, 3, 4); assert_eq!( GemmDispatchParams::compute_dispatches_params( - [dt, - dt, - dt], + [dt, dt, dt], 0, &[1, m, k], false, @@ -373,9 +376,7 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( - [dt, - dt, - dt], + [dt, dt, dt], 0, &[10, m, k], false, @@ -401,9 +402,7 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( - [dt, - dt, - dt], + [dt, dt, dt], 0, &[1, m, k], false, @@ -443,9 +442,7 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( - [dt, - dt, - dt], + [dt, dt, dt], 0, &[2, k, m], true, @@ -471,9 +468,7 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( - [dt, - dt, - dt], + [dt, dt, dt], 0, &[2, k, m], true, diff --git a/metal/src/kernels/matmul/mps/mod.rs b/metal/src/kernels/matmul/mps/mod.rs index 2deca71b1d..768c46e41a 100644 --- a/metal/src/kernels/matmul/mps/mod.rs +++ b/metal/src/kernels/matmul/mps/mod.rs @@ -61,7 +61,7 @@ impl GemmKernel for MpsMatMul { ensure!(dts[0] == dts[1]); ensure!(dts[0] == dts[2]); - + let dt = dts[0]; let data_type = match dt { diff --git a/metal/src/memory/schema.rs b/metal/src/memory/schema.rs index c687f8344f..4c59d8d36c 100644 --- a/metal/src/memory/schema.rs +++ b/metal/src/memory/schema.rs @@ -65,7 +65,7 @@ pub fn eval_metal_mem_req_for_nodes( .is_some_and(|op| op.kind == MetalSyncKind::ToCpu) }) }); - + !cpu_sync_in_next_nodes && facts.iter().any(|it| it.to_metal_fact().map(|it| it.is_from_gpu()).unwrap_or(false)) }); diff --git a/metal/src/tensor/mod.rs b/metal/src/tensor/mod.rs index 2597738aa8..d3e933870c 100644 --- a/metal/src/tensor/mod.rs +++ b/metal/src/tensor/mod.rs @@ -21,7 +21,7 @@ pub enum MetalTensor { } impl MetalTensor { - pub const SUPPORTED_DT: [DatumType; 11] = [ + pub const SUPPORTED_DT: [DatumType; 12] = [ DatumType::Bool, DatumType::F32, DatumType::F16, @@ -33,6 +33,7 @@ impl MetalTensor { DatumType::U32, DatumType::I64, DatumType::U64, + DatumType::TDim ]; pub fn tname(dt: DatumType) -> TractResult<&'static str> { @@ -48,6 +49,7 @@ impl MetalTensor { DatumType::I32 => "i32", DatumType::I64 => "i64", DatumType::Bool => "bool", + DatumType::TDim => "Tdim", _ => bail!("Unsupport dt {:?} for metal kernel function", dt), }) } diff --git a/metal/src/transform.rs b/metal/src/transform.rs index ae4c3607eb..e991ca9d09 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -185,6 +185,63 @@ impl MetalTransform { } } +fn can_translate_op_to_metal_op(source: &TypedModel, node: &TypedNode) -> TractResult { + let in_dts_metal_compatible = source + .node_input_facts(node.id)? + .iter() + .all(|f| MetalTensor::is_supported_dt(f.datum_type) || f.as_metal_fact().is_some()); + + Ok(in_dts_metal_compatible + && (node + .op_as::() + .is_some_and(|op| map_element_wise_ops_to_metal(op).is_some()) + || node + .op_as::() + .is_some_and(|op| convert_bin_ops_to_metal(&op.0).is_some()) + || node.op_as::().is_some() + || node.op_as::().is_some() + || node + .op_as::() + .is_some_and(|op| !op.transpose_c && op.quantize_output.is_none()) + || node + .op_as::() + .is_some_and(|op| !MetalTensor::is_supported_dt(op.0.datum_type())) + || node.op_as::().is_some_and(|op| { + check_in_dts_are_supported(source, node.id, ops::MetalCast::is_supported_dt) + .is_ok_and(|_| ops::MetalCast::new(op.to).is_some()) + }) + || node.op_as::().is_some() + || node.op_as::().is_some() + || node.op_as::().is_some() + || node.op_as::().is_some_and(|op| { + check_in_dts_are_supported(source, node.id, Reducer::is_supported_dt) + .is_ok_and(|_| ops::MetalReduce::from_tract_core(op).is_ok()) + }) + || node.op_as::().is_some_and(|op| { + check_in_dts_are_supported(source, node.id, Softmax::is_supported_dt) + .is_ok_and(|_| ops::MetalSoftmax::from_tract_core(op).is_ok()) + }) + || node.op_as::().is_some_and(|_| { + check_in_dts_are_supported(source, node.id, ScaledMaskedSoftmax::is_supported_dt) + .is_ok() + }) + || node.op_as::().is_some_and(|_| { + check_in_dts_are_supported(source, node.id, RmsNorm::is_supported_dt).is_ok() + }) + || node.op_as::().is_some_and(|_| { + check_in_dts_are_supported(source, node.id, RotateHalf::is_supported_dt).is_ok() + }) + || node.op_as::().is_some_and(|_| { + check_in_dts_are_supported(source, node.id, ApplyRope::is_supported_dt).is_ok() + }) + || node.op_as::().is_some_and(|_| { + check_in_dts_are_supported(source, node.id, Silu::is_supported_dt).is_ok() + }) + || node.op_as::().is_some_and(|_| { + check_in_dts_are_supported(source, node.id, NewGelu::is_supported_dt).is_ok() + }))) +} + impl Translate, TypedFact, Box> for MetalTransform { fn translate_node( &self, @@ -193,87 +250,102 @@ impl Translate, TypedFact, Box> for Met target: &mut TypedModel, mapping: &HashMap, ) -> TractResult> { - let in_dts_metal_compatible = source - .node_input_facts(node.id)? - .iter() - .all(|f| MetalTensor::is_supported_dt(f.datum_type) || f.as_metal_fact().is_some()); - - let new_metal_ops: Option>> = if !in_dts_metal_compatible { - None - } else if let Some(op) = node.op_as::() { - convert_element_wise_ops_to_metal(op).map(|o| -> Vec> { vec![Box::new(o)] }) - } else if let Some(op) = node.op_as::() { - convert_bin_ops_to_metal(&op.0).map(|o| -> Vec> { vec![Box::new(o)] }) - } else if let Some(op) = node.op_as::() { - Some(vec![Box::new(convert_logic_ops_to_metal(op))]) - } else if let Some(op) = node.op_as::() { - convert_matmul_to_metal(source, node, op, self.gemm_impl)? - } else if let Some(op) = node.op_as::() { - Some(vec![Box::new(ops::MetalMultiBroadcastTo::new(op.shape.clone()))]) - } else if let Some(op) = node.op_as::() { - convert_const(op)?.map(|o| -> Vec> { vec![Box::new(o)] }) - } else if let Some(op) = node.op_as::() { - check_in_dts_are_supported(source, node.id, ops::MetalCast::is_supported_dt)? - .then(|| ops::MetalCast::new(op.to)) - .flatten() - .map(|o| -> Vec> { vec![Box::new(o)] }) - } else if let Some(op) = node.op_as::() { - let in_fact = source.node_input_facts(node.id)?[0]; - Some(vec![Box::new(ops::MetalAxisOp::from_tract_core_with_fact(op.clone(), in_fact))]) - } else if let Some(op) = node.op_as::() { - Some(vec![Box::new(ops::MetalSlice::from_tract_core(op.clone()))]) - } else if let Some(op) = node.op_as::() { - Some(vec![Box::new(ops::MetalConcat::from_tract_core(op))]) - } else if let Some(op) = node.op_as::() { - check_in_dts_are_supported(source, node.id, Reducer::is_supported_dt)? - .then(|| ops::MetalReduce::from_tract_core(op).ok()) - .flatten() - .map(|o| -> Vec> { vec![Box::new(o)] }) - } else if let Some(op) = node.op_as::() { - check_in_dts_are_supported(source, node.id, Softmax::is_supported_dt)? - .then(|| ops::MetalSoftmax::from_tract_core(op).ok()) - .flatten() - .map(|o| -> Vec> { vec![Box::new(o)] }) - } else if let Some(op) = node.op_as::() { - check_in_dts_are_supported(source, node.id, ScaledMaskedSoftmax::is_supported_dt)? - .then(|| ops::MetalScaledMaskedSoftmax { scale: op.scale.clone() }) - .map(|o| -> Vec> { vec![Box::new(o)] }) - } else if let Some(op) = node.op_as::() { - check_in_dts_are_supported(source, node.id, RmsNorm::is_supported_dt)? - .then(|| ops::MetalRmsNorm::new(op.axis, op.eps.clone())) - .map(|o| -> Vec> { vec![Box::new(o)] }) - } else if let Some(_op) = node.op_as::() { - check_in_dts_are_supported(source, node.id, RotateHalf::is_supported_dt)? - .then_some(ops::MetalRotateHalf) - .map(|o| -> Vec> { vec![Box::new(o)] }) - } else if let Some(_op) = node.op_as::() { - check_in_dts_are_supported(source, node.id, ApplyRope::is_supported_dt)? - .then_some(ops::MetalApplyRope) - .map(|o| -> Vec> { vec![Box::new(o)] }) - } else if let Some(_op) = node.op_as::() { - check_in_dts_are_supported(source, node.id, Silu::is_supported_dt)? - .then(|| -> Vec> { vec![Box::new(ops::MetalSilu)] }) - } else if let Some(_op) = node.op_as::() { - check_in_dts_are_supported(source, node.id, NewGelu::is_supported_dt)? - .then(|| -> Vec> { vec![Box::new(ops::MetalNewGelu)] }) + let translatable = can_translate_op_to_metal_op(source, node)?; + + if translatable { + let gpu_inputs = + self.sync_inputs_if_required(target, node, mapping, MetalSyncKind::ToGpu)?; + + let outlet_ids: TVec = if let Some(op) = node.op_as::() { + target.wire_node( + node.name.clone(), + map_element_wise_ops_to_metal(op).unwrap(), + &gpu_inputs, + )? + } else if let Some(op) = node.op_as::() { + target.wire_node( + node.name.clone(), + convert_bin_ops_to_metal(&op.0).unwrap(), + &gpu_inputs, + )? + } else if let Some(op) = node.op_as::() { + target.wire_node(node.name.clone(), convert_logic_ops_to_metal(op), &gpu_inputs)? + } else if let Some(op) = node.op_as::() { + convert_matmul_to_metal(source, node, target, &gpu_inputs, op, self.gemm_impl)? + } else if let Some(op) = node.op_as::() { + target.wire_node( + node.name.clone(), + ops::MetalMultiBroadcastTo::new(op.shape.clone()), + &gpu_inputs, + )? + } else if let Some(op) = node.op_as::() { + target.wire_node(node.name.clone(), convert_const(op)?, &gpu_inputs)? + } else if let Some(op) = node.op_as::() { + target.wire_node( + node.name.clone(), + ops::MetalCast::new(op.to).unwrap(), + &gpu_inputs, + )? + } else if let Some(op) = node.op_as::() { + let in_fact = source.node_input_facts(node.id)?[0]; + target.wire_node( + node.name.clone(), + ops::MetalAxisOp::from_tract_core_with_fact(op.clone(), in_fact), + &gpu_inputs, + )? + } else if let Some(op) = node.op_as::() { + target.wire_node( + node.name.clone(), + ops::MetalSlice::from_tract_core(op.clone()), + &gpu_inputs, + )? + } else if let Some(op) = node.op_as::() { + target.wire_node( + node.name.clone(), + ops::MetalConcat::from_tract_core(op), + &gpu_inputs, + )? + } else if let Some(op) = node.op_as::() { + target.wire_node( + node.name.clone(), + ops::MetalReduce::from_tract_core(op).unwrap(), + &gpu_inputs, + )? + } else if let Some(op) = node.op_as::() { + target.wire_node( + node.name.clone(), + ops::MetalSoftmax::from_tract_core(op).unwrap(), + &gpu_inputs, + )? + } else if let Some(op) = node.op_as::() { + target.wire_node( + node.name.clone(), + ops::MetalScaledMaskedSoftmax { scale: op.scale.clone() }, + &gpu_inputs, + )? + } else if let Some(op) = node.op_as::() { + target.wire_node( + node.name.clone(), + ops::MetalRmsNorm::new(op.axis, op.eps.clone()), + &gpu_inputs, + )? + } else if let Some(_op) = node.op_as::() { + target.wire_node(node.name.clone(), ops::MetalRotateHalf, &gpu_inputs)? + } else if let Some(_op) = node.op_as::() { + target.wire_node(node.name.clone(), ops::MetalApplyRope, &gpu_inputs)? + } else if let Some(_op) = node.op_as::() { + target.wire_node(node.name.clone(), ops::MetalSilu, &gpu_inputs)? + } else if let Some(_op) = node.op_as::() { + target.wire_node(node.name.clone(), ops::MetalNewGelu, &gpu_inputs)? + } else { + bail!("Failed to translate a supported Metal Op") + }; + + self.sync_model_outputs_if_required(source, node, target, outlet_ids) } else { - None - }; - - match new_metal_ops { - Some(metal_ops) => { - let mut gpu_inputs = - self.sync_inputs_if_required(target, node, mapping, MetalSyncKind::ToGpu)?; - for op in metal_ops { - gpu_inputs = target.wire_node(node.name.clone() + "." + &op.name(), op, &gpu_inputs)?; - } - self.sync_model_outputs_if_required(source, node, target, gpu_inputs) - } - None => { - let cpu_inputs = - self.sync_inputs_if_required(target, node, mapping, MetalSyncKind::ToCpu)?; - target.wire_node(&node.name, node.op.clone(), &cpu_inputs) - } + let cpu_inputs = + self.sync_inputs_if_required(target, node, mapping, MetalSyncKind::ToCpu)?; + target.wire_node(&node.name, node.op.clone(), &cpu_inputs) } } } @@ -314,38 +386,35 @@ macro_rules! map_element_wise_ops { fn convert_matmul_to_metal( model: &TypedModel, node: &TypedNode, + target: &mut TypedModel, + inputs: &[OutletId], op: &BasicMatMul, gemm_impl: MetalGemmImplKind, -) -> Result>>> { - if !op.transpose_c - && op.quantize_output.is_none() - // && gemm_impl::is_supported_dts() - { - let matmul: Box = match gemm_impl { - MetalGemmImplKind::Mlx => { - Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) - } - MetalGemmImplKind::Mps => { - Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) - } - MetalGemmImplKind::Mfa => { - Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) - } - }; - - let out_dt = matmul.output_facts(&model.node_input_facts(node.id)?)?[0].datum_type; - let mut ops = vec![matmul]; - - dbg!(model.node_output_facts(node.id)?[0]); - if out_dt != model.node_output_facts(node.id)?[0].datum_type { - ensure!(ops::MetalCast::is_supported_dt(out_dt), "Matmul output type cannot be casted to expected type"); - let cast_op = ops::MetalCast::new(model.node_output_facts(node.id)?[0].datum_type).unwrap(); - ops.push(Box::new(cast_op)); +) -> TractResult> { + let matmul: Box = match gemm_impl { + MetalGemmImplKind::Mlx => { + Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + } + MetalGemmImplKind::Mps => { + Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) + } + MetalGemmImplKind::Mfa => { + Box::new(ops::MetalGemm::::new(op.transpose_a, op.transpose_b)) } - Ok(Some(ops)) - } else { - Ok(None) + }; + + let out_dt = matmul.output_facts(&model.node_input_facts(node.id)?)?[0].datum_type; + let mut matmul_output = target.wire_node(node.name.clone(), matmul, inputs)?; + + if out_dt != model.node_output_facts(node.id)?[0].datum_type { + ensure!( + ops::MetalCast::is_supported_dt(out_dt), + "Matmul output type cannot be casted to expected type" + ); + let cast_op = ops::MetalCast::new(model.node_output_facts(node.id)?[0].datum_type).unwrap(); + matmul_output = target.wire_node(node.name.clone() + ".cast", cast_op, &matmul_output)? } + Ok(matmul_output) } #[allow(clippy::borrowed_box)] @@ -398,16 +467,13 @@ pub fn bin_ops_to_metal( .transpose() } -fn convert_const(op: &Const) -> TractResult> { - if !MetalTensor::is_supported_dt(op.0.datum_type()) { - return Ok(None); - } +fn convert_const(op: &Const) -> TractResult { let metal_fact = MetalFact::from_cpu(Arc::clone(&op.0).into())?; let metal_const = op.0.clone().into_metal()?.into_opaque_tensor().into_arc_tensor(); - Ok(Some(Const::new_with_opaque_fact(metal_const, Box::new(metal_fact)))) + Ok(Const::new_with_opaque_fact(metal_const, Box::new(metal_fact))) } -fn convert_element_wise_ops_to_metal(op: &ElementWiseOp) -> Option { +fn map_element_wise_ops_to_metal(op: &ElementWiseOp) -> Option { map_element_wise_ops!([ (tract_core::ops::math::Abs, Abs), (tract_core::ops::math::Exp, Exp), @@ -437,29 +503,3 @@ fn convert_element_wise_ops_to_metal(op: &ElementWiseOp) -> Option Result> { - if op.1.is_some() { - return Ok(None); - } - - let input_facts = model.node_input_facts(node.id)?; - let dt = input_facts[0].datum_type; - - // All input must have the same datum type and it has to be supported. - if model.node_input_facts(node.id)?.iter().any(|f| f.datum_type != dt) - || !crate::kernels::ElementWiseOps::is_supported_dt(dt) - { - return Ok(None); - } - - convert_element_wise_ops_to_metal(op) - .map(|metal_op| TypedModelPatch::replace_single_op(model, node, &node.inputs, metal_op)) - .transpose() -} From 28dac41cb891c75222a6622e26fc091292fa648b Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Tue, 28 Jan 2025 17:22:46 +0100 Subject: [PATCH 4/9] refactor translate_node --- metal/src/tensor/mod.rs | 2 +- metal/src/transform.rs | 158 +++++++++++----------------------------- 2 files changed, 45 insertions(+), 115 deletions(-) diff --git a/metal/src/tensor/mod.rs b/metal/src/tensor/mod.rs index d3e933870c..b374b9542a 100644 --- a/metal/src/tensor/mod.rs +++ b/metal/src/tensor/mod.rs @@ -33,7 +33,7 @@ impl MetalTensor { DatumType::U32, DatumType::I64, DatumType::U64, - DatumType::TDim + DatumType::TDim, ]; pub fn tname(dt: DatumType) -> TractResult<&'static str> { diff --git a/metal/src/transform.rs b/metal/src/transform.rs index e991ca9d09..499397afe0 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -185,7 +185,7 @@ impl MetalTransform { } } -fn can_translate_op_to_metal_op(source: &TypedModel, node: &TypedNode) -> TractResult { +fn can_translate_to_metal_op(source: &TypedModel, node: &TypedNode) -> TractResult { let in_dts_metal_compatible = source .node_input_facts(node.id)? .iter() @@ -195,9 +195,7 @@ fn can_translate_op_to_metal_op(source: &TypedModel, node: &TypedNode) -> TractR && (node .op_as::() .is_some_and(|op| map_element_wise_ops_to_metal(op).is_some()) - || node - .op_as::() - .is_some_and(|op| convert_bin_ops_to_metal(&op.0).is_some()) + || node.op_as::().is_some_and(|op| map_bin_ops_to_metal(&op.0).is_some()) || node.op_as::().is_some() || node.op_as::().is_some() || node @@ -250,97 +248,55 @@ impl Translate, TypedFact, Box> for Met target: &mut TypedModel, mapping: &HashMap, ) -> TractResult> { - let translatable = can_translate_op_to_metal_op(source, node)?; + let translatable = can_translate_to_metal_op(source, node)?; if translatable { let gpu_inputs = self.sync_inputs_if_required(target, node, mapping, MetalSyncKind::ToGpu)?; - let outlet_ids: TVec = if let Some(op) = node.op_as::() { - target.wire_node( - node.name.clone(), - map_element_wise_ops_to_metal(op).unwrap(), - &gpu_inputs, - )? - } else if let Some(op) = node.op_as::() { - target.wire_node( - node.name.clone(), - convert_bin_ops_to_metal(&op.0).unwrap(), - &gpu_inputs, - )? - } else if let Some(op) = node.op_as::() { - target.wire_node(node.name.clone(), convert_logic_ops_to_metal(op), &gpu_inputs)? - } else if let Some(op) = node.op_as::() { + let outlet_ids: TVec = if let Some(op) = node.op_as::() { convert_matmul_to_metal(source, node, target, &gpu_inputs, op, self.gemm_impl)? - } else if let Some(op) = node.op_as::() { - target.wire_node( - node.name.clone(), - ops::MetalMultiBroadcastTo::new(op.shape.clone()), - &gpu_inputs, - )? - } else if let Some(op) = node.op_as::() { - target.wire_node(node.name.clone(), convert_const(op)?, &gpu_inputs)? - } else if let Some(op) = node.op_as::() { - target.wire_node( - node.name.clone(), - ops::MetalCast::new(op.to).unwrap(), - &gpu_inputs, - )? - } else if let Some(op) = node.op_as::() { - let in_fact = source.node_input_facts(node.id)?[0]; - target.wire_node( - node.name.clone(), - ops::MetalAxisOp::from_tract_core_with_fact(op.clone(), in_fact), - &gpu_inputs, - )? - } else if let Some(op) = node.op_as::() { - target.wire_node( - node.name.clone(), - ops::MetalSlice::from_tract_core(op.clone()), - &gpu_inputs, - )? - } else if let Some(op) = node.op_as::() { - target.wire_node( - node.name.clone(), - ops::MetalConcat::from_tract_core(op), - &gpu_inputs, - )? - } else if let Some(op) = node.op_as::() { - target.wire_node( - node.name.clone(), - ops::MetalReduce::from_tract_core(op).unwrap(), - &gpu_inputs, - )? - } else if let Some(op) = node.op_as::() { - target.wire_node( - node.name.clone(), - ops::MetalSoftmax::from_tract_core(op).unwrap(), - &gpu_inputs, - )? - } else if let Some(op) = node.op_as::() { - target.wire_node( - node.name.clone(), - ops::MetalScaledMaskedSoftmax { scale: op.scale.clone() }, - &gpu_inputs, - )? - } else if let Some(op) = node.op_as::() { - target.wire_node( - node.name.clone(), - ops::MetalRmsNorm::new(op.axis, op.eps.clone()), - &gpu_inputs, - )? - } else if let Some(_op) = node.op_as::() { - target.wire_node(node.name.clone(), ops::MetalRotateHalf, &gpu_inputs)? - } else if let Some(_op) = node.op_as::() { - target.wire_node(node.name.clone(), ops::MetalApplyRope, &gpu_inputs)? - } else if let Some(_op) = node.op_as::() { - target.wire_node(node.name.clone(), ops::MetalSilu, &gpu_inputs)? - } else if let Some(_op) = node.op_as::() { - target.wire_node(node.name.clone(), ops::MetalNewGelu, &gpu_inputs)? } else { - bail!("Failed to translate a supported Metal Op") + let op: Box = if let Some(op) = node.op_as::() { + Box::new(map_element_wise_ops_to_metal(op).unwrap()) + } else if let Some(op) = node.op_as::() { + Box::new(map_bin_ops_to_metal(&op.0).unwrap()) + } else if let Some(op) = node.op_as::() { + Box::new(convert_logic_ops_to_metal(op)) + } else if let Some(op) = node.op_as::() { + Box::new(ops::MetalMultiBroadcastTo::new(op.shape.clone())) + } else if let Some(op) = node.op_as::() { + Box::new(convert_const(op)?) + } else if let Some(op) = node.op_as::() { + Box::new(ops::MetalCast::new(op.to).unwrap()) + } else if let Some(op) = node.op_as::() { + let in_fact = source.node_input_facts(node.id)?[0]; + Box::new(ops::MetalAxisOp::from_tract_core_with_fact(op.clone(), in_fact)) + } else if let Some(op) = node.op_as::() { + Box::new(ops::MetalSlice::from_tract_core(op.clone())) + } else if let Some(op) = node.op_as::() { + Box::new(ops::MetalConcat::from_tract_core(op)) + } else if let Some(op) = node.op_as::() { + Box::new(ops::MetalReduce::from_tract_core(op).unwrap()) + } else if let Some(op) = node.op_as::() { + Box::new(ops::MetalSoftmax::from_tract_core(op).unwrap()) + } else if let Some(op) = node.op_as::() { + Box::new(ops::MetalScaledMaskedSoftmax { scale: op.scale.clone() }) + } else if let Some(op) = node.op_as::() { + Box::new(ops::MetalRmsNorm::new(op.axis, op.eps.clone())) + } else if let Some(_op) = node.op_as::() { + Box::new(ops::MetalRotateHalf) + } else if let Some(_op) = node.op_as::() { + Box::new(ops::MetalApplyRope) + } else if let Some(_op) = node.op_as::() { + Box::new(ops::MetalSilu) + } else if let Some(_op) = node.op_as::() { + Box::new(ops::MetalNewGelu) + } else { + bail!("Failed to translate a supported Metal Op") + }; + target.wire_node(node.name.clone(), op, &gpu_inputs)? }; - self.sync_model_outputs_if_required(source, node, target, outlet_ids) } else { let cpu_inputs = @@ -418,7 +374,7 @@ fn convert_matmul_to_metal( } #[allow(clippy::borrowed_box)] -fn convert_bin_ops_to_metal(op: &Box) -> Option { +fn map_bin_ops_to_metal(op: &Box) -> Option { map_bin_ops!([ (tract_core::ops::math::Mul, Mul), (tract_core::ops::math::Add, Add), @@ -441,32 +397,6 @@ fn convert_logic_ops_to_metal(op: &Comp) -> ops::MetalBinOp { } } -pub fn bin_ops_to_metal( - _ctx: &(), - model: &TypedModel, - node: &TypedNode, - _node_name: &str, - op: &TypedBinOp, -) -> Result> { - if op.1.is_some() { - return Ok(None); - } - - let input_facts = model.node_input_facts(node.id)?; - let dt = input_facts[0].datum_type; - - // All input must have the same datum type and it has to be supported. - if model.node_input_facts(node.id)?.iter().any(|f| f.datum_type != dt) - || !crate::kernels::BinOps::is_supported_dt(dt) - { - return Ok(None); - } - - convert_bin_ops_to_metal(&op.0) - .map(|metal_op| TypedModelPatch::replace_single_op(model, node, &node.inputs, metal_op)) - .transpose() -} - fn convert_const(op: &Const) -> TractResult { let metal_fact = MetalFact::from_cpu(Arc::clone(&op.0).into())?; let metal_const = op.0.clone().into_metal()?.into_opaque_tensor().into_arc_tensor(); From 9020ae6cbaa7687377bf81b44d53183b77d8b136 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 29 Jan 2025 09:49:22 +0100 Subject: [PATCH 5/9] refactored translation check --- metal/src/kernels/matmul/basic/mod.rs | 4 - metal/src/kernels/matmul/mfa/mod.rs | 4 - metal/src/kernels/matmul/mlx_gemm/mod.rs | 4 - metal/src/kernels/matmul/mod.rs | 9 +-- metal/src/kernels/matmul/mps/mod.rs | 4 - metal/src/transform.rs | 94 ++++++++++++------------ 6 files changed, 53 insertions(+), 66 deletions(-) diff --git a/metal/src/kernels/matmul/basic/mod.rs b/metal/src/kernels/matmul/basic/mod.rs index 1ab4ab6d89..3d9da4885b 100644 --- a/metal/src/kernels/matmul/basic/mod.rs +++ b/metal/src/kernels/matmul/basic/mod.rs @@ -15,10 +15,6 @@ impl GemmKernel for BasicMatMul { "basic" } - fn is_supported_dt(&self, dt: DatumType) -> bool { - Self::tname(dt).is_ok() - } - fn dispatch_eval( &self, context: &MetalContext, diff --git a/metal/src/kernels/matmul/mfa/mod.rs b/metal/src/kernels/matmul/mfa/mod.rs index 99628eda73..eb80add720 100644 --- a/metal/src/kernels/matmul/mfa/mod.rs +++ b/metal/src/kernels/matmul/mfa/mod.rs @@ -20,10 +20,6 @@ impl GemmKernel for MfaGemm { "mfa" } - fn is_supported_dt(&self, dt: DatumType) -> bool { - matches!(dt, DatumType::F32 | DatumType::F16) - } - fn dispatch_eval( &self, context: &MetalContext, diff --git a/metal/src/kernels/matmul/mlx_gemm/mod.rs b/metal/src/kernels/matmul/mlx_gemm/mod.rs index c17a0cfe22..e7c2081d30 100644 --- a/metal/src/kernels/matmul/mlx_gemm/mod.rs +++ b/metal/src/kernels/matmul/mlx_gemm/mod.rs @@ -54,10 +54,6 @@ impl GemmKernel for MlxGemm { "mlx" } - fn is_supported_dt(&self, dt: DatumType) -> bool { - matches!(dt, DatumType::F32 | DatumType::F16) - } - fn dispatch_eval( &self, context: &MetalContext, diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 33edac307c..513c40bace 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -148,7 +148,10 @@ impl GemmDispatchParams { pub trait GemmKernel: fmt::Display + fmt::Debug + Clone + Default + Send + Sync { fn name() -> &'static str; - fn is_supported_dt(&self, dt: DatumType) -> bool; + fn is_supported_dts(&self, dts: &[DatumType]) -> TractResult { + ensure!(dts.len() == 2); + Ok(matches!(dts[0], DatumType::F32 | DatumType::F16) && dts[0] == dts[1]) + } fn output_facts( &self, @@ -194,10 +197,6 @@ impl GemmImpl { Self { transpose_a, transpose_b, matmul: M::default() } } - pub fn is_supported_dt(&self, dt: DatumType) -> bool { - self.matmul.is_supported_dt(dt) - } - pub fn output_shape(&self, a: &[D], b: &[D]) -> TVec { let rank = a.len(); let mut output: TVec = (0..rank - 2) diff --git a/metal/src/kernels/matmul/mps/mod.rs b/metal/src/kernels/matmul/mps/mod.rs index 768c46e41a..eb84a3e6e6 100644 --- a/metal/src/kernels/matmul/mps/mod.rs +++ b/metal/src/kernels/matmul/mps/mod.rs @@ -34,10 +34,6 @@ impl GemmKernel for MpsMatMul { "mps" } - fn is_supported_dt(&self, dt: DatumType) -> bool { - matches!(dt, DatumType::F32 | DatumType::F16) - } - fn dispatch_eval( &self, context: &MetalContext, diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 499397afe0..0344ffcc19 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -1,6 +1,6 @@ use crate::fact::MetalTypedFactExt; use crate::kernels::array::RotateHalf; -use crate::kernels::matmul::{MetalGemmImplKind, MfaGemm, MlxGemm, MpsMatMul}; +use crate::kernels::matmul::{GemmKernel, MetalGemmImplKind, MfaGemm, MlxGemm, MpsMatMul}; use crate::kernels::nn::{ ApplyRope, NewGelu, Reducer, RmsNorm, ScaledMaskedSoftmax, Silu, Softmax, }; @@ -13,7 +13,6 @@ use crate::rewrite_rules::{ }; use crate::tensor::MetalTensorExt; use crate::{IntoMetal, MetalFact, MetalTensor}; -use anyhow::Result; use std::borrow::Cow; use std::fmt::Debug; use tract_core::internal::translator::Translate; @@ -27,6 +26,7 @@ use tract_core::ops::konst::Const; use tract_core::ops::logic::Comp; use tract_core::ops::nn::{Reduce, Softmax as CoreSoftmax}; use tract_core::transform::ModelTransform; +use tract_itertools::Itertools; impl MetalGemmImplKind { pub fn variants() -> Vec { @@ -185,11 +185,18 @@ impl MetalTransform { } } -fn can_translate_to_metal_op(source: &TypedModel, node: &TypedNode) -> TractResult { - let in_dts_metal_compatible = source +fn can_translate_to_metal_op( + source: &TypedModel, + node: &TypedNode, + gemm_impl: MetalGemmImplKind, +) -> TractResult { + let input_dts = source .node_input_facts(node.id)? .iter() - .all(|f| MetalTensor::is_supported_dt(f.datum_type) || f.as_metal_fact().is_some()); + .map(|f| f.as_metal_fact().map(|f| f.datum_type).unwrap_or(f.datum_type)) + .collect_vec(); + + let in_dts_metal_compatible = input_dts.iter().all(|dt| MetalTensor::is_supported_dt(*dt)); Ok(in_dts_metal_compatible && (node @@ -198,46 +205,45 @@ fn can_translate_to_metal_op(source: &TypedModel, node: &TypedNode) -> TractResu || node.op_as::().is_some_and(|op| map_bin_ops_to_metal(&op.0).is_some()) || node.op_as::().is_some() || node.op_as::().is_some() - || node - .op_as::() - .is_some_and(|op| !op.transpose_c && op.quantize_output.is_none()) + || node.op_as::().is_some_and(|op| { + !op.transpose_c + && op.quantize_output.is_none() + && check_matmul_in_dts(gemm_impl, &input_dts) + }) || node .op_as::() .is_some_and(|op| !MetalTensor::is_supported_dt(op.0.datum_type())) || node.op_as::().is_some_and(|op| { - check_in_dts_are_supported(source, node.id, ops::MetalCast::is_supported_dt) - .is_ok_and(|_| ops::MetalCast::new(op.to).is_some()) + ops::MetalCast::is_supported_dt(input_dts[0]) + && ops::MetalCast::new(op.to).is_some() }) || node.op_as::().is_some() || node.op_as::().is_some() || node.op_as::().is_some() || node.op_as::().is_some_and(|op| { - check_in_dts_are_supported(source, node.id, Reducer::is_supported_dt) - .is_ok_and(|_| ops::MetalReduce::from_tract_core(op).is_ok()) + Reducer::is_supported_dt(input_dts[0]) + && ops::MetalReduce::from_tract_core(op).is_ok() }) || node.op_as::().is_some_and(|op| { - check_in_dts_are_supported(source, node.id, Softmax::is_supported_dt) - .is_ok_and(|_| ops::MetalSoftmax::from_tract_core(op).is_ok()) + Softmax::is_supported_dt(input_dts[0]) + && ops::MetalSoftmax::from_tract_core(op).is_ok() }) - || node.op_as::().is_some_and(|_| { - check_in_dts_are_supported(source, node.id, ScaledMaskedSoftmax::is_supported_dt) - .is_ok() - }) - || node.op_as::().is_some_and(|_| { - check_in_dts_are_supported(source, node.id, RmsNorm::is_supported_dt).is_ok() - }) - || node.op_as::().is_some_and(|_| { - check_in_dts_are_supported(source, node.id, RotateHalf::is_supported_dt).is_ok() - }) - || node.op_as::().is_some_and(|_| { - check_in_dts_are_supported(source, node.id, ApplyRope::is_supported_dt).is_ok() - }) - || node.op_as::().is_some_and(|_| { - check_in_dts_are_supported(source, node.id, Silu::is_supported_dt).is_ok() - }) - || node.op_as::().is_some_and(|_| { - check_in_dts_are_supported(source, node.id, NewGelu::is_supported_dt).is_ok() - }))) + || node + .op_as::() + .is_some_and(|_| ScaledMaskedSoftmax::is_supported_dt(input_dts[0])) + || node + .op_as::() + .is_some_and(|_| RmsNorm::is_supported_dt(input_dts[0])) + || node + .op_as::() + .is_some_and(|_| RotateHalf::is_supported_dt(input_dts[0])) + || node + .op_as::() + .is_some_and(|_| ApplyRope::is_supported_dt(input_dts[0])) + || node.op_as::().is_some_and(|_| Silu::is_supported_dt(input_dts[0])) + || node + .op_as::() + .is_some_and(|_| NewGelu::is_supported_dt(input_dts[0])))) } impl Translate, TypedFact, Box> for MetalTransform { @@ -248,7 +254,7 @@ impl Translate, TypedFact, Box> for Met target: &mut TypedModel, mapping: &HashMap, ) -> TractResult> { - let translatable = can_translate_to_metal_op(source, node)?; + let translatable = can_translate_to_metal_op(source, node, self.gemm_impl)?; if translatable { let gpu_inputs = @@ -306,17 +312,6 @@ impl Translate, TypedFact, Box> for Met } } -fn check_in_dts_are_supported( - model: &TypedModel, - node_id: usize, - is_supported_dt: impl Fn(DatumType) -> bool, -) -> TractResult { - Ok(model.node_input_facts(node_id)?.iter().all(|f| { - (is_supported_dt)(f.datum_type) - || f.as_metal_fact().map(|f| (is_supported_dt)(f.datum_type)).unwrap_or(false) - })) -} - macro_rules! map_bin_ops { ([$(($tract_bin_op:path, $metal_bin_op:ident)),* $(,)?]) => { |op: &Box| { @@ -339,6 +334,15 @@ macro_rules! map_element_wise_ops { }; } +fn check_matmul_in_dts(gemm_impl: MetalGemmImplKind, dts: &[DatumType]) -> bool { + let is_supported = match gemm_impl { + MetalGemmImplKind::Mlx => MlxGemm.is_supported_dts(dts), + MetalGemmImplKind::Mps => MpsMatMul.is_supported_dts(dts), + MetalGemmImplKind::Mfa => MfaGemm.is_supported_dts(dts), + }; + is_supported.unwrap_or(false) +} + fn convert_matmul_to_metal( model: &TypedModel, node: &TypedNode, From f4fc1e4264ff8cb2d83021c2d7ed21215582c3cf Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 29 Jan 2025 13:10:14 +0100 Subject: [PATCH 6/9] Simplifies dt handling for moni-type matmuls --- metal/src/kernels/matmul/basic/mod.rs | 1 + metal/src/kernels/matmul/mfa/mod.rs | 10 ++++----- metal/src/kernels/matmul/mlx_gemm/mod.rs | 27 +++++++++--------------- metal/src/kernels/matmul/mod.rs | 10 ++++----- 4 files changed, 20 insertions(+), 28 deletions(-) diff --git a/metal/src/kernels/matmul/basic/mod.rs b/metal/src/kernels/matmul/basic/mod.rs index 3d9da4885b..6de7fa0b8f 100644 --- a/metal/src/kernels/matmul/basic/mod.rs +++ b/metal/src/kernels/matmul/basic/mod.rs @@ -38,6 +38,7 @@ impl GemmKernel for BasicMatMul { ensure!(dts[0] == dts[1]); ensure!(dts[0] == dts[2]); + ensure!(Self::tname(dts[0]).is_ok()); let dt = dts[0]; for b_idx in 0..batch { diff --git a/metal/src/kernels/matmul/mfa/mod.rs b/metal/src/kernels/matmul/mfa/mod.rs index eb80add720..465c4663fb 100644 --- a/metal/src/kernels/matmul/mfa/mod.rs +++ b/metal/src/kernels/matmul/mfa/mod.rs @@ -52,9 +52,11 @@ impl GemmKernel for MfaGemm { natural_strides(&[batch, k, n]) }; + ensure!(dts[0] == dts[1]); + ensure!(dts[0] == dts[2]); dispatch_metal_mfa_gemm( context, - dts, + dts[0], (batch, m, n, k), unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) }, a_offset, @@ -76,7 +78,7 @@ impl GemmKernel for MfaGemm { #[allow(clippy::too_many_arguments)] pub fn dispatch_metal_mfa_gemm( context: &MetalContext, - dts: [DatumType; 3], + dt: DatumType, (b, m, n, k): (usize, usize, usize, usize), lhs_stride: &[usize], lhs_offset: usize, @@ -164,10 +166,6 @@ pub fn dispatch_metal_mfa_gemm( (50_001, Value::Bool(fused_bias)), ])); - ensure!(dts[0] == dts[1]); - ensure!(dts[0] == dts[2]); - - let dt = dts[0]; let name = match dt { DatumType::F32 => "sgemm", DatumType::F16 => "hgemm", diff --git a/metal/src/kernels/matmul/mlx_gemm/mod.rs b/metal/src/kernels/matmul/mlx_gemm/mod.rs index e7c2081d30..fe19d0a373 100644 --- a/metal/src/kernels/matmul/mlx_gemm/mod.rs +++ b/metal/src/kernels/matmul/mlx_gemm/mod.rs @@ -86,10 +86,12 @@ impl GemmKernel for MlxGemm { natural_strides(&[batch, k, n]) }; + ensure!(matches!(dts[0], DatumType::F32 | DatumType::F16)); + ensure!(dts[0] == dts[1] && dts[0] == dts[2]); if m == 1 || n == 1 { dispatch_metal_mlx_gemv( context, - dts, + dts[0], (batch, m, n, k), unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) }, a_offset, @@ -105,7 +107,7 @@ impl GemmKernel for MlxGemm { } else { dispatch_metal_mlx_gemm( context, - dts, + dts[0], (batch, m, n, k), unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) }, a_offset, @@ -128,7 +130,7 @@ impl GemmKernel for MlxGemm { #[allow(clippy::too_many_arguments)] pub fn dispatch_metal_mlx_gemv( context: &MetalContext, - dts: [DatumType; 3], + dt: DatumType, (b, m, n, k): (usize, usize, usize, usize), a_strides: &[usize], a_offset: usize, @@ -144,8 +146,6 @@ pub fn dispatch_metal_mlx_gemv( ensure!(m == 1 || n == 1); assert!(a_strides.len() >= 2 && b_strides.len() >= 2); assert!(a_strides.len() >= 2); - ensure!(matches!(dts[0], DatumType::F32 | DatumType::F16)); - ensure!(dts[0] == dts[1] && dts[0] == dts[2]); let lda = if a_trans { m } else { k }; let ldb = if b_trans { k } else { n }; @@ -198,7 +198,7 @@ pub fn dispatch_metal_mlx_gemv( let t_mat = if mv_trans { "t_" } else { "" }; - let tname = MetalTensor::tname(dts[0])?; + let tname = MetalTensor::tname(dt)?; let name = format!("gemv_{t_mat}{tname}_bm{bm}_bn{bn}_sm{sm}_sn{sn}_tm{tm}_tn{tn}_nc0_axpby0"); let pipeline = context.shared_context().load_pipeline(LibraryName::MlxGemv, &name)?; @@ -266,7 +266,7 @@ pub fn dispatch_metal_mlx_gemv( #[allow(clippy::too_many_arguments)] pub fn dispatch_metal_mlx_gemm( context: &MetalContext, - dts: [DatumType; 3], + dts: DatumType, (b, m, n, k): (usize, usize, usize, usize), lhs_stride: &[usize], lhs_offset: usize, @@ -282,8 +282,6 @@ pub fn dispatch_metal_mlx_gemm( ) -> Result<()> { assert!(rhs_stride.len() >= 2); assert!(lhs_stride.len() >= 2); - ensure!(matches!(dts[0], DatumType::F32 | DatumType::F16)); - ensure!(dts[0] == dts[1] && dts[0] == dts[2]); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; @@ -418,17 +416,12 @@ pub fn dispatch_metal_mlx_gemm( Ok(()) } -pub fn kernel_name_gemm( - dts: [DatumType; 3], - transpose_a: bool, - transpose_b: bool, -) -> Result { +pub fn kernel_name_gemm(dt: DatumType, transpose_a: bool, transpose_b: bool) -> Result { let t_a = if transpose_a { "t" } else { "n" }; let t_b = if transpose_b { "t" } else { "n" }; - let i_tname = MetalTensor::tname(dts[0])?; - let o_tname = MetalTensor::tname(dts[2])?; - Ok(format!("gemm_{t_a}{t_b}_{i_tname}_{o_tname}_32_32_16_2_2")) + let tname = MetalTensor::tname(dt)?; + Ok(format!("gemm_{t_a}{t_b}_{tname}_{tname}_32_32_16_2_2")) } #[cfg(test)] diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 513c40bace..db8beda1a1 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -349,7 +349,7 @@ mod tests { let (m, k, n) = (2, 3, 4); assert_eq!( GemmDispatchParams::compute_dispatches_params( - [dt, dt, dt], + [dt; 3], 0, &[1, m, k], false, @@ -375,7 +375,7 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( - [dt, dt, dt], + [dt; 3], 0, &[10, m, k], false, @@ -401,7 +401,7 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( - [dt, dt, dt], + [dt; 3], 0, &[1, m, k], false, @@ -441,7 +441,7 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( - [dt, dt, dt], + [dt; 3], 0, &[2, k, m], true, @@ -467,7 +467,7 @@ mod tests { assert_eq!( GemmDispatchParams::compute_dispatches_params( - [dt, dt, dt], + [dt; 3], 0, &[2, k, m], true, From d22256c9ccaad98ee35d2024357f8921393b557a Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 29 Jan 2025 13:33:16 +0100 Subject: [PATCH 7/9] Fix typo for Const --- metal/src/transform.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 0344ffcc19..e1822d1b7b 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -212,7 +212,7 @@ fn can_translate_to_metal_op( }) || node .op_as::() - .is_some_and(|op| !MetalTensor::is_supported_dt(op.0.datum_type())) + .is_some_and(|op| MetalTensor::is_supported_dt(op.0.datum_type())) || node.op_as::().is_some_and(|op| { ops::MetalCast::is_supported_dt(input_dts[0]) && ops::MetalCast::new(op.to).is_some() From 31e4c08db4c0d780d5c5366448553efae639a2fb Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 29 Jan 2025 17:09:00 +0100 Subject: [PATCH 8/9] code clean --- metal/src/kernels/matmul/basic/mod.rs | 15 ++++++++++++--- metal/src/kernels/matmul/mfa/mod.rs | 15 +++++++++++++-- metal/src/kernels/matmul/mlx_gemm/mod.rs | 23 +++++++++++++++++------ metal/src/kernels/matmul/mps/mod.rs | 17 +++++++++++++---- metal/src/transform.rs | 10 +++++----- 5 files changed, 60 insertions(+), 20 deletions(-) diff --git a/metal/src/kernels/matmul/basic/mod.rs b/metal/src/kernels/matmul/basic/mod.rs index 6de7fa0b8f..3ba5c3a35e 100644 --- a/metal/src/kernels/matmul/basic/mod.rs +++ b/metal/src/kernels/matmul/basic/mod.rs @@ -36,9 +36,18 @@ impl GemmKernel for BasicMatMul { c_offset, } = params; - ensure!(dts[0] == dts[1]); - ensure!(dts[0] == dts[2]); - ensure!(Self::tname(dts[0]).is_ok()); + ensure!( + Self::tname(dts[0]).is_ok(), + "Unsupported datum type for Metal BasicMatmul {:?}", + dts[0] + ); + ensure!( + dts[0] == dts[1] && dts[0] == dts[2], + "Metal BasicMatmul only supports homogenous datum types. I: {:?}, {:?}. O: {:?}", + dts[0], + dts[1], + dts[2] + ); let dt = dts[0]; for b_idx in 0..batch { diff --git a/metal/src/kernels/matmul/mfa/mod.rs b/metal/src/kernels/matmul/mfa/mod.rs index 465c4663fb..d646f53195 100644 --- a/metal/src/kernels/matmul/mfa/mod.rs +++ b/metal/src/kernels/matmul/mfa/mod.rs @@ -52,8 +52,19 @@ impl GemmKernel for MfaGemm { natural_strides(&[batch, k, n]) }; - ensure!(dts[0] == dts[1]); - ensure!(dts[0] == dts[2]); + ensure!( + matches!(dts[0], DatumType::F32 | DatumType::F16), + "Unsupported datum type for Mfa {:?}", + dts[0] + ); + ensure!( + dts[0] == dts[1] && dts[0] == dts[2], + "Mfa only supports homogeneous datum types. I: {:?}, {:?}. O: {:?}", + dts[0], + dts[1], + dts[2] + ); + dispatch_metal_mfa_gemm( context, dts[0], diff --git a/metal/src/kernels/matmul/mlx_gemm/mod.rs b/metal/src/kernels/matmul/mlx_gemm/mod.rs index fe19d0a373..7e5c1891dd 100644 --- a/metal/src/kernels/matmul/mlx_gemm/mod.rs +++ b/metal/src/kernels/matmul/mlx_gemm/mod.rs @@ -86,8 +86,19 @@ impl GemmKernel for MlxGemm { natural_strides(&[batch, k, n]) }; - ensure!(matches!(dts[0], DatumType::F32 | DatumType::F16)); - ensure!(dts[0] == dts[1] && dts[0] == dts[2]); + ensure!( + matches!(dts[0], DatumType::F32 | DatumType::F16), + "Unsupported datum type for MlxGemm {:?}", + dts[0] + ); + ensure!( + dts[0] == dts[1] && dts[0] == dts[2], + "MlxGemm only supports homogeneous datum types. I: {:?}, {:?}. O: {:?}", + dts[0], + dts[1], + dts[2] + ); + if m == 1 || n == 1 { dispatch_metal_mlx_gemv( context, @@ -144,8 +155,8 @@ pub fn dispatch_metal_mlx_gemv( output_offset: usize, ) -> Result<()> { ensure!(m == 1 || n == 1); - assert!(a_strides.len() >= 2 && b_strides.len() >= 2); - assert!(a_strides.len() >= 2); + ensure!(a_strides.len() >= 2 && b_strides.len() >= 2); + ensure!(a_strides.len() >= 2); let lda = if a_trans { m } else { k }; let ldb = if b_trans { k } else { n }; @@ -280,8 +291,8 @@ pub fn dispatch_metal_mlx_gemm( output_offset: usize, debug: bool, ) -> Result<()> { - assert!(rhs_stride.len() >= 2); - assert!(lhs_stride.len() >= 2); + ensure!(rhs_stride.len() >= 2); + ensure!(lhs_stride.len() >= 2); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; diff --git a/metal/src/kernels/matmul/mps/mod.rs b/metal/src/kernels/matmul/mps/mod.rs index eb84a3e6e6..79f79e6288 100644 --- a/metal/src/kernels/matmul/mps/mod.rs +++ b/metal/src/kernels/matmul/mps/mod.rs @@ -55,17 +55,26 @@ impl GemmKernel for MpsMatMul { c_offset, } = params; - ensure!(dts[0] == dts[1]); - ensure!(dts[0] == dts[2]); - let dt = dts[0]; - let data_type = match dt { DatumType::F32 => MPSDataType::Float32, DatumType::F16 => MPSDataType::Float16, _ => bail!("Unsupported datum type for MpsMatMul {:?}", dt), }; + ensure!( + dts[0] == dts[1], + "MpsMatmul: Input datum types are different {:?} != {:?}", + dts[0], + dts[1] + ); + ensure!( + dts[0] == dts[2], + "MpsMatmul: Input/Output datum types are different {:?} != {:?}", + dts[0], + dts[2] + ); + for b_idx in 0..batch { let a_offset = a_offset + b_idx * m * k * dt.size_of(); let b_offset = b_offset + b_idx * n * k * dt.size_of(); diff --git a/metal/src/transform.rs b/metal/src/transform.rs index e1822d1b7b..2ba2ceb6e3 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -203,8 +203,8 @@ fn can_translate_to_metal_op( .op_as::() .is_some_and(|op| map_element_wise_ops_to_metal(op).is_some()) || node.op_as::().is_some_and(|op| map_bin_ops_to_metal(&op.0).is_some()) - || node.op_as::().is_some() - || node.op_as::().is_some() + || node.op_is::() + || node.op_is::() || node.op_as::().is_some_and(|op| { !op.transpose_c && op.quantize_output.is_none() @@ -217,9 +217,9 @@ fn can_translate_to_metal_op( ops::MetalCast::is_supported_dt(input_dts[0]) && ops::MetalCast::new(op.to).is_some() }) - || node.op_as::().is_some() - || node.op_as::().is_some() - || node.op_as::().is_some() + || node.op_is::() + || node.op_is::() + || node.op_is::() || node.op_as::().is_some_and(|op| { Reducer::is_supported_dt(input_dts[0]) && ops::MetalReduce::from_tract_core(op).is_ok() From b4af8600fdd6f6aa078cec8e4ff5484babd687e8 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 29 Jan 2025 17:21:54 +0100 Subject: [PATCH 9/9] Fix typos and unecessary fix --- metal/src/kernels/matmul/mlx_gemm/mod.rs | 4 ++-- metal/src/kernels/matmul/mod.rs | 2 +- metal/src/tensor/mod.rs | 4 +--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/metal/src/kernels/matmul/mlx_gemm/mod.rs b/metal/src/kernels/matmul/mlx_gemm/mod.rs index 7e5c1891dd..562f5544a2 100644 --- a/metal/src/kernels/matmul/mlx_gemm/mod.rs +++ b/metal/src/kernels/matmul/mlx_gemm/mod.rs @@ -277,7 +277,7 @@ pub fn dispatch_metal_mlx_gemv( #[allow(clippy::too_many_arguments)] pub fn dispatch_metal_mlx_gemm( context: &MetalContext, - dts: DatumType, + dt: DatumType, (b, m, n, k): (usize, usize, usize, usize), lhs_stride: &[usize], lhs_offset: usize, @@ -365,7 +365,7 @@ pub fn dispatch_metal_mlx_gemm( let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b]; - let name = kernel_name_gemm(dts, a_trans, b_trans)?; + let name = kernel_name_gemm(dt, a_trans, b_trans)?; let pipeline = context.shared_context().load_pipeline_with_constants( LibraryName::MlxGemm, diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index db8beda1a1..77d2a8a1b0 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -117,7 +117,7 @@ impl GemmDispatchParams { transpose_a, a_offset, transpose_b, - b_offset: b_offset + b_batch_idx * n * k * dts[0].size_of(), + b_offset: b_offset + b_batch_idx * n * k * dts[1].size_of(), c_offset: c_offset + b_batch_idx * m * n * dts[2].size_of(), }) .collect()), diff --git a/metal/src/tensor/mod.rs b/metal/src/tensor/mod.rs index b374b9542a..2597738aa8 100644 --- a/metal/src/tensor/mod.rs +++ b/metal/src/tensor/mod.rs @@ -21,7 +21,7 @@ pub enum MetalTensor { } impl MetalTensor { - pub const SUPPORTED_DT: [DatumType; 12] = [ + pub const SUPPORTED_DT: [DatumType; 11] = [ DatumType::Bool, DatumType::F32, DatumType::F16, @@ -33,7 +33,6 @@ impl MetalTensor { DatumType::U32, DatumType::I64, DatumType::U64, - DatumType::TDim, ]; pub fn tname(dt: DatumType) -> TractResult<&'static str> { @@ -49,7 +48,6 @@ impl MetalTensor { DatumType::I32 => "i32", DatumType::I64 => "i64", DatumType::Bool => "bool", - DatumType::TDim => "Tdim", _ => bail!("Unsupport dt {:?} for metal kernel function", dt), }) }