Skip to content

Commit

Permalink
fix wiring of non-preemptive packed einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jan 30, 2025
1 parent df49145 commit a70f79f
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 54 deletions.
37 changes: 24 additions & 13 deletions core/src/ops/einsum/kernel_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use tract_itertools::Itertools;
use tract_linalg::frame::PackedFormat;
use tract_linalg::mmm::panel_extract::PanelExtractor;
use tract_linalg::mmm::{KitDatumType, MMMInputValue, MatMatMul};
use tract_linalg::mmm::{KitDatumType, MMMInputValue, MatMatMul, WeightType};

use crate::internal::*;
use crate::ops::matmul::pack::OptMatMulPack;
Expand All @@ -27,9 +27,8 @@ pub fn wire_packing(
let b_dt = b_fact.datum_type;

if a_fact.konst.is_some() && op.n.as_i64().is_none() {
let a = a_fact.konst.unwrap();
let (b, impls, picker) = wire_linear(patch, prefix, op, &a, operands[1])?;
return Ok((operands[0], b, impls, picker));
return wire_for_variable_n(patch, prefix, op, operands[0], operands[1])
.context("In wire_linear");
}

// "simple" kernel selection
Expand Down Expand Up @@ -80,23 +79,18 @@ pub fn wire_packing(
Ok((pa, pb, vec![(mmm, packing, None)], mode_picker))
}

pub fn wire_linear(
pub fn wire_for_variable_n(
patch: &mut TypedModelPatch,
prefix: &str,
op: &EinSumAnnotatedAsMatMul,
a: &Arc<Tensor>,
mut a: OutletId,
b: OutletId,
) -> TractResult<(
OutletId,
OutletId,
Vec<(Box<dyn MatMatMul>, usize, Option<PanelExtractor>)>,
ModePicker,
)> {
let packed = a
.to_scalar::<Opaque>()?
.0
.downcast_ref::<Box<dyn MMMInputValue>>()
.unwrap()
.format();
let accumulator = if op.operating_dt.is_integer() {
KitDatumType::I32
} else if op.operating_dt == f16::datum_type() && tract_linalg::has_fp16() {
Expand All @@ -109,16 +103,32 @@ pub fn wire_linear(
DatumType::F32 => KitDatumType::F32,
_ => todo!(),
};

let a_konst = patch.outlet_fact(a)?.konst.as_ref().unwrap();
// preemptive packing ?
let prepack = a_konst
.to_scalar::<Opaque>()
.ok()
.and_then(|opaq| opaq.0.downcast_ref::<Box<dyn MMMInputValue>>());
let kit = tract_linalg::ops()
.mmm_kits()
.iter()
.filter(|kit| {
kit.static_packer.same_as(packed)
prepack
.map(|pre| kit.static_packer.same_as(pre.format()))
.unwrap_or_else(|| kit.weight == WeightType::from(a_konst.datum_type()))
&& kit.accumulator == accumulator
&& kit.activation == activation
})
.min_by_key(|kit| kit.generic_fallback as usize)
.with_context(|| format!("No kit found for pre-packed {a:?}"))?;
if a_konst.datum_type().is_number() {
let packed_a = kit
.static_packer
.prepare_tensor(&a_konst, op.a_k(), op.a_m())?;
let name = patch.node(a.node).name.clone();
a = patch.add_const(name, tensor0(Opaque::from(packed_a)))?;
}

let configs = [kit.item_for_mv(), kit.item_for_squarish()];

Expand All @@ -144,6 +154,7 @@ pub fn wire_linear(
)?[0];

Ok((
a,
pb,
configs
.iter()
Expand Down
181 changes: 140 additions & 41 deletions core/src/ops/einsum/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,8 @@ pub(crate) fn ensure_mkn_axes<'a>(
.axes
.iter_all_axes()
// Filter possible candidates (should be one time in each inputs but not in output)
.filter(|a| {
a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].len() == 0
})
.collect();
.filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].len() == 0)
.collect();

let non_trivial_k_axis = k_axes
.iter()
Expand All @@ -109,7 +107,11 @@ pub(crate) fn ensure_mkn_axes<'a>(
// TODO: handle case where multiple consecutive k in the same order in both input.
bail!("Multiple k-axis candidate found");
} else {
non_trivial_k_axis.first().copied().or_else(|| k_axes.first()).copied()
non_trivial_k_axis
.first()
.copied()
.or_else(|| k_axes.first())
.copied()
};
let Some(k_axis) = k_axis else {
return Ok(AxesOrPatch::Patch(inject_k_axis(op, model, node)?));
Expand All @@ -124,7 +126,13 @@ pub(crate) fn ensure_mkn_axes<'a>(
})
.max_by_key(|a| output_shape[a.outputs[0][0]].as_i64().unwrap_or(i64::MAX));
let Some(m_axis) = m_axis else {
return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(op, model, node, false, &[k_axis])?));
return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(
op,
model,
node,
false,
&[k_axis],
)?));
};
let n_axis = op
.axes
Expand All @@ -146,19 +154,37 @@ pub(crate) fn ensure_mkn_axes<'a>(
};
for axis in op.axes.iter_all_axes() {
let one = TDim::one();
let in_left =
axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
let in_right =
axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
let in_out = axis.outputs[0].first().map(|pos| &output_shape[*pos]).unwrap_or(&one) != &one;
let in_left = axis.inputs[0]
.first()
.map(|pos| &input_shapes[0][*pos])
.unwrap_or(&one)
!= &one;
let in_right = axis.inputs[1]
.first()
.map(|pos| &input_shapes[1][*pos])
.unwrap_or(&one)
!= &one;
let in_out = axis.outputs[0]
.first()
.map(|pos| &output_shape[*pos])
.unwrap_or(&one)
!= &one;
if (in_left ^ in_right) && !in_out {
return Ok(AxesOrPatch::NotAMatMul(axis));
}
}
let m = input_shapes[0][m_axis.inputs[0][0]].clone();
let k = input_shapes[0][k_axis.inputs[0][0]].clone();
let n = input_shapes[1][n_axis.inputs[1][0]].clone();
Ok(AxesOrPatch::Annotated(EinSumAnnotatedAsMatMul { op, m_axis, k_axis, n_axis, m, k, n }))
Ok(AxesOrPatch::Annotated(EinSumAnnotatedAsMatMul {
op,
m_axis,
k_axis,
n_axis,
m,
k,
n,
}))
}

pub(super) fn inject_k_axis(
Expand All @@ -171,14 +197,19 @@ pub(super) fn inject_k_axis(
let mut patch = TypedModelPatch::new("inject k axis");
let mut wire = patch.taps(model, &node.inputs)?;
let repr = new_axes.available_label();
new_axes = new_axes.with_extra_axis(repr, InOut::In(0), 0)?.with_extra_axis_occurency(
repr,
InOut::In(1),
0,
)?;
new_axes = new_axes
.with_extra_axis(repr, InOut::In(0), 0)?
.with_extra_axis_occurency(repr, InOut::In(1), 0)?;
wire[0] = patch.wire_node(format!("{name}.add_k.0"), AxisOp::Add(0), &[wire[0]])?[0];
wire[1] = patch.wire_node(format!("{name}.add_k.1"), AxisOp::Add(0), &[wire[1]])?[0];
wire = patch.wire_node(&node.name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
wire = patch.wire_node(
&node.name,
EinSum {
axes: new_axes,
..op.clone()
},
&wire,
)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
Ok(patch)
}
Expand All @@ -194,21 +225,31 @@ pub(super) fn inject_m_or_n_axis(
let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
let input_to_fix = is_n as usize;
let label = if is_n { "n" } else { "m" };
let quasi_m_or_n_axis = op.axes.iter_all_axes().filter(|a| !exclude.contains(a)).find(|a| {
(a.inputs[1 - input_to_fix].len() == 0
|| input_shapes[1 - input_to_fix][a.inputs[1 - input_to_fix][0]].is_one())
&& (a.inputs[input_to_fix].len() == 1 || a.outputs[0].len() == 1)
});
let quasi_m_or_n_axis = op
.axes
.iter_all_axes()
.filter(|a| !exclude.contains(a))
.find(|a| {
(a.inputs[1 - input_to_fix].len() == 0
|| input_shapes[1 - input_to_fix][a.inputs[1 - input_to_fix][0]].is_one())
&& (a.inputs[input_to_fix].len() == 1 || a.outputs[0].len() == 1)
});
let name = &node.name;
let mut patch = TypedModelPatch::new("Injecting m or n axis");
let mut wire = patch.taps(model, &node.inputs)?;
if let Some(axis) = quasi_m_or_n_axis {
if axis.inputs[input_to_fix].len() == 1 {
let new_axes =
op.axes.clone().with_extra_axis('$', InOut::Out(0), 0)?.linking(axis.repr, '$')?;
let new_axes = op
.axes
.clone()
.with_extra_axis('$', InOut::Out(0), 0)?
.linking(axis.repr, '$')?;
wire = patch.wire_node(
format!("{name}.einsum"),
EinSum { axes: new_axes, ..op.clone() },
EinSum {
axes: new_axes,
..op.clone()
},
&wire,
)?;
wire = patch.wire_node(&node.name, AxisOp::Rm(0), &wire)?;
Expand All @@ -223,7 +264,14 @@ pub(super) fn inject_m_or_n_axis(
AxisOp::Add(0),
&[wire[input_to_fix]],
)?[0];
wire = patch.wire_node(&node.name, EinSum { axes: new_axes, ..op.clone() }, &wire)?;
wire = patch.wire_node(
&node.name,
EinSum {
axes: new_axes,
..op.clone()
},
&wire,
)?;
}
} else {
let repr = op.axes.available_label();
Expand All @@ -240,7 +288,10 @@ pub(super) fn inject_m_or_n_axis(
)?[0];
wire = patch.wire_node(
format!("{name}.einsum"),
EinSum { axes: new_axes, ..op.clone() },
EinSum {
axes: new_axes,
..op.clone()
},
&wire,
)?;
wire = patch.wire_node(&node.name, AxisOp::Rm(0), &wire)?;
Expand Down Expand Up @@ -273,7 +324,12 @@ fn dequant(
let mut taps = patch.taps(model, &node.inputs)?;
for ab in [0, 1] {
let scale_input = 4 + ab * 2;
if !patch.outlet_fact(taps[scale_input])?.shape.volume().is_one() {
if !patch
.outlet_fact(taps[scale_input])?
.shape
.volume()
.is_one()
{
let q_axis_in_output = op.axes.axis((InOut::In(scale_input), 0))?.outputs[0][0];
let output_rank = node.outputs[0].fact.rank();
for i in 1..(output_rank - q_axis_in_output) {
Expand All @@ -290,8 +346,22 @@ fn dequant(
bail!("Expect exactly 9 inputs")
};

wire_ensure_q8_flavour(&mut patch, &node.name, &mut a, "a", &mut a0, i8::datum_type())?;
wire_ensure_q8_flavour(&mut patch, &node.name, &mut b, "b", &mut b0, i8::datum_type())?;
wire_ensure_q8_flavour(
&mut patch,
&node.name,
&mut a,
"a",
&mut a0,
i8::datum_type(),
)?;
wire_ensure_q8_flavour(
&mut patch,
&node.name,
&mut b,
"b",
&mut b0,
i8::datum_type(),
)?;

let mut output = patch.wire_node(
&node.name,
Expand All @@ -316,13 +386,28 @@ fn dequant(
&[b_i32],
)?;

let sum_a =
wire_axes_fix(&mut patch, name, "sum_a", &op.axes.extract_sub_mapping(&[0], &[0])?, sum_a)?;
let sum_b =
wire_axes_fix(&mut patch, name, "sum_b", &op.axes.extract_sub_mapping(&[1], &[0])?, sum_b)?;
let sum_a = wire_axes_fix(
&mut patch,
name,
"sum_a",
&op.axes.extract_sub_mapping(&[0], &[0])?,
sum_a,
)?;
let sum_b = wire_axes_fix(
&mut patch,
name,
"sum_b",
&op.axes.extract_sub_mapping(&[1], &[0])?,
sum_b,
)?;
let bias = tvec!(bias);
let bias =
wire_axes_fix(&mut patch, name, "bias", &op.axes.extract_sub_mapping(&[2], &[0])?, bias)?;
let bias = wire_axes_fix(
&mut patch,
name,
"bias",
&op.axes.extract_sub_mapping(&[2], &[0])?,
bias,
)?;

let abc_scale = combine_scales(&mut patch, name, a_scale, b_scale, c_scale)?;

Expand All @@ -331,7 +416,14 @@ fn dequant(
let k = model.outlet_fact(node.inputs[0])?.shape[op.k_axis.inputs[0][0]].clone();
let output = compensate_zero_points(&mut patch, name, output[0], k, a0, b0, sum_a[0], sum_b[0])
.context("Zero point compensation")?;
let output = requant(&mut patch, name, output, op.q_params.unwrap(), abc_scale, c0)?;
let output = requant(
&mut patch,
name,
output,
op.q_params.unwrap(),
abc_scale,
c0,
)?;
patch.shunt_outside(model, node.id.into(), output)?;
Ok(Some(patch))
}
Expand Down Expand Up @@ -363,15 +455,19 @@ fn optimized_mat_mul(
model,
node,
&[node.inputs[1], node.inputs[0]],
EinSum { axes: AxesMapping::new(node.inputs.len(), 1, expr)?, ..op.op.clone() },
EinSum {
axes: AxesMapping::new(node.inputs.len(), 1, expr)?,
..op.op.clone()
},
)
.map(Some);
}

let mut patch = TypedModelPatch::new("Einsum to OptMatMul");
let name = &node.name;
let taps = patch.taps(model, &node.inputs)?;
let (a, b, mmms, mode_picker) = wire_packing(&mut patch, name, &taps[0..2], op)?;
let (a, b, mmms, mode_picker) =
wire_packing(&mut patch, name, &taps[0..2], op).context("Wiring packing")?;

let mut c_to_a_axis_mapping = tvec!();
let mut c_to_b_axis_mapping = tvec!();
Expand Down Expand Up @@ -402,7 +498,10 @@ fn optimized_mat_mul(
c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
};
let (mmms, packings, extractor): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(mmms);
let outputs = mmms.iter().map(|mmm| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
let outputs = mmms
.iter()
.map(|mmm| unsafe { mmm.c_view(op.c_m(), op.c_n()) })
.collect();
let trivial_packing =
mmms.len() == 1 && packings[0] == 0 && patch.outlet_fact(a)?.opaque_fact.is_none();
let opt = OptMatMul::new(
Expand Down

0 comments on commit a70f79f

Please sign in to comment.