Skip to content

Commit

Permalink
refactor!: Move canon check to BFieldElement
Browse files Browse the repository at this point in the history
- Reject non-canonical `BFieldElement`s when parsing string slices or
  interpreting byte sequences
- Move canon check of `BFieldElement`s from `Digest` to `BFieldElement`
- Deprecate functionality for platform-dependent endianness when
  decoding

BREAKING CHANGE: The error variant `NotCanonical` from enum
`TryFromDigestError` has moved into enum `ParseBFieldElementError`.
  • Loading branch information
jan-ferdinand committed Oct 7, 2024
1 parent 8482f93 commit 7dafa32
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 60 deletions.
12 changes: 9 additions & 3 deletions twenty-first/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ pub use crate::util_types::merkle_tree::MerkleTreeError;
pub enum ParseBFieldElementError {
#[error("invalid `u64`")]
ParseU64Error(#[source] <u64 as FromStr>::Err),

#[error("non-canonical {0} >= {} == `BFieldElement::P`", BFieldElement::P)]
NotCanonical(u64),

#[error(
"incorrect number of bytes: {0} != {} == `BFieldElement::BYTES`",
BFieldElement::BYTES
)]
InvalidNumBytes(usize),
}

#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Error)]
Expand Down Expand Up @@ -42,9 +51,6 @@ pub enum TryFromDigestError {
#[error("invalid `BFieldElement`")]
InvalidBFieldElement(#[from] ParseBFieldElementError),

#[error("non-canonical {0} >= {} == `BFieldElement::P`", BFieldElement::P)]
NotCanonical(u64),

#[error("overflow converting to Digest")]
Overflow,
}
Expand Down
63 changes: 50 additions & 13 deletions twenty-first/src/math/b_field_element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,28 @@ impl BFieldElement {
pub const BYTES: usize = 8;

/// The base field's prime, _i.e._, 2^64 - 2^32 + 1.
pub const P: u64 = 0xffff_ffff_0000_0001u64;
pub const P: u64 = 0xffff_ffff_0000_0001;
pub const MAX: u64 = Self::P - 1;

/// 2^128 mod P; this is used for conversion of elements into Montgomery representation.
const R2: u64 = 0xFFFFFFFE00000001;
const R2: u64 = 0xffff_fffe_0000_0001;

/// -2^-1
pub const MINUS_TWO_INVERSE: Self = Self::new(9223372034707292160);
pub const MINUS_TWO_INVERSE: Self = Self::new(0x7fff_ffff_8000_0000);

#[inline]
pub const fn new(value: u64) -> Self {
Self(Self::montyred((value as u128) * (Self::R2 as u128)))
}

/// Construct a new base field element iff the given value is
/// [canonical][Self::is_canonical], an error otherwise.
fn try_new(v: u64) -> Result<Self, ParseBFieldElementError> {
Self::is_canonical(v)
.then(|| Self::new(v))
.ok_or(ParseBFieldElementError::NotCanonical(v))
}

#[inline]
pub const fn value(&self) -> u64 {
self.canonical_representation()
Expand Down Expand Up @@ -294,6 +302,10 @@ impl BFieldElement {
}

/// Convert a `BFieldElement` from a byte slice in native endianness.
#[deprecated(
since = "0.42.0",
note = "endianness must not be platform specific; use `<&[u8]>::try_from()` instead"
)]
pub fn from_ne_bytes(bytes: &[u8]) -> BFieldElement {
let mut bytes_copied: [u8; 8] = [0; 8];
bytes_copied.copy_from_slice(bytes);
Expand Down Expand Up @@ -393,7 +405,7 @@ impl FromStr for BFieldElement {

fn from_str(s: &str) -> Result<Self, Self::Err> {
let parsed = s.parse().map_err(Self::Err::ParseU64Error)?;
Ok(BFieldElement::new(parsed))
Self::try_new(parsed)
}
}

Expand Down Expand Up @@ -559,10 +571,21 @@ impl From<BFieldElement> for [u8; BFieldElement::BYTES] {
}
}

impl From<[u8; BFieldElement::BYTES]> for BFieldElement {
fn from(array: [u8; BFieldElement::BYTES]) -> Self {
let n = u64::from_le_bytes(array);
BFieldElement::new(n)
impl TryFrom<[u8; BFieldElement::BYTES]> for BFieldElement {
type Error = ParseBFieldElementError;

fn try_from(array: [u8; BFieldElement::BYTES]) -> Result<Self, Self::Error> {
Self::try_new(u64::from_le_bytes(array))
}
}

impl TryFrom<&[u8]> for BFieldElement {
type Error = ParseBFieldElementError;

fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
<[u8; BFieldElement::BYTES]>::try_from(bytes)
.map_err(|_| Self::Error::InvalidNumBytes(bytes.len()))?
.try_into()
}
}

Expand Down Expand Up @@ -820,6 +843,22 @@ mod b_prime_field_element_test {
prop_assert_eq!(bfe, deserialized);
}

#[proptest]
fn parsing_string_representing_canonical_u64_gives_correct_bfield_element(
#[strategy(0..=BFieldElement::MAX)] v: u64,
) {
let bfe = BFieldElement::from_str(&v.to_string()).unwrap();
prop_assert_eq!(v, bfe.value());
}

#[proptest]
fn parsing_string_representing_too_big_u64_as_bfield_element_gives_error(
#[strategy(BFieldElement::P..)] v: u64,
) {
let err = BFieldElement::from_str(&v.to_string()).err().unwrap();
prop_assert_eq!(ParseBFieldElementError::NotCanonical(v), err);
}

#[proptest]
fn zero_is_neutral_element_for_addition(bfe: BFieldElement) {
let zero = BFieldElement::ZERO;
Expand Down Expand Up @@ -961,16 +1000,14 @@ mod b_prime_field_element_test {
#[proptest]
fn byte_array_conversion(bfe: BFieldElement) {
let array: [u8; 8] = bfe.into();
let bfe_recalculated: BFieldElement = array.into();
let bfe_recalculated: BFieldElement = array.try_into()?;
prop_assert_eq!(bfe, bfe_recalculated);
}

#[proptest]
fn byte_array_outside_range_is_brought_into_range(#[strategy(BFieldElement::P..)] value: u64) {
fn byte_array_outside_range_is_not_accepted(#[strategy(BFieldElement::P..)] value: u64) {
let byte_array = value.to_le_bytes();
let bfe: BFieldElement = byte_array.into();
let expected_value = value - BFieldElement::P;
assert_eq!(expected_value, bfe.value());
prop_assert!(BFieldElement::try_from(byte_array).is_err());
}

#[proptest]
Expand Down
57 changes: 22 additions & 35 deletions twenty-first/src/math/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use serde::Deserializer;
use serde::Serialize;
use serde::Serializer;

use crate::error::ParseBFieldElementError;
use crate::error::TryFromDigestError;
use crate::error::TryFromHexDigestError;
use crate::math::b_field_element::BFieldElement;
Expand Down Expand Up @@ -127,17 +126,10 @@ impl FromStr for Digest {
type Err = TryFromDigestError;

fn from_str(string: &str) -> Result<Self, Self::Err> {
let maybe_parsed_u64s: Result<Vec<_>, _> =
string.split(',').map(str::parse::<u64>).collect();
let parsed_u64s = maybe_parsed_u64s.map_err(ParseBFieldElementError::ParseU64Error)?;

// checks if each u64 is canonical before instantiating into BFE.
let bfe_try_from = |v: u64| -> Result<BFieldElement, _> {
let bfe = BFieldElement::is_canonical(v).then(|| BFieldElement::new(v));
bfe.ok_or(TryFromDigestError::NotCanonical(v))
};
let bfes: Vec<_> = parsed_u64s.into_iter().map(bfe_try_from).try_collect()?;

let bfes: Vec<_> = string
.split(',')
.map(str::parse::<BFieldElement>)
.try_collect()?;
let invalid_len_err = Self::Err::InvalidLength(bfes.len());
let digest_innards = bfes.try_into().map_err(|_| invalid_len_err)?;

Expand Down Expand Up @@ -170,10 +162,9 @@ impl From<Digest> for Vec<BFieldElement> {
}

impl From<Digest> for [u8; Digest::BYTES] {
fn from(item: Digest) -> Self {
let u64s = item.0.iter().map(|x| x.value());
u64s.map(|x| x.to_le_bytes())
.collect::<Vec<_>>()
fn from(Digest(innards): Digest) -> Self {
innards
.map(<[u8; BFieldElement::BYTES]>::from)
.concat()
.try_into()
.unwrap()
Expand All @@ -184,20 +175,9 @@ impl TryFrom<[u8; Digest::BYTES]> for Digest {
type Error = TryFromDigestError;

fn try_from(item: [u8; Self::BYTES]) -> Result<Self, Self::Error> {
let chunk_into_bfe = |chunk: &[u8]| -> Result<BFieldElement, _> {
let mut arr = [0u8; BFieldElement::BYTES];
arr.copy_from_slice(chunk);
let int = u64::from_le_bytes(arr);

// return bfe, or error if not canonical
BFieldElement::is_canonical(int)
.then(|| BFieldElement::new(int))
.ok_or(TryFromDigestError::NotCanonical(int))
};

let digest_innards: Vec<_> = item
.chunks_exact(BFieldElement::BYTES)
.map(chunk_into_bfe)
.map(BFieldElement::try_from)
.try_collect()?;

Ok(Self(digest_innards.try_into().unwrap()))
Expand Down Expand Up @@ -325,9 +305,11 @@ pub(crate) mod digest_tests {
use proptest_arbitrary_interop::arb;
use test_strategy::proptest;

use super::*;
use crate::error::ParseBFieldElementError;
use crate::prelude::*;

use super::*;

impl ProptestArbitrary for Digest {
type Parameters = ();
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
Expand Down Expand Up @@ -530,8 +512,10 @@ pub(crate) mod digest_tests {
fn try_from_bytes_not_canonical() -> Result<(), TryFromDigestError> {
let bytes: [u8; Digest::BYTES] = [255; Digest::BYTES];

assert!(Digest::try_from(bytes)
.is_err_and(|e| matches!(e, TryFromDigestError::NotCanonical(_))));
assert!(Digest::try_from(bytes).is_err_and(|e| matches!(
e,
TryFromDigestError::InvalidBFieldElement(ParseBFieldElementError::NotCanonical(_))
)));

Ok(())
}
Expand All @@ -541,9 +525,10 @@ pub(crate) mod digest_tests {
fn from_str_not_canonical() -> Result<(), TryFromDigestError> {
let str = format!("0,0,0,0,{}", u64::MAX);

assert!(
Digest::from_str(&str).is_err_and(|e| matches!(e, TryFromDigestError::NotCanonical(_)))
);
assert!(Digest::from_str(&str).is_err_and(|e| matches!(
e,
TryFromDigestError::InvalidBFieldElement(ParseBFieldElementError::NotCanonical(_))
)));

Ok(())
}
Expand Down Expand Up @@ -669,7 +654,9 @@ pub(crate) mod digest_tests {
)
.is_err_and(|e| matches!(
e,
TryFromHexDigestError::Digest(TryFromDigestError::NotCanonical(_))
TryFromHexDigestError::Digest(TryFromDigestError::InvalidBFieldElement(
ParseBFieldElementError::NotCanonical(_)
))
)));
}
}
Expand Down
13 changes: 4 additions & 9 deletions twenty-first/src/util_types/mmr/mmr_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,15 +459,10 @@ pub mod util {

impl<'a> Arbitrary<'a> for MmrAccumulator {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let mut buffer = [0u8; 8];
u.fill_buffer(&mut buffer)?;
let num_leafs = u64::from_be_bytes(buffer) >> 1; // num_leafs can be at most 63 bits

let mut peaks = vec![];
for _ in 0..(num_leafs.count_ones()) {
let peak = Digest::arbitrary(u)?;
peaks.push(peak);
}
let num_leafs = u.arbitrary::<u64>()? >> 1; // num_leafs can be at most 63 bits
let peaks = (0..num_leafs.count_ones())
.map(|_| Digest::arbitrary(u))
.try_collect()?;

Ok(MmrAccumulator::init(peaks, num_leafs))
}
Expand Down

0 comments on commit 7dafa32

Please sign in to comment.