Skip to content

Commit

Permalink
Implement for chapter 3 (#3)
Browse files Browse the repository at this point in the history
1. univariate lagrange polynomial
    1. evaluate mpoly
2. Fix univariable lagrange interpolate poly.    

* TODO
   1. The lagrange of mpoly havn't done.
  • Loading branch information
SuccinctPaul authored Jul 23, 2023
1 parent 9da0757 commit e783285
Show file tree
Hide file tree
Showing 17 changed files with 540 additions and 58 deletions.
2 changes: 1 addition & 1 deletion 2_Freivalds_Algorithm/src/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Matrix {
}

fn get_columns(&self, column_index: usize) -> Vec<Scalar> {
assert!(0 <= column_index || self.cols > column_index);
assert!(self.cols > column_index);

self.values
.iter()
Expand Down
10 changes: 9 additions & 1 deletion 2_univariate_lagrange_interpolation/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,12 @@ ff = "0.13.0"
bls12_381 = "0.8.0"
rand = "0.8.5"
rand_core = { version = "0.6.4", default-features = false, features = ["std"] }
rayon = "1.7.0"
rayon = "1.7.0"

[dev-dependencies]
criterion = "0.3"


[[bench]]
name = "lagrange_interpolate"
harness = false
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#[macro_use]
extern crate criterion;

use bls12_381::Scalar;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use ff::{Field, PrimeField};
use rand_core::OsRng;
use univariate_lagrange_interpolation::polynomial::Polynomial;

fn bench_lagrange_interpolate(c: &mut Criterion) {
let MIN_K: u32 = std::env::var("DEGREE")
.unwrap_or_else(|_| "16".to_string())
.parse()
.expect("Cannot parse DEGREE env var as u32");

const MAX_K: u32 = 19;

// values
let max_n = 1 << MAX_K;
let domain: Vec<Scalar> = (0..max_n).map(|i| Scalar::from_u128(i)).collect::<Vec<_>>();
let values: Vec<Scalar> = (0..max_n)
.map(|_| Scalar::random(OsRng))
.collect::<Vec<_>>();

let mut group = c.benchmark_group("lagrange_interpolate");

for k in MIN_K..=MAX_K {
let n: u128 = 1 << k;

let x = &domain[..n];
let y = &values[..n];

group.bench_function(BenchmarkId::new("k", k), |b| {
b.iter(|| Polynomial::lagrange_interpolate(x.clone(), y.clone()));
});
}

group.finish();
}

criterion_group!(benches, bench_lagrange_interpolate);
criterion_main!(benches);
111 changes: 60 additions & 51 deletions 2_univariate_lagrange_interpolation/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,67 +1,76 @@
use crate::polynomial::Polynomial;
use bls12_381::Scalar;
use ff::PrimeField;
#![allow(non_snake_case)]

mod polynomial;
mod utils;
pub mod polynomial;

/// Encode vector into polynomial.
// TODO: can do a bench for diff impl
// eg:
// 1: https://github.com/Neptune-Crypto/twenty-first/twenty-first/src/shared_math/polynomial.rs#lagrange_interpolate
// 2. halo2's
// 3. arkwork's
// 4. lambda-work's

#[test]
fn encode() {
let two = Scalar::one().add(&Scalar::one());
#[cfg(test)]
mod test {
use crate::polynomial::Polynomial;
use bls12_381::Scalar;
use ff::PrimeField;
use std::ops::Sub;
use std::time::Instant;

// p = 1 + 2x + x^2
let a = vec![Scalar::one(), two, Scalar::one()];
// Encode vector into polynomial.
#[test]
fn encode() {
let two = Scalar::one().add(&Scalar::one());

let poly = Polynomial::encode(a);
// p = 1 + 2x + x^2
let a = vec![Scalar::one(), two, Scalar::one()];

let z = poly.evaluate(Scalar::one());
let poly = Polynomial::encode(a);

assert_eq!(Scalar::from_u128(4), z);
let z = poly.evaluate(Scalar::one());

let z = poly.evaluate(two.double());
assert_eq!(Scalar::from_u128(25), z);
assert_eq!(Scalar::from_u128(4), z);

for i in 1..10 {
println!("{:?}", poly.evaluate(Scalar::from_u128(i)));
}
}
let z = poly.evaluate(two.double());
assert_eq!(Scalar::from_u128(25), z);

#[test]
fn lagrange_interpolate() {
// aim: p = 1 + 2x + x^2
for i in 1..10 {
println!("{:?}", poly.evaluate(Scalar::from_u128(i)));
}
}

let domain = vec![
Scalar::from_u128(1),
Scalar::from_u128(2),
Scalar::from_u128(3),
Scalar::from_u128(4),
Scalar::from_u128(5),
Scalar::from_u128(6),
Scalar::from_u128(7),
Scalar::from_u128(8),
Scalar::from_u128(9),
];
let evals = vec![
Scalar::from_u128(4),
Scalar::from_u128(9),
Scalar::from_u128(10),
Scalar::from_u128(19),
Scalar::from_u128(24),
Scalar::from_u128(31),
Scalar::from_u128(40),
Scalar::from_u128(51),
Scalar::from_u128(64),
];
#[test]
fn lagrange_interpolate() {
// aim: p = 1 + 2x + x^2

let poly = Polynomial::lagrange_interpolate(domain.clone(), evals.clone());
let domain = vec![
Scalar::from_u128(1),
Scalar::from_u128(2),
Scalar::from_u128(3),
Scalar::from_u128(4),
Scalar::from_u128(5),
Scalar::from_u128(6),
Scalar::from_u128(7),
Scalar::from_u128(8),
Scalar::from_u128(9),
];
let evals = vec![
Scalar::from_u128(4),
Scalar::from_u128(9),
Scalar::from_u128(10),
Scalar::from_u128(19),
Scalar::from_u128(24),
Scalar::from_u128(31),
Scalar::from_u128(40),
Scalar::from_u128(51),
Scalar::from_u128(64),
];

let z = poly.evaluate(Scalar::from_u128(3));
println!("{:?}", z);
let poly = Polynomial::lagrange_interpolate(domain.clone(), evals.clone());

// todo meet errors
for (x, y) in domain.iter().zip(evals) {
assert_eq!(poly.evaluate(*x), y);
for (x, y) in domain.iter().zip(evals) {
assert_eq!(poly.evaluate(*x), y);
}
println!("pass");
}
}
92 changes: 89 additions & 3 deletions 2_univariate_lagrange_interpolation/src/polynomial.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use bls12_381::Scalar;
use ff::BatchInvert;
use rayon::{current_num_threads, scope, Scope};
use std::iter::Scan;
use rayon::{current_num_threads, scope};

// p(x) = = a_0 + a_1 * X + ... + a_n * X^(n-1)
//
Expand Down Expand Up @@ -34,6 +33,92 @@ impl Polynomial {
pub fn lagrange_interpolate(domains: Vec<Scalar>, evals: Vec<Scalar>) -> Self {
assert_eq!(domains.len(), evals.len());

if evals.len() == 1 {
// Constant polynomial
Self {
coeffs: vec![evals[0]],
}
} else {
let poly_size = domains.len();
let lag_basis_poly_size = poly_size - 1;

// 1. divisors = vec(x_j - x_k). prepare for L_j(X)=∏(X−x_k)/(x_j−x_k)
let mut divisors = Vec::with_capacity(poly_size);
for (j, x_j) in domains.iter().enumerate() {
// divisor_j
let mut divisor = Vec::with_capacity(lag_basis_poly_size);
// obtain domain for x_k
for x_k in domains
.iter()
.enumerate()
.filter(|&(k, _)| k != j)
.map(|(_, x)| x)
{
divisor.push(*x_j - x_k);
}
divisors.push(divisor);
}
// Inverse (x_j - x_k)^(-1) for each j != k to compute L_j(X)=∏(X−x_k)/(x_j−x_k)
divisors
.iter_mut()
.flat_map(|v| v.iter_mut())
.batch_invert();

// 2. Calculate L_j(X) : L_j(X)=∏(X−x_k) divisors_j
let mut L_j_vec: Vec<Vec<Scalar>> = Vec::with_capacity(poly_size);

for (j, divisor_j) in divisors.into_iter().enumerate() {
let mut L_j: Vec<Scalar> = Vec::with_capacity(poly_size);
L_j.push(Scalar::one());

// (X−x_k) * divisors_j
let mut product = Vec::with_capacity(lag_basis_poly_size);

// obtain domain for x_k
for (x_k, divisor) in domains
.iter()
.enumerate()
.filter(|&(k, _)| k != j)
.map(|(_, x)| x)
.zip(divisor_j.into_iter())
{
product.resize(L_j.len() + 1, Scalar::zero());

// loop (poly_size + 1) round
// calculate L_j(X)=∏(X−x_k) divisors_j with coefficient form.
for ((a, b), product) in L_j
.iter()
.chain(std::iter::once(&Scalar::zero()))
.zip(std::iter::once(&Scalar::zero()).chain(L_j.iter()))
.zip(product.iter_mut())
{
*product = *a * (-divisor * x_k) + *b * divisor;
}
std::mem::swap(&mut L_j, &mut product);
}

assert_eq!(L_j.len(), poly_size);
assert_eq!(product.len(), poly_size - 1);

L_j_vec.push(L_j);
}

// p(x)=∑y_j⋅L_j(X) in coefficients
let mut final_poly = vec![Scalar::zero(); poly_size];
// 3. p(x)=∑y_j⋅L_j(X)
for (L_j, y_j) in L_j_vec.iter().zip(evals) {
for (final_coeff, L_j_coeff) in final_poly.iter_mut().zip(L_j.into_iter()) {
*final_coeff += L_j_coeff * y_j;
}
}
Self { coeffs: final_poly }
}
}

// todo parallel
pub fn lagrange_interpolate_parallel(domains: Vec<Scalar>, evals: Vec<Scalar>) -> Self {
assert_eq!(domains.len(), evals.len());

if evals.len() == 1 {
// Constant polynomial
Self {
Expand Down Expand Up @@ -121,8 +206,10 @@ impl Polynomial {
let coeffs = self.coeffs.clone();
let poly_size = self.coeffs.len();

// p(x) = = a_0 + a_1 * X + ... + a_n * X^(n-1), revert it and fold sum it
fn eval(poly: &[Scalar], point: Scalar) -> Scalar {
poly.iter()
.rev()
.fold(Scalar::zero(), |acc, coeff| acc * point + coeff)
}

Expand All @@ -148,4 +235,3 @@ impl Polynomial {
}
}
}
// canonical set of inputs
1 change: 0 additions & 1 deletion 2_univariate_lagrange_interpolation/src/utils.rs

This file was deleted.

16 changes: 16 additions & 0 deletions 3_multilinear_lagrange_interpolation/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "multilinear_lagrange_interpolation"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ff = "0.13.0"
bls12_381 = "0.8.0"
rand = "0.8.5"
rand_core = { version = "0.6.4", default-features = false, features = ["std"] }
rayon = "1.7.0"

[dev-dependencies]
criterion = "0.3"
4 changes: 4 additions & 0 deletions 3_multilinear_lagrange_interpolation/reference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Multivariable Polynomial Reference
1. https://github.com/int-e/twenty-first/blob/cbda032f2c7b3ba44aa5582a1057c6b8b32e4ab6/twenty-first/src/shared_math/mpolynomial.rs
2. https://github.com/benruijl/symbolica/blob/cfb96196dd6f8dcae813157f9f42ad9eaa64a4ab/src/poly/polynomial.rs
3. https://github.com/arkworks-rs/algebra/blob/master/poly/src/polynomial/multivariate/sparse.rs
Loading

0 comments on commit e783285

Please sign in to comment.