Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add heterogeneous Datum types support for Metal Matmuls #1631

Merged
merged 9 commits into from
Jan 30, 2025
21 changes: 16 additions & 5 deletions metal/src/kernels/matmul/basic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ impl GemmKernel for BasicMatMul {
"basic"
}

fn is_supported_dt(&self, dt: DatumType) -> bool {
Self::tname(dt).is_ok()
}

fn dispatch_eval(
&self,
context: &MetalContext,
Expand All @@ -28,7 +24,7 @@ impl GemmKernel for BasicMatMul {
c_buffer: &Buffer,
) -> TractResult<()> {
let GemmDispatchParams {
dt,
dts,
batch,
m,
k,
Expand All @@ -39,6 +35,21 @@ impl GemmKernel for BasicMatMul {
b_offset,
c_offset,
} = params;

ensure!(
Self::tname(dts[0]).is_ok(),
"Unsupported datum type for Metal BasicMatmul {:?}",
dts[0]
);
ensure!(
dts[0] == dts[1] && dts[0] == dts[2],
"Metal BasicMatmul only supports homogenous datum types. I: {:?}, {:?}. O: {:?}",
dts[0],
dts[1],
dts[2]
);

let dt = dts[0];
for b_idx in 0..batch {
let a_offset = a_offset + b_idx * m * k * dt.size_of();
let b_offset = b_offset + b_idx * n * k * dt.size_of();
Expand Down
21 changes: 15 additions & 6 deletions metal/src/kernels/matmul/mfa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ impl GemmKernel for MfaGemm {
"mfa"
}

fn is_supported_dt(&self, dt: DatumType) -> bool {
matches!(dt, DatumType::F32 | DatumType::F16)
}

fn dispatch_eval(
&self,
context: &MetalContext,
Expand All @@ -33,7 +29,7 @@ impl GemmKernel for MfaGemm {
c_buffer: &Buffer,
) -> TractResult<()> {
let GemmDispatchParams {
dt,
dts,
batch,
m,
k,
Expand All @@ -56,9 +52,22 @@ impl GemmKernel for MfaGemm {
natural_strides(&[batch, k, n])
};

ensure!(
matches!(dts[0], DatumType::F32 | DatumType::F16),
"Unsupported datum type for Mfa {:?}",
dts[0]
);
ensure!(
dts[0] == dts[1] && dts[0] == dts[2],
"Mfa only supports homogeneous datum types. I: {:?}, {:?}. O: {:?}",
dts[0],
dts[1],
dts[2]
);

dispatch_metal_mfa_gemm(
context,
dt,
dts[0],
(batch, m, n, k),
unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) },
a_offset,
Expand Down
35 changes: 21 additions & 14 deletions metal/src/kernels/matmul/mlx_gemm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ impl GemmKernel for MlxGemm {
"mlx"
}

fn is_supported_dt(&self, dt: DatumType) -> bool {
matches!(dt, DatumType::F32 | DatumType::F16)
}

fn dispatch_eval(
&self,
context: &MetalContext,
Expand All @@ -67,7 +63,7 @@ impl GemmKernel for MlxGemm {
c_buffer: &Buffer,
) -> TractResult<()> {
let GemmDispatchParams {
dt,
dts,
batch,
m,
k,
Expand All @@ -90,10 +86,23 @@ impl GemmKernel for MlxGemm {
natural_strides(&[batch, k, n])
};

ensure!(
matches!(dts[0], DatumType::F32 | DatumType::F16),
"Unsupported datum type for MlxGemm {:?}",
dts[0]
);
ensure!(
dts[0] == dts[1] && dts[0] == dts[2],
"MlxGemm only supports homogeneous datum types. I: {:?}, {:?}. O: {:?}",
dts[0],
dts[1],
dts[2]
);

if m == 1 || n == 1 {
dispatch_metal_mlx_gemv(
context,
dt,
dts[0],
(batch, m, n, k),
unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) },
a_offset,
Expand All @@ -109,7 +118,7 @@ impl GemmKernel for MlxGemm {
} else {
dispatch_metal_mlx_gemm(
context,
dt,
dts[0],
(batch, m, n, k),
unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) },
a_offset,
Expand Down Expand Up @@ -146,9 +155,8 @@ pub fn dispatch_metal_mlx_gemv(
output_offset: usize,
) -> Result<()> {
ensure!(m == 1 || n == 1);
assert!(a_strides.len() >= 2 && b_strides.len() >= 2);
assert!(a_strides.len() >= 2);
ensure!(matches!(dt, DatumType::F32 | DatumType::F16));
ensure!(a_strides.len() >= 2 && b_strides.len() >= 2);
ensure!(a_strides.len() >= 2);

let lda = if a_trans { m } else { k };
let ldb = if b_trans { k } else { n };
Expand Down Expand Up @@ -283,9 +291,8 @@ pub fn dispatch_metal_mlx_gemm(
output_offset: usize,
debug: bool,
) -> Result<()> {
assert!(rhs_stride.len() >= 2);
assert!(lhs_stride.len() >= 2);
ensure!(matches!(dt, DatumType::F32 | DatumType::F16));
ensure!(rhs_stride.len() >= 2);
ensure!(lhs_stride.len() >= 2);

let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
Expand Down Expand Up @@ -423,7 +430,7 @@ pub fn dispatch_metal_mlx_gemm(
pub fn kernel_name_gemm(dt: DatumType, transpose_a: bool, transpose_b: bool) -> Result<String> {
let t_a = if transpose_a { "t" } else { "n" };
let t_b = if transpose_b { "t" } else { "n" };
ensure!(matches!(dt, DatumType::F32 | DatumType::F16));

let tname = MetalTensor::tname(dt)?;
Ok(format!("gemm_{t_a}{t_b}_{tname}_{tname}_32_32_16_2_2"))
}
Expand Down
Loading
Loading