From a29e9c6fea07088877060e9d039cab6ec14a25d5 Mon Sep 17 00:00:00 2001 From: moven0831 Date: Fri, 3 Jan 2025 23:43:10 +0800 Subject: [PATCH 1/2] refactor(mont): test and benchmarks to use random BigInt values --- Cargo.lock | 3 ++- mopro-msm/Cargo.toml | 1 + .../tests/mont_backend/mont_benchmarks.rs | 23 +++++-------------- .../tests/mont_backend/mont_mul_cios.rs | 16 ++++--------- .../tests/mont_backend/mont_mul_modified.rs | 16 ++++--------- .../tests/mont_backend/mont_mul_optimised.rs | 16 ++++--------- 6 files changed, 24 insertions(+), 51 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2a306ae..0f2a5f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -1746,6 +1746,7 @@ dependencies = [ "objc", "once_cell", "proptest", + "rand 0.8.5", "rayon", "ruint", "serde", diff --git a/mopro-msm/Cargo.toml b/mopro-msm/Cargo.toml index 2a19c10..d50e7f3 100644 --- a/mopro-msm/Cargo.toml +++ b/mopro-msm/Cargo.toml @@ -46,6 +46,7 @@ objc = { version = "=0.2.7" } proptest = { version = "1.4.0" } rayon = "1.5.1" itertools = "0.13.0" +rand = "0.8.5" [build-dependencies] color-eyre = "0.6" diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs index 04c8a05..c4f5bd3 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs @@ -7,7 +7,8 @@ use crate::msm::metal_msm::utils::mont_params::{ use ark_bn254::Fr as ScalarField; use ark_ff::{BigInt, PrimeField}; use metal::*; -use num_bigint::BigUint; +use num_bigint::{BigUint, RandBigInt}; +use rand::thread_rng; use stopwatch::Stopwatch; #[test] @@ -53,26 +54,14 @@ fn expensive_computation( } pub fn benchmark(log_limb_size: u32, shader_file: &str) -> Result { - let p = BigUint::parse_bytes( - b"30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001", - 16, - ) - .unwrap(); - assert!(p == ScalarField::MODULUS.try_into().unwrap()); + let p: BigUint = ScalarField::MODULUS.try_into().unwrap(); let p_bitwidth = calc_bitwidth(&p); let num_limbs = calc_num_limbs(log_limb_size, p_bitwidth); - let a = BigUint::parse_bytes( - b"10ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", - 16, - ) - .unwrap(); - let b = BigUint::parse_bytes( - b"11ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", - 16, - ) - .unwrap(); + let mut rng = thread_rng(); + let a = rng.gen_biguint_below(&p); + let b = rng.gen_biguint_below(&p); let nsafe = calc_nsafe(log_limb_size); if nsafe == 0 { diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs index a30ad34..ba12fa8 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs @@ -9,7 +9,8 @@ use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, cal use ark_bn254::Fr as ScalarField; use ark_ff::{BigInt, PrimeField}; use metal::*; -use num_bigint::BigUint; +use num_bigint::{BigUint, RandBigInt}; +use rand::thread_rng; #[test] #[serial_test::serial] @@ -34,16 +35,9 @@ pub fn do_test(log_limb_size: u32) { let res = calc_rinv_and_n0(&p, &r, log_limb_size); let n0 = res.1; - let a = BigUint::parse_bytes( - b"10ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", - 16, - ) - .unwrap(); - let b = BigUint::parse_bytes( - b"11ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", - 16, - ) - .unwrap(); + let mut rng = thread_rng(); + let a = rng.gen_biguint_below(&p); + let b = rng.gen_biguint_below(&p); let a_r = &a * &r % &p; let b_r = &b * &r % &p; diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs index 2a4a85e..673d1f8 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs @@ -9,7 +9,8 @@ use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, cal use ark_bn254::Fr as ScalarField; use ark_ff::{BigInt, PrimeField}; use metal::*; -use num_bigint::BigUint; +use num_bigint::{BigUint, RandBigInt}; +use rand::thread_rng; #[test] #[serial_test::serial] @@ -34,16 +35,9 @@ pub fn do_test(log_limb_size: u32) { let res = calc_rinv_and_n0(&p, &r, log_limb_size); let n0 = res.1; - let a = BigUint::parse_bytes( - b"10ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", - 16, - ) - .unwrap(); - let b = BigUint::parse_bytes( - b"11ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", - 16, - ) - .unwrap(); + let mut rng = thread_rng(); + let a = rng.gen_biguint_below(&p); + let b = rng.gen_biguint_below(&p); let a_r = &a * &r % &p; let b_r = &b * &r % &p; diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs index 500c0ee..6fe416b 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs @@ -12,7 +12,8 @@ use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_rinv_and_n use ark_bn254::Fr as ScalarField; use ark_ff::{BigInt, PrimeField}; use metal::*; -use num_bigint::BigUint; +use num_bigint::{BigUint, RandBigInt}; +use rand::thread_rng; #[test] #[serial_test::serial] @@ -37,16 +38,9 @@ pub fn do_test(log_limb_size: u32) { let res = calc_rinv_and_n0(&p, &r, log_limb_size); let n0 = res.1; - let a = BigUint::parse_bytes( - b"10ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", - 16, - ) - .unwrap(); - let b = BigUint::parse_bytes( - b"11ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", - 16, - ) - .unwrap(); + let mut rng = thread_rng(); + let a = rng.gen_biguint_below(&p); + let b = rng.gen_biguint_below(&p); let a_r = &a * &r % &p; let b_r = &b * &r % &p; From 5350f186a5efe3c5a5cb6d573f3268fc49a2ea2e Mon Sep 17 00:00:00 2001 From: moven0831 Date: Fri, 3 Jan 2025 23:46:01 +0800 Subject: [PATCH 2/2] tests(mont): replace ScalarField with BaseField in mont_mul benchmarks and tests --- .../tests/mont_backend/mont_benchmarks.rs | 12 +++++----- .../tests/mont_backend/mont_mul_cios.rs | 14 +++++------ .../tests/mont_backend/mont_mul_modified.rs | 14 +++++------ .../tests/mont_backend/mont_mul_optimised.rs | 24 +++++++++---------- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs index c4f5bd3..b8383a7 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs @@ -4,7 +4,7 @@ use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; use crate::msm::metal_msm::utils::mont_params::{ calc_bitwidth, calc_mont_radix, calc_nsafe, calc_num_limbs, calc_rinv_and_n0, }; -use ark_bn254::Fr as ScalarField; +use ark_bn254::Fq as BaseField; use ark_ff::{BigInt, PrimeField}; use metal::*; use num_bigint::{BigUint, RandBigInt}; @@ -54,7 +54,7 @@ fn expensive_computation( } pub fn benchmark(log_limb_size: u32, shader_file: &str) -> Result { - let p: BigUint = ScalarField::MODULUS.try_into().unwrap(); + let p: BigUint = BaseField::MODULUS.try_into().unwrap(); let p_bitwidth = calc_bitwidth(&p); let num_limbs = calc_num_limbs(log_limb_size, p_bitwidth); @@ -77,20 +77,20 @@ pub fn benchmark(log_limb_size: u32, shader_file: &str) -> Result { let cost = 2u32.pow(16u32) as usize; let expected = expensive_computation(cost, &a, &b, &p, &r); - let expected_limbs = ScalarField::from_bigint(expected.clone().try_into().unwrap()) + let expected_limbs = BaseField::from_bigint(expected.clone().try_into().unwrap()) .unwrap() .into_bigint() .to_limbs(num_limbs, log_limb_size); - let ar_limbs = ScalarField::from_bigint(a_r.clone().try_into().unwrap()) + let ar_limbs = BaseField::from_bigint(a_r.clone().try_into().unwrap()) .unwrap() .into_bigint() .to_limbs(num_limbs, log_limb_size); - let br_limbs = ScalarField::from_bigint(b_r.clone().try_into().unwrap()) + let br_limbs = BaseField::from_bigint(b_r.clone().try_into().unwrap()) .unwrap() .into_bigint() .to_limbs(num_limbs, log_limb_size); - let p_limbs = &ScalarField::MODULUS.to_limbs(num_limbs, log_limb_size); + let p_limbs = &BaseField::MODULUS.to_limbs(num_limbs, log_limb_size); let device = get_default_device(); diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs index ba12fa8..cabca7d 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs @@ -6,7 +6,7 @@ use crate::msm::metal_msm::host::gpu::{ use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0}; -use ark_bn254::Fr as ScalarField; +use ark_bn254::Fq as BaseField; use ark_ff::{BigInt, PrimeField}; use metal::*; use num_bigint::{BigUint, RandBigInt}; @@ -25,11 +25,11 @@ pub fn test_mont_mul_15() { } pub fn do_test(log_limb_size: u32) { - let modulus_bits = ScalarField::MODULUS_BIT_SIZE as u32; + let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32; let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize; let r = calc_mont_radix(num_limbs, log_limb_size); - let p: BigUint = ScalarField::MODULUS.try_into().unwrap(); + let p: BigUint = BaseField::MODULUS.try_into().unwrap(); let nsafe = calc_nsafe(log_limb_size); let res = calc_rinv_and_n0(&p, &r, log_limb_size); @@ -43,9 +43,9 @@ pub fn do_test(log_limb_size: u32) { let b_r = &b * &r % &p; let expected = (&a * &b * &r) % &p; - let a_r_in_ark = ScalarField::from_bigint(a_r.clone().try_into().unwrap()).unwrap(); - let b_r_in_ark = ScalarField::from_bigint(b_r.clone().try_into().unwrap()).unwrap(); - let expected_in_ark = ScalarField::from_bigint(expected.clone().try_into().unwrap()).unwrap(); + let a_r_in_ark = BaseField::from_bigint(a_r.clone().try_into().unwrap()).unwrap(); + let b_r_in_ark = BaseField::from_bigint(b_r.clone().try_into().unwrap()).unwrap(); + let expected_in_ark = BaseField::from_bigint(expected.clone().try_into().unwrap()).unwrap(); let expected_limbs = expected_in_ark .into_bigint() .to_limbs(num_limbs, log_limb_size); @@ -61,7 +61,7 @@ pub fn do_test(log_limb_size: u32) { ); let p_buf = create_buffer( &device, - &ScalarField::MODULUS.to_limbs(num_limbs, log_limb_size), + &BaseField::MODULUS.to_limbs(num_limbs, log_limb_size), ); let result_buf = create_empty_buffer(&device, num_limbs); diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs index 673d1f8..0928c60 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs @@ -6,7 +6,7 @@ use crate::msm::metal_msm::host::gpu::{ use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0}; -use ark_bn254::Fr as ScalarField; +use ark_bn254::Fq as BaseField; use ark_ff::{BigInt, PrimeField}; use metal::*; use num_bigint::{BigUint, RandBigInt}; @@ -25,11 +25,11 @@ pub fn test_mont_mul_15() { } pub fn do_test(log_limb_size: u32) { - let modulus_bits = ScalarField::MODULUS_BIT_SIZE as u32; + let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32; let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize; let r = calc_mont_radix(num_limbs, log_limb_size); - let p: BigUint = ScalarField::MODULUS.try_into().unwrap(); + let p: BigUint = BaseField::MODULUS.try_into().unwrap(); let nsafe = calc_nsafe(log_limb_size); let res = calc_rinv_and_n0(&p, &r, log_limb_size); @@ -43,9 +43,9 @@ pub fn do_test(log_limb_size: u32) { let b_r = &b * &r % &p; let expected = (&a * &b * &r) % &p; - let a_r_in_ark = ScalarField::from_bigint(a_r.clone().try_into().unwrap()).unwrap(); - let b_r_in_ark = ScalarField::from_bigint(b_r.clone().try_into().unwrap()).unwrap(); - let expected_in_ark = ScalarField::from_bigint(expected.clone().try_into().unwrap()).unwrap(); + let a_r_in_ark = BaseField::from_bigint(a_r.clone().try_into().unwrap()).unwrap(); + let b_r_in_ark = BaseField::from_bigint(b_r.clone().try_into().unwrap()).unwrap(); + let expected_in_ark = BaseField::from_bigint(expected.clone().try_into().unwrap()).unwrap(); let expected_limbs = expected_in_ark .into_bigint() .to_limbs(num_limbs, log_limb_size); @@ -61,7 +61,7 @@ pub fn do_test(log_limb_size: u32) { ); let p_buf = create_buffer( &device, - &ScalarField::MODULUS.to_limbs(num_limbs, log_limb_size), + &BaseField::MODULUS.to_limbs(num_limbs, log_limb_size), ); let result_buf = create_empty_buffer(&device, num_limbs); diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs index 6fe416b..d21751c 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs @@ -9,7 +9,7 @@ use crate::msm::metal_msm::host::gpu::{ use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_rinv_and_n0}; -use ark_bn254::Fr as ScalarField; +use ark_bn254::Fq as BaseField; use ark_ff::{BigInt, PrimeField}; use metal::*; use num_bigint::{BigUint, RandBigInt}; @@ -29,11 +29,11 @@ pub fn test_mont_mul_13() { pub fn do_test(log_limb_size: u32) { // Calculate num_limbs based on modulus size and limb size - let modulus_bits = ScalarField::MODULUS_BIT_SIZE as u32; + let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32; let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize; let r = calc_mont_radix(num_limbs, log_limb_size); - let p: BigUint = ScalarField::MODULUS.try_into().unwrap(); + let p: BigUint = BaseField::MODULUS.try_into().unwrap(); let res = calc_rinv_and_n0(&p, &r, log_limb_size); let n0 = res.1; @@ -46,9 +46,9 @@ pub fn do_test(log_limb_size: u32) { let b_r = &b * &r % &p; let expected = (&a * &b * &r) % &p; - let a_r_in_ark = ScalarField::from_bigint(a_r.clone().try_into().unwrap()).unwrap(); - let b_r_in_ark = ScalarField::from_bigint(b_r.clone().try_into().unwrap()).unwrap(); - let expected_in_ark = ScalarField::from_bigint(expected.clone().try_into().unwrap()).unwrap(); + let a_r_in_ark = BaseField::from_bigint(a_r.clone().try_into().unwrap()).unwrap(); + let b_r_in_ark = BaseField::from_bigint(b_r.clone().try_into().unwrap()).unwrap(); + let expected_in_ark = BaseField::from_bigint(expected.clone().try_into().unwrap()).unwrap(); let expected_limbs = expected_in_ark .into_bigint() .to_limbs(num_limbs, log_limb_size); @@ -64,7 +64,7 @@ pub fn do_test(log_limb_size: u32) { ); let p_buf = create_buffer( &device, - &ScalarField::MODULUS.to_limbs(num_limbs, log_limb_size), + &BaseField::MODULUS.to_limbs(num_limbs, log_limb_size), ); let result_buf = create_empty_buffer(&device, num_limbs); @@ -133,17 +133,17 @@ pub fn do_test(log_limb_size: u32) { pub fn test_number_conversions() { // Setup parameters let log_limb_size = 12; - let modulus_bits = ScalarField::MODULUS_BIT_SIZE as u32; + let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32; let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize; // Create test values using small numbers for clarity let original_biguint = BigUint::parse_bytes(b"123456789", 10).unwrap(); - // Convert BigUint to ScalarField + // Convert BigUint to BaseField let scalar_field_value = - ScalarField::from_bigint(original_biguint.clone().try_into().unwrap()).unwrap(); + BaseField::from_bigint(original_biguint.clone().try_into().unwrap()).unwrap(); - // Convert ScalarField to limbs + // Convert BaseField to limbs let limbs = scalar_field_value .into_bigint() .to_limbs(num_limbs, log_limb_size); @@ -168,7 +168,7 @@ pub fn test_number_conversions() { ]; for value in test_values { - let scalar = ScalarField::from_bigint(value.clone().try_into().unwrap()).unwrap(); + let scalar = BaseField::from_bigint(value.clone().try_into().unwrap()).unwrap(); let value_limbs = scalar.into_bigint().to_limbs(num_limbs, log_limb_size); let converted: BigUint = BigInt::from_limbs(&value_limbs, log_limb_size) .try_into()