diff --git a/core/src/air/extension.rs b/core/src/air/extension.rs index 1a546c8b16..6e8401bbdc 100644 --- a/core/src/air/extension.rs +++ b/core/src/air/extension.rs @@ -19,9 +19,9 @@ pub const DEGREE: usize = 4; #[repr(C)] pub struct Extension(pub [T; DEGREE]); // Degree 4 is hard coded for now. TODO: Change to a const generic -impl Extension { +impl Extension { // Returns the one element of the extension field - pub fn one>() -> Extension + pub fn one>() -> Extension where AB::Expr: AbstractField, { @@ -30,51 +30,72 @@ impl Extension { } // Converts a field element to extension element - pub fn from>(x: V) -> Extension - where - AB::Expr: From, - { - Extension(field_to_array(x.into())) + pub fn from>(x: E) -> Extension { + Extension(field_to_array(x)) } // Negates an extension field Element - pub fn neg>(self) -> Extension { + pub fn neg>(self) -> Extension { Extension(self.0.map(|x| AB::Expr::zero() - x)) } // Adds an extension field element - pub fn add>(self, rhs: &Self) -> Extension + pub fn add>(self, rhs: &Self) -> Extension where - V: Add + Copy, + E: Add, { let mut elements = Vec::new(); - for (e1, e2) in self.0.into_iter().zip_eq(rhs.0.into_iter()) { + for (e1, e2) in self.0.into_iter().zip_eq(rhs.0.clone().into_iter()) { elements.push(e1 + e2); } Extension(elements.try_into().unwrap()) } + // Subtracts an extension field element + pub fn sub>(self, rhs: &Self) -> Extension + where + E: Add, + { + let mut elements = Vec::new(); + + for (e1, e2) in self.0.into_iter().zip_eq(rhs.0.clone().into_iter()) { + elements.push(e1 - e2); + } + + Extension(elements.try_into().unwrap()) + } + // Multiplies an extension field element - pub fn mul>(self, rhs: &Self) -> Extension + pub fn mul>(self, rhs: &Self) -> Extension where - V: Mul + Copy, + E: Mul, { let mut elements = Vec::new(); - for (e1, e2) in self.0.into_iter().zip_eq(rhs.0.into_iter()) { + for (e1, e2) in self.0.into_iter().zip_eq(rhs.0.clone().into_iter()) { elements.push(e1 * e2); } Extension(elements.try_into().unwrap()) } - pub fn as_base_slice(&self) -> &[V] { + pub fn as_base_slice(&self) -> &[E] { &self.0 } } +impl Extension { + // Converts a field element with var base elements to one with expr base elements. + pub fn from_var>(self) -> Extension + where + V: Into, + { + Extension(self.0.map(|x| x.into())) + } +} + impl From> for Extension where F: Field, diff --git a/core/src/operations/div_extension.rs b/core/src/operations/div_extension.rs index a05fd65650..40abaeb340 100644 --- a/core/src/operations/div_extension.rs +++ b/core/src/operations/div_extension.rs @@ -2,11 +2,8 @@ //! use core::borrow::Borrow; use core::borrow::BorrowMut; -use p3_air::AirBuilder; use p3_field::extension::BinomialExtensionField; use p3_field::extension::BinomiallyExtendable; -use p3_field::AbstractField; -use p3_field::Field; use sp1_derive::AlignedBorrow; use std::mem::size_of; @@ -14,10 +11,14 @@ use crate::air::Extension; use crate::air::SP1AirBuilder; use crate::air::DEGREE; +use super::IsEqualExtOperation; + /// A set of columns needed to compute whether the given word is 0. #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct DivExtOperation { + pub is_equal: IsEqualExtOperation, + /// Result is the quotient pub result: Extension, } @@ -30,6 +31,10 @@ impl> DivExtOperation { ) -> BinomialExtensionField { let result = a / b; self.result = result.into(); + + let product = b * result; + self.is_equal.populate(a, product); + result } @@ -39,16 +44,12 @@ impl> DivExtOperation { b: Extension, cols: DivExtOperation, is_real: AB::Expr, - ) { + ) where + AB::F: BinomiallyExtendable, + { builder.assert_bool(is_real.clone()); - let product = b.mul(&cols.result); - builder.when(is_real.clone()).assert_eq(product, a); - - // If the result is 1, then the input is 0. - builder - .when(is_real.clone()) - .when(cols.result) - .assert_zero(a.clone()); + let product = b.mul::(&cols.result.from_var::()); + IsEqualExtOperation::::eval(builder, a, product, cols.is_equal, is_real.clone()); } } diff --git a/core/src/operations/is_equal_extension.rs b/core/src/operations/is_equal_extension.rs index 9f95f59de3..bd2361959c 100644 --- a/core/src/operations/is_equal_extension.rs +++ b/core/src/operations/is_equal_extension.rs @@ -36,11 +36,13 @@ impl> IsEqualExtOperation { b: Extension, cols: IsEqualExtOperation, is_real: AB::Expr, - ) { + ) where + AB::F: BinomiallyExtendable, + { builder.assert_bool(is_real.clone()); // Calculate differences. - let diff = a.sub(b); + let diff = a.sub::(&b); // Check if the difference is 0. IsZeroExtOperation::::eval(builder, diff, cols.is_diff_zero, is_real.clone());