Skip to content

Commit

Permalink
Merge pull request #41 from zkmopro/feat/metal/ff-reduce
Browse files Browse the repository at this point in the history
Feat/metal/ff reduce
  • Loading branch information
moven0831 authored Jan 30, 2025
2 parents 9e97c78 + b22fb3f commit aca53d1
Show file tree
Hide file tree
Showing 13 changed files with 322 additions and 148 deletions.
112 changes: 46 additions & 66 deletions mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal
Original file line number Diff line number Diff line change
Expand Up @@ -7,120 +7,104 @@ using namespace metal;

BigInt bigint_zero() {
BigInt s;
for (uint i = 0; i < NUM_LIMBS; i ++) {
for (uint i = 0; i < NUM_LIMBS; i++) {
s.limbs[i] = 0;
}
return s;
}

BigInt bigint_add_unsafe(
BigIntResult bigint_add_unsafe(
BigInt lhs,
BigInt rhs
) {
BigInt result;
BigIntResult res;
res.carry = 0;
uint mask = (1 << LOG_LIMB_SIZE) - 1;
uint carry = 0;

for (uint i = 0; i < NUM_LIMBS; i ++) {
uint c = lhs.limbs[i] + rhs.limbs[i] + carry;
result.limbs[i] = c & mask;
carry = c >> LOG_LIMB_SIZE;
for (uint i = 0; i < NUM_LIMBS; i++) {
uint c = lhs.limbs[i] + rhs.limbs[i] + res.carry;
res.value.limbs[i] = c & mask;
res.carry = c >> LOG_LIMB_SIZE;
}
return result;
return res;
}

BigIntWide bigint_add_wide(
BigIntResultWide bigint_add_wide(
BigInt lhs,
BigInt rhs
) {
BigIntWide result;
BigIntResultWide res;
res.carry = 0;
uint mask = (1 << LOG_LIMB_SIZE) - 1;
uint carry = 0;

for (uint i = 0; i < NUM_LIMBS; i ++) {
for (uint i = 0; i < NUM_LIMBS; i++) {
uint c = lhs.limbs[i] + rhs.limbs[i] + carry;
result.limbs[i] = c & mask;
res.value.limbs[i] = c & mask;
carry = c >> LOG_LIMB_SIZE;
}
result.limbs[NUM_LIMBS] = carry;

return result;
res.value.limbs[NUM_LIMBS] = carry;
res.carry = carry;
return res;
}

BigInt bigint_sub(
BigIntResult bigint_sub(
BigInt lhs,
BigInt rhs
) {
uint borrow = 0;

BigInt res;

for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - borrow;

if (lhs.limbs[i] < (rhs.limbs[i] + borrow)) {
res.limbs[i] = res.limbs[i] + TWO_POW_WORD_SIZE;
borrow = 1;
BigIntResult res;
res.carry = 0;
for (uint i = 0; i < NUM_LIMBS; i++) {
res.value.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - res.carry;
if (lhs.limbs[i] < rhs.limbs[i] + res.carry) {
res.value.limbs[i] += TWO_POW_WORD_SIZE;
res.carry = 1;
} else {
borrow = 0;
res.carry = 0;
}
}

return res;
}


BigIntWide bigint_sub_wide(
BigIntResultWide bigint_sub_wide(
BigIntWide lhs,
BigIntWide rhs
) {
uint borrow = 0;

BigIntWide res;

for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - borrow;

if (lhs.limbs[i] < (rhs.limbs[i] + borrow)) {
res.limbs[i] = res.limbs[i] + TWO_POW_WORD_SIZE;
borrow = 1;
BigIntResultWide res;
res.carry = 0;
for (uint i = 0; i < NUM_LIMBS; i++) {
res.value.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - res.carry;
if (lhs.limbs[i] < rhs.limbs[i] + res.carry) {
res.value.limbs[i] += TWO_POW_WORD_SIZE;
res.carry = 1;
} else {
borrow = 0;
res.carry = 0;
}
}

return res;
}

bool bigint_gte(
BigInt lhs,
BigInt rhs
) {
for (uint idx = 0; idx < NUM_LIMBS; idx ++) {
// for (uint i = NUM_LIMBS-1; i >= 0; i--) is troublesome from unknown reason
for (uint idx = 0; idx < NUM_LIMBS; idx++) {
uint i = NUM_LIMBS - 1 - idx;
if (lhs.limbs[i] < rhs.limbs[i]) {
return false;
} else if (lhs.limbs[i] > rhs.limbs[i]) {
return true;
}
if (lhs.limbs[i] < rhs.limbs[i]) return false;
else if (lhs.limbs[i] > rhs.limbs[i]) return true;
}

return true;
}

bool bigint_wide_gte(
BigIntWide lhs,
BigIntWide rhs
) {
for (uint idx = 0; idx < NUM_LIMBS_WIDE; idx ++) {
for (uint idx = 0; idx < NUM_LIMBS_WIDE; idx++) {
uint i = NUM_LIMBS_WIDE - 1 - idx;
if (lhs.limbs[i] < rhs.limbs[i]) {
return false;
} else if (lhs.limbs[i] > rhs.limbs[i]) {
return true;
}
if (lhs.limbs[i] < rhs.limbs[i]) return false;
else if (lhs.limbs[i] > rhs.limbs[i]) return true;
}

return true;
}

Expand All @@ -129,29 +113,25 @@ bool bigint_eq(
BigInt rhs
) {
for (uint i = 0; i < NUM_LIMBS; i++) {
if (lhs.limbs[i] != rhs.limbs[i]) {
return false;
}
if (lhs.limbs[i] != rhs.limbs[i]) return false;
}
return true;
}

bool is_bigint_zero(BigInt x) {
for (uint i = 0; i < NUM_LIMBS; i++) {
if (x.limbs[i] != 0) {
return false;
}
if (x.limbs[i] != 0) return false;
}
return true;
}

// Overload Operators
constexpr BigInt operator+(const BigInt lhs, const BigInt rhs) {
return bigint_add_unsafe(lhs, rhs);
return bigint_add_unsafe(lhs, rhs).value;
}

constexpr BigInt operator-(const BigInt lhs, const BigInt rhs) {
return bigint_sub(lhs, rhs);
return bigint_sub(lhs, rhs).value;
}

constexpr bool operator>=(const BigInt lhs, const BigInt rhs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ using namespace metal;
#include "bigint.metal"

kernel void run(
device BigInt* lhs [[ buffer(0) ]],
device BigInt* rhs [[ buffer(1) ]],
device BigInt* result [[ buffer(2) ]],
device BigInt* a [[ buffer(0) ]],
device BigInt* b [[ buffer(1) ]],
device BigInt* res [[ buffer(2) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a = *lhs;
BigInt b = *rhs;
BigInt res = a + b;
*result = res;
BigIntResult result = bigint_add_unsafe(*a, *b);
*res = result.value;
}
12 changes: 5 additions & 7 deletions mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ using namespace metal;
#include "bigint.metal"

kernel void run(
device BigInt* lhs [[ buffer(0) ]],
device BigInt* rhs [[ buffer(1) ]],
device BigIntWide* result [[ buffer(2) ]],
device BigInt* a [[ buffer(0) ]],
device BigInt* b [[ buffer(1) ]],
device BigIntWide* res [[ buffer(2) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a = *lhs;
BigInt b = *rhs;
BigIntWide res = bigint_add_wide(a, b);
*result = res;
BigIntResultWide result = bigint_add_wide(*a, *b);
*res = result.value;
}
12 changes: 5 additions & 7 deletions mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ using namespace metal;
#include "bigint.metal"

kernel void run(
device BigInt* lhs [[ buffer(0) ]],
device BigInt* rhs [[ buffer(1) ]],
device BigInt* result [[ buffer(2) ]],
device BigInt* a [[ buffer(0) ]],
device BigInt* b [[ buffer(1) ]],
device BigInt* res [[ buffer(2) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a = *lhs;
BigInt b = *rhs;
BigInt res = a - b;
*result = res;
BigIntResult result = bigint_sub(*a, *b);
*res = result.value;
}
10 changes: 3 additions & 7 deletions mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,8 @@ Jacobian jacobian_add_2007_bl(
Jacobian b,
BigInt p
) {
if (is_jacobian_zero(a)) {
return b;
}
if (is_jacobian_zero(b)) {
return a;
}
if (is_jacobian_zero(a)) return b;
if (is_jacobian_zero(b)) return a;
if (a == b) return jacobian_dbl_2009_l(a, p);

BigInt x1 = a.x;
Expand Down Expand Up @@ -199,4 +195,4 @@ Jacobian jacobian_scalar_mul(
}

return result;
}
}
4 changes: 2 additions & 2 deletions mopro-msm/src/msm/metal_msm/shader/curve/utils.metal
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ bool jacobian_eq(
}

bool is_jacobian_zero(Jacobian p) {
return (is_bigint_zero(p.z));
return is_bigint_zero(p.z);
}

constexpr bool operator==(const Jacobian lhs, const Jacobian rhs) {
return jacobian_eq(lhs, rhs);
}
}
52 changes: 21 additions & 31 deletions mopro-msm/src/msm/metal_msm/shader/field/ff.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,39 @@ using namespace metal;
#include <metal_math>
#include "../bigint/bigint.metal"

BigInt ff_reduce(
BigInt a,
BigInt p
) {
BigIntResult res = bigint_sub(a, p);
if (bigint_gte(res.value, p)) return a;
return res.value;
}

BigInt ff_add(
BigInt a,
BigInt b,
BigInt p
) {
BigInt sum = a + b;

BigInt res;
if (sum >= p) {
// s = a + b - p
BigInt s = sum - p;
for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = s.limbs[i];
}
}
else {
for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = sum.limbs[i];
}
}
return res;
BigIntResult res = bigint_add_unsafe(a, b);
return ff_reduce(res.value, p);
}

BigInt ff_sub(
BigInt a,
BigInt b,
BigInt p
) {
// if a >= b
if (a >= b) {
// a - b
BigInt res = a - b;
for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = res.limbs[i];
}
return res;
} else {
bool a_gte_b = bigint_gte(a, b);

if (a_gte_b) {
BigIntResult res = bigint_sub(a, b);
return res.value;
}
else {
// p - (b - a)
BigInt r = b - a;
BigInt res = p - r;
for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = res.limbs[i];
}
return res;
BigIntResult diff = bigint_sub(b, a);
BigIntResult res = bigint_sub(p, diff.value);
return res.value;
}
}
13 changes: 4 additions & 9 deletions mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@ using namespace metal;
#include "ff.metal"

kernel void run(
device BigInt* lhs [[ buffer(0) ]],
device BigInt* rhs [[ buffer(1) ]],
device BigInt* a [[ buffer(0) ]],
device BigInt* b [[ buffer(1) ]],
device BigInt* prime [[ buffer(2) ]],
device BigInt* result [[ buffer(3) ]],
device BigInt* res [[ buffer(3) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a = *lhs;
BigInt b = *rhs;
BigInt p = *prime;

BigInt res = ff_add(a, b, p);
*result = res;
*res = ff_add(*a, *b, *prime);
}
15 changes: 15 additions & 0 deletions mopro-msm/src/msm/metal_msm/shader/field/ff_reduce.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// source: https://github.com/geometryxyz/msl-secp256k1

using namespace metal;
#include <metal_stdlib>
#include <metal_math>
#include "ff.metal"

kernel void run(
device BigInt* a [[ buffer(0) ]],
device BigInt* prime [[ buffer(1) ]],
device BigInt* res [[ buffer(2) ]],
uint gid [[ thread_position_in_grid ]]
) {
*res = ff_reduce(*a, *prime);
}
Loading

0 comments on commit aca53d1

Please sign in to comment.