From 9a4ec68aa685c6a240d4873e86834f2ec2d31bbd Mon Sep 17 00:00:00 2001 From: moven0831 Date: Thu, 30 Jan 2025 19:28:32 +0900 Subject: [PATCH 1/2] refactor(bigint): have add/sub return carry signals with value-based approach --- .../msm/metal_msm/shader/bigint/bigint.metal | 112 +++++++----------- .../shader/bigint/bigint_add_unsafe.metal | 12 +- .../shader/bigint/bigint_add_wide.metal | 12 +- .../metal_msm/shader/bigint/bigint_sub.metal | 12 +- .../src/msm/metal_msm/shader/misc/types.metal | 10 ++ 5 files changed, 71 insertions(+), 87 deletions(-) diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal index da5b1f2..d585508 100644 --- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal @@ -7,88 +7,79 @@ using namespace metal; BigInt bigint_zero() { BigInt s; - for (uint i = 0; i < NUM_LIMBS; i ++) { + for (uint i = 0; i < NUM_LIMBS; i++) { s.limbs[i] = 0; } return s; } -BigInt bigint_add_unsafe( +BigIntResult bigint_add_unsafe( BigInt lhs, BigInt rhs ) { - BigInt result; + BigIntResult res; + res.carry = 0; uint mask = (1 << LOG_LIMB_SIZE) - 1; - uint carry = 0; - - for (uint i = 0; i < NUM_LIMBS; i ++) { - uint c = lhs.limbs[i] + rhs.limbs[i] + carry; - result.limbs[i] = c & mask; - carry = c >> LOG_LIMB_SIZE; + for (uint i = 0; i < NUM_LIMBS; i++) { + uint c = lhs.limbs[i] + rhs.limbs[i] + res.carry; + res.value.limbs[i] = c & mask; + res.carry = c >> LOG_LIMB_SIZE; } - return result; + return res; } -BigIntWide bigint_add_wide( +BigIntResultWide bigint_add_wide( BigInt lhs, BigInt rhs ) { - BigIntWide result; + BigIntResultWide res; + res.carry = 0; uint mask = (1 << LOG_LIMB_SIZE) - 1; uint carry = 0; - - for (uint i = 0; i < NUM_LIMBS; i ++) { + for (uint i = 0; i < NUM_LIMBS; i++) { uint c = lhs.limbs[i] + rhs.limbs[i] + carry; - result.limbs[i] = c & mask; + res.value.limbs[i] = c & mask; carry = c >> LOG_LIMB_SIZE; } - result.limbs[NUM_LIMBS] = carry; - - return result; + res.value.limbs[NUM_LIMBS] = carry; + res.carry = carry; + return res; } -BigInt bigint_sub( +BigIntResult bigint_sub( BigInt lhs, BigInt rhs ) { - uint borrow = 0; - - BigInt res; - - for (uint i = 0; i < NUM_LIMBS; i ++) { - res.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - borrow; - - if (lhs.limbs[i] < (rhs.limbs[i] + borrow)) { - res.limbs[i] = res.limbs[i] + TWO_POW_WORD_SIZE; - borrow = 1; + BigIntResult res; + res.carry = 0; + for (uint i = 0; i < NUM_LIMBS; i++) { + res.value.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - res.carry; + if (lhs.limbs[i] < rhs.limbs[i] + res.carry) { + res.value.limbs[i] += TWO_POW_WORD_SIZE; + res.carry = 1; } else { - borrow = 0; + res.carry = 0; } } - return res; } -BigIntWide bigint_sub_wide( +BigIntResultWide bigint_sub_wide( BigIntWide lhs, BigIntWide rhs ) { - uint borrow = 0; - - BigIntWide res; - - for (uint i = 0; i < NUM_LIMBS; i ++) { - res.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - borrow; - - if (lhs.limbs[i] < (rhs.limbs[i] + borrow)) { - res.limbs[i] = res.limbs[i] + TWO_POW_WORD_SIZE; - borrow = 1; + BigIntResultWide res; + res.carry = 0; + for (uint i = 0; i < NUM_LIMBS; i++) { + res.value.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - res.carry; + if (lhs.limbs[i] < rhs.limbs[i] + res.carry) { + res.value.limbs[i] += TWO_POW_WORD_SIZE; + res.carry = 1; } else { - borrow = 0; + res.carry = 0; } } - return res; } @@ -96,15 +87,12 @@ bool bigint_gte( BigInt lhs, BigInt rhs ) { - for (uint idx = 0; idx < NUM_LIMBS; idx ++) { + // for (uint i = NUM_LIMBS-1; i >= 0; i--) is troublesome from unknown reason + for (uint idx = 0; idx < NUM_LIMBS; idx++) { uint i = NUM_LIMBS - 1 - idx; - if (lhs.limbs[i] < rhs.limbs[i]) { - return false; - } else if (lhs.limbs[i] > rhs.limbs[i]) { - return true; - } + if (lhs.limbs[i] < rhs.limbs[i]) return false; + else if (lhs.limbs[i] > rhs.limbs[i]) return true; } - return true; } @@ -112,15 +100,11 @@ bool bigint_wide_gte( BigIntWide lhs, BigIntWide rhs ) { - for (uint idx = 0; idx < NUM_LIMBS_WIDE; idx ++) { + for (uint idx = 0; idx < NUM_LIMBS_WIDE; idx++) { uint i = NUM_LIMBS_WIDE - 1 - idx; - if (lhs.limbs[i] < rhs.limbs[i]) { - return false; - } else if (lhs.limbs[i] > rhs.limbs[i]) { - return true; - } + if (lhs.limbs[i] < rhs.limbs[i]) return false; + else if (lhs.limbs[i] > rhs.limbs[i]) return true; } - return true; } @@ -129,29 +113,25 @@ bool bigint_eq( BigInt rhs ) { for (uint i = 0; i < NUM_LIMBS; i++) { - if (lhs.limbs[i] != rhs.limbs[i]) { - return false; - } + if (lhs.limbs[i] != rhs.limbs[i]) return false; } return true; } bool is_bigint_zero(BigInt x) { for (uint i = 0; i < NUM_LIMBS; i++) { - if (x.limbs[i] != 0) { - return false; - } + if (x.limbs[i] != 0) return false; } return true; } // Overload Operators constexpr BigInt operator+(const BigInt lhs, const BigInt rhs) { - return bigint_add_unsafe(lhs, rhs); + return bigint_add_unsafe(lhs, rhs).value; } constexpr BigInt operator-(const BigInt lhs, const BigInt rhs) { - return bigint_sub(lhs, rhs); + return bigint_sub(lhs, rhs).value; } constexpr bool operator>=(const BigInt lhs, const BigInt rhs) { diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal index ce54857..6b967d6 100644 --- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal @@ -6,13 +6,11 @@ using namespace metal; #include "bigint.metal" kernel void run( - device BigInt* lhs [[ buffer(0) ]], - device BigInt* rhs [[ buffer(1) ]], - device BigInt* result [[ buffer(2) ]], + device BigInt* a [[ buffer(0) ]], + device BigInt* b [[ buffer(1) ]], + device BigInt* res [[ buffer(2) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a = *lhs; - BigInt b = *rhs; - BigInt res = a + b; - *result = res; + BigIntResult result = bigint_add_unsafe(*a, *b); + *res = result.value; } diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal index f150055..9a66517 100644 --- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal @@ -6,13 +6,11 @@ using namespace metal; #include "bigint.metal" kernel void run( - device BigInt* lhs [[ buffer(0) ]], - device BigInt* rhs [[ buffer(1) ]], - device BigIntWide* result [[ buffer(2) ]], + device BigInt* a [[ buffer(0) ]], + device BigInt* b [[ buffer(1) ]], + device BigIntWide* res [[ buffer(2) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a = *lhs; - BigInt b = *rhs; - BigIntWide res = bigint_add_wide(a, b); - *result = res; + BigIntResultWide result = bigint_add_wide(*a, *b); + *res = result.value; } diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal index 9621d1c..95a8092 100644 --- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal @@ -6,13 +6,11 @@ using namespace metal; #include "bigint.metal" kernel void run( - device BigInt* lhs [[ buffer(0) ]], - device BigInt* rhs [[ buffer(1) ]], - device BigInt* result [[ buffer(2) ]], + device BigInt* a [[ buffer(0) ]], + device BigInt* b [[ buffer(1) ]], + device BigInt* res [[ buffer(2) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a = *lhs; - BigInt b = *rhs; - BigInt res = a - b; - *result = res; + BigIntResult result = bigint_sub(*a, *b); + *res = result.value; } diff --git a/mopro-msm/src/msm/metal_msm/shader/misc/types.metal b/mopro-msm/src/msm/metal_msm/shader/misc/types.metal index 8bc7409..9ebf4bf 100644 --- a/mopro-msm/src/msm/metal_msm/shader/misc/types.metal +++ b/mopro-msm/src/msm/metal_msm/shader/misc/types.metal @@ -11,6 +11,16 @@ struct BigIntWide { array limbs; }; +struct BigIntResult { + BigInt value; + bool carry; +}; + +struct BigIntResultWide { + BigIntWide value; + bool carry; +}; + struct Jacobian { BigInt x; BigInt y; From b22fb3f6dfe6c7b19c30c7dea33416fa66a9c79b Mon Sep 17 00:00:00 2001 From: moven0831 Date: Thu, 30 Jan 2025 19:29:26 +0900 Subject: [PATCH 2/2] feat(field): add ff reduction function and tests --- .../msm/metal_msm/shader/curve/jacobian.metal | 10 +- .../msm/metal_msm/shader/curve/utils.metal | 4 +- .../src/msm/metal_msm/shader/field/ff.metal | 52 ++--- .../msm/metal_msm/shader/field/ff_add.metal | 13 +- .../metal_msm/shader/field/ff_reduce.metal | 15 ++ .../msm/metal_msm/shader/field/ff_sub.metal | 16 +- .../msm/metal_msm/tests/field/ff_reduce.rs | 200 ++++++++++++++++++ .../src/msm/metal_msm/tests/field/mod.rs | 2 + 8 files changed, 251 insertions(+), 61 deletions(-) create mode 100644 mopro-msm/src/msm/metal_msm/shader/field/ff_reduce.metal create mode 100644 mopro-msm/src/msm/metal_msm/tests/field/ff_reduce.rs diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal index a61d63e..a5043c8 100644 --- a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal +++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal @@ -50,12 +50,8 @@ Jacobian jacobian_add_2007_bl( Jacobian b, BigInt p ) { - if (is_jacobian_zero(a)) { - return b; - } - if (is_jacobian_zero(b)) { - return a; - } + if (is_jacobian_zero(a)) return b; + if (is_jacobian_zero(b)) return a; if (a == b) return jacobian_dbl_2009_l(a, p); BigInt x1 = a.x; @@ -199,4 +195,4 @@ Jacobian jacobian_scalar_mul( } return result; -} \ No newline at end of file +} diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/utils.metal b/mopro-msm/src/msm/metal_msm/shader/curve/utils.metal index 1e5cacd..112fe94 100644 --- a/mopro-msm/src/msm/metal_msm/shader/curve/utils.metal +++ b/mopro-msm/src/msm/metal_msm/shader/curve/utils.metal @@ -21,9 +21,9 @@ bool jacobian_eq( } bool is_jacobian_zero(Jacobian p) { - return (is_bigint_zero(p.z)); + return is_bigint_zero(p.z); } constexpr bool operator==(const Jacobian lhs, const Jacobian rhs) { return jacobian_eq(lhs, rhs); -} \ No newline at end of file +} diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff.metal index f56f160..928de0c 100644 --- a/mopro-msm/src/msm/metal_msm/shader/field/ff.metal +++ b/mopro-msm/src/msm/metal_msm/shader/field/ff.metal @@ -6,27 +6,22 @@ using namespace metal; #include #include "../bigint/bigint.metal" +BigInt ff_reduce( + BigInt a, + BigInt p +) { + BigIntResult res = bigint_sub(a, p); + if (bigint_gte(res.value, p)) return a; + return res.value; +} + BigInt ff_add( BigInt a, BigInt b, BigInt p ) { - BigInt sum = a + b; - - BigInt res; - if (sum >= p) { - // s = a + b - p - BigInt s = sum - p; - for (uint i = 0; i < NUM_LIMBS; i ++) { - res.limbs[i] = s.limbs[i]; - } - } - else { - for (uint i = 0; i < NUM_LIMBS; i ++) { - res.limbs[i] = sum.limbs[i]; - } - } - return res; + BigIntResult res = bigint_add_unsafe(a, b); + return ff_reduce(res.value, p); } BigInt ff_sub( @@ -34,21 +29,16 @@ BigInt ff_sub( BigInt b, BigInt p ) { - // if a >= b - if (a >= b) { - // a - b - BigInt res = a - b; - for (uint i = 0; i < NUM_LIMBS; i ++) { - res.limbs[i] = res.limbs[i]; - } - return res; - } else { + bool a_gte_b = bigint_gte(a, b); + + if (a_gte_b) { + BigIntResult res = bigint_sub(a, b); + return res.value; + } + else { // p - (b - a) - BigInt r = b - a; - BigInt res = p - r; - for (uint i = 0; i < NUM_LIMBS; i ++) { - res.limbs[i] = res.limbs[i]; - } - return res; + BigIntResult diff = bigint_sub(b, a); + BigIntResult res = bigint_sub(p, diff.value); + return res.value; } } diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal index a10a17e..7b14e28 100644 --- a/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal +++ b/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal @@ -6,16 +6,11 @@ using namespace metal; #include "ff.metal" kernel void run( - device BigInt* lhs [[ buffer(0) ]], - device BigInt* rhs [[ buffer(1) ]], + device BigInt* a [[ buffer(0) ]], + device BigInt* b [[ buffer(1) ]], device BigInt* prime [[ buffer(2) ]], - device BigInt* result [[ buffer(3) ]], + device BigInt* res [[ buffer(3) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a = *lhs; - BigInt b = *rhs; - BigInt p = *prime; - - BigInt res = ff_add(a, b, p); - *result = res; + *res = ff_add(*a, *b, *prime); } diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff_reduce.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff_reduce.metal new file mode 100644 index 0000000..95e2431 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/field/ff_reduce.metal @@ -0,0 +1,15 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "ff.metal" + +kernel void run( + device BigInt* a [[ buffer(0) ]], + device BigInt* prime [[ buffer(1) ]], + device BigInt* res [[ buffer(2) ]], + uint gid [[ thread_position_in_grid ]] +) { + *res = ff_reduce(*a, *prime); +} diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal index 5f32e09..095d9a0 100644 --- a/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal +++ b/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal @@ -6,19 +6,11 @@ using namespace metal; #include "ff.metal" kernel void run( - device BigInt* lhs [[ buffer(0) ]], - device BigInt* rhs [[ buffer(1) ]], + device BigInt* a [[ buffer(0) ]], + device BigInt* b [[ buffer(1) ]], device BigInt* prime [[ buffer(2) ]], - device BigInt* result [[ buffer(3) ]], + device BigInt* res [[ buffer(3) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - BigInt p; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; - p.limbs = prime->limbs; - - BigInt res = ff_sub(a, b, p); - result->limbs = res.limbs; + *res = ff_sub(*a, *b, *prime); } diff --git a/mopro-msm/src/msm/metal_msm/tests/field/ff_reduce.rs b/mopro-msm/src/msm/metal_msm/tests/field/ff_reduce.rs new file mode 100644 index 0000000..05b6738 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/tests/field/ff_reduce.rs @@ -0,0 +1,200 @@ +use crate::msm::metal_msm::host::gpu::{ + create_buffer, create_empty_buffer, get_default_device, read_buffer, +}; +use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; +use crate::msm::metal_msm::utils::limbs_conversion::GenericLimbConversion; +use ark_bn254::Fq as BaseField; +use ark_ff::{BigInt, BigInteger, PrimeField, UniformRand}; +use ark_std::rand; +use metal::*; + +#[test] +#[serial_test::serial] +pub fn test_ff_reduce_a_less_than_p() { + let log_limb_size = 16; + let num_limbs = 16; + + // Scalar field modulus for bn254 + let p = BaseField::MODULUS; + + let mut rng = rand::thread_rng(); + let raw_a = BigInt::<4>::rand(&mut rng); + + // Ensure a is non-negative + assert!(raw_a >= BigInt::from(0u64), "a must be non-negative"); + + // Perform a % p + let mut a = raw_a.clone(); + + // While result >= p, subtract p + while a >= p { + a.sub_with_borrow(&p); + } + // Ensure expected is non-negative and less than p + assert!(a >= BigInt::from(0u64), "a must be non-negative"); + assert!(a < p, "a must be less than p"); + + let a_limbs = a.to_limbs(num_limbs, log_limb_size); + let device = get_default_device(); + let a_buf = create_buffer(&device, &a_limbs); + let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size)); + let result_buf = create_empty_buffer(&device, num_limbs); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + write_constants( + "../mopro-msm/src/msm/metal_msm/shader", + num_limbs, + log_limb_size, + 0, + 0, + ); + let library_path = compile_metal( + "../mopro-msm/src/msm/metal_msm/shader/field", + "ff_reduce.metal", + ); + let library = device.new_library_with_file(library_path).unwrap(); + let kernel = library.get_function("run", None).unwrap(); + + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&kernel)); + + let pipeline_state = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + encoder.set_compute_pipeline_state(&pipeline_state); + encoder.set_buffer(0, Some(&a_buf), 0); + encoder.set_buffer(1, Some(&p_buf), 0); + encoder.set_buffer(2, Some(&result_buf), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let result_limbs: Vec = read_buffer(&result_buf, num_limbs); + let result = BigInt::from_limbs(&result_limbs, log_limb_size); + + assert!(result == a); + assert!(result_limbs == a_limbs); +} + +#[test] +#[serial_test::serial] +pub fn test_ff_reduce_a_greater_than_p_less_than_2p() { + let log_limb_size = 16; + let num_limbs = 16; + + // Scalar field modulus for bn254 + let p = BaseField::MODULUS; + let mut two_p = p.clone(); + two_p.add_with_carry(&p); + + let mut rng = rand::thread_rng(); + let raw_a = BigInt::<4>::rand(&mut rng); + + // Ensure a is non-negative + assert!(raw_a >= BigInt::from(0u64), "a must be non-negative"); + + // Perform a % p + let mut a = raw_a.clone(); + + // While result >= p, subtract p + while a >= p { + a.sub_with_borrow(&p); + } + let expected = a.clone(); + let expected_limbs = a.to_limbs(num_limbs, log_limb_size); + + // Adding p to a to ensure a is in the range [p, 2p) + a.add_with_carry(&p); + + // Ensure expected is non-negative and less than p + assert!(a >= BigInt::from(0u64), "a must be non-negative"); + assert!(a < two_p, "a must be less than 2p"); + assert!(a >= p, "a must be greater than or equal to p"); + + let a_limbs = a.to_limbs(num_limbs, log_limb_size); + let device = get_default_device(); + let a_buf = create_buffer(&device, &a_limbs); + let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size)); + let result_buf = create_empty_buffer(&device, num_limbs); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + write_constants( + "../mopro-msm/src/msm/metal_msm/shader", + num_limbs, + log_limb_size, + 0, + 0, + ); + let library_path = compile_metal( + "../mopro-msm/src/msm/metal_msm/shader/field", + "ff_reduce.metal", + ); + let library = device.new_library_with_file(library_path).unwrap(); + let kernel = library.get_function("run", None).unwrap(); + + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&kernel)); + + let pipeline_state = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + encoder.set_compute_pipeline_state(&pipeline_state); + encoder.set_buffer(0, Some(&a_buf), 0); + encoder.set_buffer(1, Some(&p_buf), 0); + encoder.set_buffer(2, Some(&result_buf), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let result_limbs: Vec = read_buffer(&result_buf, num_limbs); + let result = BigInt::from_limbs(&result_limbs, log_limb_size); + + assert!(result == expected); + assert!(result_limbs == expected_limbs); +} diff --git a/mopro-msm/src/msm/metal_msm/tests/field/mod.rs b/mopro-msm/src/msm/metal_msm/tests/field/mod.rs index 3be2ba6..6ce8cbe 100644 --- a/mopro-msm/src/msm/metal_msm/tests/field/mod.rs +++ b/mopro-msm/src/msm/metal_msm/tests/field/mod.rs @@ -1,4 +1,6 @@ #[cfg(test)] pub mod ff_add; #[cfg(test)] +pub mod ff_reduce; +#[cfg(test)] pub mod ff_sub;