Skip to content

Commit

Permalink
approximation level in assertion is tweakable
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 23, 2024
1 parent 9bc8be9 commit 4668dca
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 4 deletions.
8 changes: 8 additions & 0 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,14 @@ fn dump_subcommand<'a>() -> clap::Command<'a> {
fn assertions_options(command: clap::Command) -> clap::Command {
use clap::*;
command
.arg(
Arg::new("approx")
.takes_value(true)
.possible_values(&["exact", "close", "approximate", "super"])
.default_value("close")
.long("approx")
.help("Approximation level used in assertions."),
)
.arg(
Arg::new("assert-output")
.takes_value(true)
Expand Down
11 changes: 9 additions & 2 deletions cli/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ pub struct Assertions {
pub assert_outputs: bool,
pub assert_output_facts: Option<Vec<InferenceFact>>,
pub assert_op_count: Option<Vec<(String, usize)>>,
pub approximation: Approximation
}

impl Assertions {
Expand All @@ -1123,7 +1124,13 @@ impl Assertions {
.map(|mut args| Some((args.next()?.to_string(), args.next()?.parse().ok()?)))
.collect()
});

Ok(Assertions { assert_outputs, assert_output_facts, assert_op_count })
let approximation = match sub.value_of("approx").unwrap() {
"exact" => Approximation::Exact,
"close" => Approximation::Close,
"approximate" => Approximation::Approximate,
"super" => Approximation::SuperApproximate,
_ => panic!()
};
Ok(Assertions { assert_outputs, assert_output_facts, assert_op_count, approximation })
}
}
5 changes: 4 additions & 1 deletion cli/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ pub fn check_outputs(got: &[Vec<TValue>], params: &Parameters) -> TractResult<()
{
exp = exp.cast_to_dt(got.datum_type())?.into_owned().into_tvalue();
}
if let Err(e) = exp.close_enough(&got, true).context(format!("Checking output {ix}")) {
if let Err(e) = exp
.close_enough(&got, params.assertions.approximation)
.context(format!("Checking output {ix}"))
{
if error.is_some() {
error!("{:?}", e);
} else {
Expand Down
3 changes: 2 additions & 1 deletion data/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use std::sync::Arc;
pub mod litteral;
pub mod view;

#[derive(Copy, Clone, PartialEq, Eq, Debug)]
#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
pub enum Approximation {
Exact,
#[default]
Close,
Approximate,
SuperApproximate,
Expand Down

0 comments on commit 4668dca

Please sign in to comment.