From 93e0fad13d0f92f7cf32a26b277be640d92461f1 Mon Sep 17 00:00:00 2001 From: Rute Figueiredo Date: Mon, 30 Oct 2023 21:03:27 +0000 Subject: [PATCH] Feature/optimise step selector builder (#154) ## Optimisation of Step Selector for Log2(n) Columns ### Description - **Goal:** is to optimise the Step Selector Builder to have Log2(n) Columns per n Step Types instead of one Column per each Step Type. - **Current Implementation:** implemented a new type of Selector and implemented the build function for it - **Explanation:** If we use binary representation for the step types instead of using the actual columns we will get an optimised version for the Selector Builder since we can get the total of columns from the following expression ```math n\_cols = \lceil \log_2(n\_step\_types + 1) \rceil ``` We need take into consideration that we also need to save one binary value for the case when there is no step type, because of this we do the math with \(n\_step\_types + 1\). ### What is missing - [x] Ensure compatibility with the backend - [x] Unit tests --------- Co-authored-by: nullbitx8 <92404251+nullbitx8@users.noreply.github.com> Co-authored-by: John Smith Co-authored-by: Leo Lara Co-authored-by: Jaewon In Co-authored-by: Steve Wang --- Cargo.toml | 2 +- src/plonkish/compiler/step_selector.rs | 151 +++++++++++++++++++++++++ 2 files changed, 152 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 44a9e6ef..c15cf949 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chiquito" -version = "0.1.2023101700" +version = "0.1.2023101100" edition = "2021" license = "MIT OR Apache-2.0" authors = ["Leo Lara "] diff --git a/src/plonkish/compiler/step_selector.rs b/src/plonkish/compiler/step_selector.rs index a95cdd60..e49b7f92 100644 --- a/src/plonkish/compiler/step_selector.rs +++ b/src/plonkish/compiler/step_selector.rs @@ -188,6 +188,64 @@ impl StepSelectorBuilder for TwoStepsSelectorBuilder { } } +#[derive(Debug, Default, Clone)] +pub struct LogNSelectorBuilder {} + +impl StepSelectorBuilder for LogNSelectorBuilder { + fn build(&self, unit: &mut CompilationUnit) { + let mut selector: StepSelector = StepSelector { + selector_expr: HashMap::new(), + selector_expr_not: HashMap::new(), + selector_assignment: HashMap::new(), + columns: Vec::new(), + }; + + let n_step_types = unit.step_types.len() as u64; + let n_cols = (n_step_types as f64 + 1.0).log2().ceil() as u64; + + let mut annotation; + for index in 0..n_cols { + annotation = format!("'binary selector column {}'", index); + + let column = Column::advice(annotation.clone(), 0); + selector.columns.push(column.clone()); + } + + let mut step_value = 1; + for step in unit.step_types.values() { + let mut combined_expr = PolyExpr::Const(F::ONE); + let mut assignments = Vec::new(); + + for i in 0..n_cols { + let bit = (step_value >> i) & 1; // Extract the i-th bit of step_value + let column = &selector.columns[i as usize]; + + if bit == 1 { + combined_expr = combined_expr * column.query(0, format!("Column {}", i)); + assignments.push((column.query(0, format!("Column {}", i)), F::ONE)); + } else { + combined_expr = combined_expr + * (PolyExpr::Const(F::ONE) - column.query(0, format!("Column {}", i))); + } + } + + selector + .selector_expr + .insert(step.uuid(), combined_expr.clone()); + selector + .selector_expr_not + .insert(step.uuid(), PolyExpr::Const(F::ONE) - combined_expr.clone()); + selector + .selector_assignment + .insert(step.uuid(), assignments); + step_value += 1; + } + + unit.columns.extend_from_slice(&selector.columns); + unit.selector = selector; + } +} + fn other_step_type(unit: &CompilationUnit, uuid: UUID) -> Option>> { for step_type in unit.step_types.values() { if step_type.uuid() != uuid { @@ -197,3 +255,96 @@ fn other_step_type(unit: &CompilationUnit, uuid: UUID) -> Option() -> CompilationUnit { + CompilationUnit::default() + } + + fn add_step_types_to_unit(unit: &mut CompilationUnit, n_step_types: usize) { + for i in 0..n_step_types { + let uuid_value = Uuid::now_v1(&[1, 2, 3, 4, 5, 6]).as_u128(); + unit.step_types.insert( + uuid_value, + Rc::new(StepType::new(uuid_value, format!("StepType{}", i))), + ); + } + } + + fn assert_common_tests(unit: &CompilationUnit, expected_cols: usize) { + assert_eq!(unit.columns.len(), expected_cols); + assert_eq!(unit.selector.columns.len(), expected_cols); + for step_type in unit.step_types.values() { + assert!(unit + .selector + .selector_assignment + .contains_key(&step_type.uuid())); + assert!(unit.selector.selector_expr.contains_key(&step_type.uuid())); + } + } + + #[test] + fn test_log_n_selector_builder_3_step_types() { + let builder = LogNSelectorBuilder {}; + let mut unit = mock_compilation_unit::(); + + add_step_types_to_unit(&mut unit, 3); + builder.build(&mut unit); + assert_common_tests(&unit, 2); + + // Asserts expressions for 3 step types + let expr10_temp = format!( + "(0x1 * {:#?} * (0x1 + (-{:#?})))", + &unit.selector.columns[0].query::(0, "Column 0"), + &unit.selector.columns[1].query::(0, "Column 1") + ); + let expr01_temp = format!( + "(0x1 * (0x1 + (-{:#?})) * {:#?})", + &unit.selector.columns[0].query::(0, "Column 0"), + &unit.selector.columns[1].query::(0, "Column 1") + ); + let expr11_temp = format!( + "(0x1 * {:#?} * {:#?})", + &unit.selector.columns[0].query::(0, "Column 0"), + &unit.selector.columns[1].query::(0, "Column 1") + ); + let expected_exprs = [expr01_temp.trim(), expr10_temp.trim(), expr11_temp.trim()]; + + for expr in unit.selector.selector_expr.values() { + let expr_str = format!("{:#?}", expr); + assert!( + expected_exprs.contains(&expr_str.trim()), + "Unexpected expression: {}", + expr_str + ); + } + } + + #[test] + fn test_log_n_selector_builder_4_step_types() { + let builder = LogNSelectorBuilder {}; + let mut unit = mock_compilation_unit::(); + + add_step_types_to_unit(&mut unit, 4); + builder.build(&mut unit); + assert_common_tests(&unit, 3); + } + + #[test] + fn test_log_n_selector_builder_10_step_types() { + let builder = LogNSelectorBuilder {}; + let mut unit = mock_compilation_unit::(); + + add_step_types_to_unit(&mut unit, 10); + builder.build(&mut unit); + + let expected_cols = (10_f64 + 1.0).log2().ceil() as usize; + assert_common_tests(&unit, expected_cols); + } +}