Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal GGML Q40 support #1643

Merged
merged 15 commits into from
Feb 14, 2025
29 changes: 26 additions & 3 deletions core/src/ops/einsum/as_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ impl TypedOp for BasicMatMul {
let [a, b] = inputs else {
bail!("Expects 2 inputs");
};
if a.datum_type.is_number() {
if a.datum_type.is_number() && b.datum_type.is_number(){
ensure!(a.rank() == b.rank());
ensure!(a.rank() >= 2);
ensure!(
Expand Down Expand Up @@ -313,9 +313,32 @@ impl TypedOp for BasicMatMul {
.quantize_output
.unwrap_or(b.datum_type)
.fact(self.output_shape(&a_shape, &b.shape))))
} else if let Some(opf) = inputs[1]
.opaque_fact
.as_ref()
.and_then(|of| of.downcast_ref::<BlockQuantFact>())
.or_else(|| {
inputs[1]
.konst
.as_ref()
.and_then(|k| k.to_scalar::<Opaque>().ok())
.and_then(|o| o.downcast_ref::<BlockQuantValue>())
.map(|v| &v.fact)
})
{
let b_shape: ShapeFact = b
.shape
.iter()
.cloned()
.chain(opf.shape.iter().map(|d| d.to_dim()))
.collect();
Ok(tvec!(self
.quantize_output
.unwrap_or(a.datum_type)
.fact(self.output_shape(&a.shape, &b_shape))))
} else {
todo!()
}
todo!()
}
}

as_op!();
Expand Down
1 change: 1 addition & 0 deletions metal/src/kernels/matmul/basic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl GemmKernel for BasicMatMul {
transpose_b,
b_offset,
c_offset,
..
} = params;

ensure!(
Expand Down
144 changes: 144 additions & 0 deletions metal/src/kernels/matmul/ggml_gemm/ggml_mm_mv.metal
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@

using namespace metal;

#define N_SIMDWIDTH 32 // assuming SIMD group size is 32

#define QK4_0 32
typedef struct {
half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;

typedef struct {
int32_t batch;
int32_t m;
Expand Down Expand Up @@ -207,6 +215,122 @@ typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;

template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;

// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q4 quants begin (0 or QK4_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;

float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };

device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);

for (int i = 0; i < 8; i += 2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
}

return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
}

// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
//Note: This is a template, but strictly speaking it only applies to
// quantizations where the block size is 32. It also does not
// guard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
void mul_vec_q_n_f32_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.k/QK4_0;

const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;

const int first_row = (r0 * nsg + sgitg) * nr;

const uint i12 = im%args.batch;
const uint i13 = im/args.batch;

//const uint64_t offset0 = first_row*args. + (i12/args.channel_broadcast_ratio)*args.b_strides[1] + (i13/args.batch_broadcast_ratio)*args.b_strides[0];
const uint64_t offset1 = r1*args.a_strides[2] + (i12 )*args.a_strides[1] + (i13 )*args.a_strides[0];

//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);

// pointers to src0 rows
device const block_q_type * ax[nr];
for (int row = 0; row < nr; ++row) {
const uint64_t offset0 = (first_row + row)*args.b_strides[2] + (i12/args.channel_broadcast_ratio)*args.b_strides[1] + (i13/args.batch_broadcast_ratio)*args.b_strides[0];

ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
}

float yl[16]; // src1 vector cache
float sumf[nr] = {0.f};

const short ix = (tiisg/2);
const short il = (tiisg%2)*8;

device const float * yb = y + ix*QK4_0 + il;

// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += nw/2) {
float sumy[2] = { 0.f, 0.f };

#pragma unroll
for (int i = 0; i < 8; i += 2) {
sumy[0] += yb[i + 0] + yb[i + 1];
yl[i + 0] = yb[i + 0];
yl[i + 1] = yb[i + 1]/256.f;

sumy[1] += yb[i + 16] + yb[i + 17];
yl[i + 8] = yb[i + 16]/16.f;
yl[i + 9] = yb[i + 17]/4096.f;
}

#pragma unroll
for (int row = 0; row < nr; row++) {
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
}

yb += QK4_0 * 16;
}

device float * dst_f32 = (device float *) dst + im*args.n*args.m + r1*args.n;

for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);

if (tiisg == 0 && first_row + row < args.n) {
dst_f32[first_row + row] = tot;
}
}
}

kernel void kernel_mul_mv_q4_0_f32(
constant ggml_metal_kargs_mul & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}

#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
#define BLOCK_SIZE_K 32
Expand Down Expand Up @@ -372,7 +496,27 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg)
reg = (type4x4)(*src);
}

template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;

float4x4 reg_f;

for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
}

reg = (type4x4) reg_f;
}

typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;

template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
Loading
Loading