Skip to content

Commit

Permalink
compute directly cardinality
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Feb 12, 2025
1 parent ad838bb commit df8c82c
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions core/src/ops/nn/reduce.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::internal::Axis;
use crate::internal::*;
use crate::ops::binary::TypedBinOp;
use crate::ops::cast::{cast, wire_cast};
use crate::ops::cast::cast;
use crate::ops::change_axes::wire_with_rank_broadcast;
use crate::ops::element_wise::ElementWiseOp;
use crate::ops::math::{div, mul, square, Mul, Square};
use crate::ops::math::{div, square, Mul, Square};
use std::convert::TryFrom;
use std::iter::Sum;
use std::mem::transmute;
Expand Down Expand Up @@ -629,29 +629,37 @@ pub fn expand_mean_of_squares(
if op.reducer == Reducer::MeanOfSquares {
let mut patch = TypedModelPatch::default();
let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
let dt = model.outlet_fact(node.inputs[0])?.datum_type;
let input_fact = model.outlet_fact(node.inputs[0])?;
let dt = input_fact.datum_type;
if dt != f32::datum_type() {
wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?;
}
wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?;
let input_size = patch.outlet_fact(wire[0])?.shape.volume();
let input_size = patch.add_const(format!("{name}.input_size"), tensor0(input_size))?;
wire = patch.wire_node(
format!("{name}.sum"),
Reduce::new(op.axes.clone(), Reducer::Sum),
&wire,
)?;
let output_size = patch.outlet_fact(wire[0])?.shape.volume();
let output_size = patch.add_const(format!("{name}.output_size"), tensor0(output_size))?;
let norm = wire_cast(
let card = input_fact
.shape
.iter()
.enumerate()
.filter(|(ix, _dim)| op.axes.contains(ix))
.map(|(_ix, dim)| dim)
.product::<TDim>();
let card = patch.add_const(format!("{name}.card"), tensor0(card))?;
let card = patch.wire_node(
format!("{name}.card_to_f32"),
cast(f32::datum_type()),
&[card],
)?;

wire = wire_with_rank_broadcast(
format!("{name}.norm"),
&mut patch,
&[output_size, input_size],
f32::datum_type(),
div(),
&[wire[0], card[0]],
)?;
let norm = patch.wire_node(format!("{name}.norm"), div(), &norm)?[0];
wire =
wire_with_rank_broadcast(format!("{name}.card"), &mut patch, mul(), &[wire[0], norm])?;
if dt != f32::datum_type() {
wire = patch.wire_node(format!("{name}.from_f32"), cast(dt), &wire)?;
}
Expand Down

0 comments on commit df8c82c

Please sign in to comment.