diff --git a/Cargo.lock b/Cargo.lock index 7c37657754..a13cd54849 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -119,7 +119,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -135,7 +135,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", "syn-solidity", "tiny-keccak", ] @@ -151,7 +151,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", "syn-solidity", ] @@ -399,7 +399,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -427,7 +427,7 @@ checksum = "3c87f3f15e7794432337fc718554eaa4dc8f04c9677a950ffe366f20a162ae42" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -571,7 +571,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.66", + "syn 2.0.67", "which", ] @@ -869,7 +869,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -1147,7 +1147,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -1171,7 +1171,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -1182,7 +1182,7 @@ checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" dependencies = [ "darling_core", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -1226,7 +1226,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version 0.4.0", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -1523,7 +1523,7 @@ dependencies = [ "regex", "serde", "serde_json", - "syn 2.0.66", + "syn 2.0.67", "toml", "walkdir", ] @@ -1541,7 +1541,7 @@ dependencies = [ "proc-macro2", "quote", "serde_json", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -1567,7 +1567,7 @@ dependencies = [ "serde", "serde_json", "strum", - "syn 2.0.66", + "syn 2.0.67", "tempfile", "thiserror", "tiny-keccak", @@ -1842,7 +1842,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -2920,7 +2920,7 @@ dependencies = [ "proc-macro-crate 3.1.0", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3007,7 +3007,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3043,7 +3043,7 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3-air" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "p3-field", "p3-matrix", @@ -3052,7 +3052,7 @@ dependencies = [ [[package]] name = "p3-baby-bear" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "num-bigint 0.4.5", "p3-field", @@ -3066,7 +3066,7 @@ dependencies = [ [[package]] name = "p3-blake3" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "blake3", "p3-symmetric", @@ -3075,7 +3075,7 @@ dependencies = [ [[package]] name = "p3-bn254-fr" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "ff 0.13.0", "num-bigint 0.4.5", @@ -3089,7 +3089,7 @@ dependencies = [ [[package]] name = "p3-challenger" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "p3-field", "p3-maybe-rayon", @@ -3101,7 +3101,7 @@ dependencies = [ [[package]] name = "p3-commit" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "itertools 0.12.1", "p3-challenger", @@ -3114,7 +3114,7 @@ dependencies = [ [[package]] name = "p3-dft" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "p3-field", "p3-matrix", @@ -3126,7 +3126,7 @@ dependencies = [ [[package]] name = "p3-field" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "itertools 0.12.1", "num-bigint 0.4.5", @@ -3139,7 +3139,7 @@ dependencies = [ [[package]] name = "p3-fri" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "itertools 0.12.1", "p3-challenger", @@ -3157,7 +3157,7 @@ dependencies = [ [[package]] name = "p3-interpolation" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "p3-field", "p3-matrix", @@ -3167,7 +3167,7 @@ dependencies = [ [[package]] name = "p3-keccak" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "p3-symmetric", "tiny-keccak", @@ -3176,7 +3176,7 @@ dependencies = [ [[package]] name = "p3-keccak-air" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "p3-air", "p3-field", @@ -3189,7 +3189,7 @@ dependencies = [ [[package]] name = "p3-matrix" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "itertools 0.12.1", "p3-field", @@ -3203,7 +3203,7 @@ dependencies = [ [[package]] name = "p3-maybe-rayon" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "rayon", ] @@ -3211,7 +3211,7 @@ dependencies = [ [[package]] name = "p3-mds" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "itertools 0.12.1", "p3-dft", @@ -3225,7 +3225,7 @@ dependencies = [ [[package]] name = "p3-merkle-tree" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "itertools 0.12.1", "p3-commit", @@ -3241,7 +3241,7 @@ dependencies = [ [[package]] name = "p3-poseidon2" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "gcd", "p3-field", @@ -3253,7 +3253,7 @@ dependencies = [ [[package]] name = "p3-symmetric" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "itertools 0.12.1", "p3-field", @@ -3263,7 +3263,7 @@ dependencies = [ [[package]] name = "p3-uni-stark" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "itertools 0.12.1", "p3-air", @@ -3281,7 +3281,7 @@ dependencies = [ [[package]] name = "p3-util" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=42d52e8608c1d12d337cfc8bfd692777ef13532f#42d52e8608c1d12d337cfc8bfd692777ef13532f" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=b447924a508a780b997d451f9864e593efce0843#b447924a508a780b997d451f9864e593efce0843" dependencies = [ "serde", ] @@ -3452,7 +3452,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3536,7 +3536,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3645,7 +3645,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -4397,7 +4397,7 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -4469,7 +4469,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -4494,7 +4494,7 @@ checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -5067,14 +5067,14 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] name = "subtle" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +checksum = "0d0208408ba0c3df17ed26eb06992cb1a1268d41b2c0e12e65203fbe3972cee5" [[package]] name = "subtle-encoding" @@ -5098,9 +5098,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.66" +version = "2.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +checksum = "ff8655ed1d86f3af4ee3fd3263786bc14245ad17c4c7e85ba7187fb3ae028c90" dependencies = [ "proc-macro2", "quote", @@ -5116,7 +5116,7 @@ dependencies = [ "paste", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -5193,7 +5193,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -5300,7 +5300,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -5451,7 +5451,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -5732,7 +5732,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", "wasm-bindgen-shared", ] @@ -5766,7 +5766,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -6105,7 +6105,7 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -6125,7 +6125,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index b5b09af31a..76287a565d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,26 +30,26 @@ debug = true debug-assertions = true [workspace.dependencies] -p3-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-commit = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-matrix = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } +p3-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-commit = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-matrix = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", features = [ "nightly-features", -], rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-util = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-dft = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-keccak = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-keccak-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-blake3 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-merkle-tree = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-uni-stark = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-maybe-rayon = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } -p3-bn254-fr = { git = "https://github.com/Plonky3/Plonky3.git", rev = "42d52e8608c1d12d337cfc8bfd692777ef13532f" } +], rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-util = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-dft = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-keccak = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-keccak-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-blake3 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-merkle-tree = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-uni-stark = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-maybe-rayon = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } +p3-bn254-fr = { git = "https://github.com/Plonky3/Plonky3.git", rev = "b447924a508a780b997d451f9864e593efce0843" } # For local development. diff --git a/core/src/air/builder.rs b/core/src/air/builder.rs index d957f43cd9..d5d87b2b23 100644 --- a/core/src/air/builder.rs +++ b/core/src/air/builder.rs @@ -52,6 +52,11 @@ pub trait BaseAirBuilder: AirBuilder + MessageBuilder } } + /// Asserts that an iterator of expressions are all zero. + fn assert_all_zero>(&mut self, iter: impl IntoIterator) { + iter.into_iter().for_each(|expr| self.assert_zero(expr)); + } + /// Will return `a` if `condition` is 1, else `b`. This assumes that `condition` is already /// checked to be a boolean. #[inline] diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 552c3fc348..a621e0ac1e 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -32,6 +32,7 @@ use syn::Data; use syn::DeriveInput; use syn::GenericParam; use syn::ItemFn; +use syn::WherePredicate; #[proc_macro_derive(AlignedBorrow)] pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { @@ -94,7 +95,13 @@ pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { #[proc_macro_derive( MachineAir, - attributes(sp1_core_path, execution_record_path, program_path, builder_path) + attributes( + sp1_core_path, + execution_record_path, + program_path, + builder_path, + eval_trait_bound + ) )] pub fn machine_air_derive(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); @@ -105,6 +112,7 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { let execution_record_path = find_execution_record_path(&ast.attrs); let program_path = find_program_path(&ast.attrs); let builder_path = find_builder_path(&ast.attrs); + let eval_trait_bound = find_eval_trait_bound(&ast.attrs); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); match &ast.data { @@ -257,6 +265,13 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { let (air_impl_generics, _, _) = new_generics.split_for_impl(); + let mut new_generics = generics.clone(); + let where_clause = new_generics.make_where_clause(); + if eval_trait_bound.is_some() { + let predicate: WherePredicate = syn::parse_str(&eval_trait_bound.unwrap()).unwrap(); + where_clause.predicates.push(predicate); + } + let air = quote! { impl #air_impl_generics p3_air::Air for #name #ty_generics #where_clause { fn eval(&self, builder: &mut AB) { @@ -360,3 +375,17 @@ fn find_builder_path(attrs: &[syn::Attribute]) -> syn::Path { } parse_quote!(crate::air::SP1AirBuilder) } + +fn find_eval_trait_bound(attrs: &[syn::Attribute]) -> Option { + for attr in attrs { + if attr.path.is_ident("eval_trait_bound") { + if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() { + if let syn::Lit::Str(lit_str) = &meta.lit { + return Some(lit_str.value()); + } + } + } + } + + None +} diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 0314ddee8a..44b911d172 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -74,7 +74,7 @@ pub type OuterSC = BabyBearPoseidon2Outer; const REDUCE_DEGREE: usize = 3; const COMPRESS_DEGREE: usize = 9; -const WRAP_DEGREE: usize = 9; +const WRAP_DEGREE: usize = 17; pub type ReduceAir = RecursionAir; pub type CompressAir = RecursionAir; diff --git a/recursion/circuit/src/stark.rs b/recursion/circuit/src/stark.rs index 574a801411..c2986c62a4 100644 --- a/recursion/circuit/src/stark.rs +++ b/recursion/circuit/src/stark.rs @@ -23,7 +23,7 @@ use sp1_recursion_compiler::ir::{Usize, Witness}; use sp1_recursion_compiler::prelude::SymbolicVar; use sp1_recursion_core::air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH}; use sp1_recursion_core::stark::config::{outer_fri_config, BabyBearPoseidon2Outer}; -use sp1_recursion_core::stark::RecursionAirSkinnyDeg9; +use sp1_recursion_core::stark::RecursionAirWideDeg17; use sp1_recursion_program::commit::PolynomialSpaceVariable; use sp1_recursion_program::stark::RecursiveVerifierConstraintFolder; use sp1_recursion_program::types::QuotientDataValues; @@ -244,7 +244,7 @@ pub fn build_wrap_circuit( template_proof: ShardProof, ) -> Vec { let outer_config = OuterSC::new(); - let outer_machine = RecursionAirSkinnyDeg9::::wrap_machine(outer_config); + let outer_machine = RecursionAirWideDeg17::::wrap_machine(outer_config); let mut builder = Builder::::default(); let mut challenger = MultiField32ChallengerVariable::new(&mut builder); diff --git a/recursion/compiler/src/asm/compiler.rs b/recursion/compiler/src/asm/compiler.rs index 3ec01ebe48..47797aae68 100644 --- a/recursion/compiler/src/asm/compiler.rs +++ b/recursion/compiler/src/asm/compiler.rs @@ -517,7 +517,32 @@ impl + TwoAdicField> AsmCo _ => unimplemented!(), } } - + DslIr::Poseidon2AbsorbBabyBear(p2_hash_num, input) => match input { + Array::Dyn(input, input_size) => { + if let Usize::Var(input_size) = input_size { + self.push( + AsmInstruction::Poseidon2Absorb( + p2_hash_num.fp(), + input.fp(), + input_size.fp(), + ), + trace, + ); + } else { + unimplemented!(); + } + } + _ => unimplemented!(), + }, + DslIr::Poseidon2FinalizeBabyBear(p2_hash_num, output) => match output { + Array::Dyn(output, _) => { + self.push( + AsmInstruction::Poseidon2Finalize(p2_hash_num.fp(), output.fp()), + trace, + ); + } + _ => unimplemented!(), + }, DslIr::Commit(val, index) => { self.push(AsmInstruction::Commit(val.fp(), index.fp()), trace); } diff --git a/recursion/compiler/src/asm/instruction.rs b/recursion/compiler/src/asm/instruction.rs index 78befc94d4..5a38b67a4c 100644 --- a/recursion/compiler/src/asm/instruction.rs +++ b/recursion/compiler/src/asm/instruction.rs @@ -147,8 +147,16 @@ pub enum AsmInstruction { /// Perform a permutation of the Poseidon2 hash function on the array specified by the ptr. Poseidon2Permute(i32, i32), + + /// Perform a Poseidon2 compress. Poseidon2Compress(i32, i32, i32), + /// Performs a Posedion2 absorb. + Poseidon2Absorb(i32, i32, i32), + + /// Performs a Poseidon2 finalize. + Poseidon2Finalize(i32, i32), + /// Print a variable. PrintV(i32), @@ -846,6 +854,28 @@ impl> AsmInstruction { false, "".to_string(), ), + AsmInstruction::Poseidon2Absorb(hash_num, input_ptr, input_len) => Instruction::new( + Opcode::Poseidon2Absorb, + i32_f(hash_num), + i32_f_arr(input_ptr), + i32_f_arr(input_len), + F::zero(), + F::zero(), + false, + false, + "".to_string(), + ), + AsmInstruction::Poseidon2Finalize(hash_num, output_ptr) => Instruction::new( + Opcode::Poseidon2Finalize, + i32_f(hash_num), + i32_f_arr(output_ptr), + f_u32(F::zero()), + F::zero(), + F::zero(), + false, + false, + "".to_string(), + ), AsmInstruction::Commit(val, index) => Instruction::new( Opcode::Commit, i32_f(val), @@ -1144,6 +1174,16 @@ impl> AsmInstruction { result, src1, src2 ) } + AsmInstruction::Poseidon2Absorb(hash_num, input_ptr, input_len) => { + write!( + f, + "poseidon2_absorb ({})fp, {})fp, ({})fp", + hash_num, input_ptr, input_len, + ) + } + AsmInstruction::Poseidon2Finalize(hash_num, output_ptr) => { + write!(f, "poseidon2_finalize ({})fp, {})fp", hash_num, output_ptr,) + } AsmInstruction::Commit(val, index) => { write!(f, "commit ({})fp ({})fp", val, index) } diff --git a/recursion/compiler/src/ir/builder.rs b/recursion/compiler/src/ir/builder.rs index 19fb164f28..b3d5741dda 100644 --- a/recursion/compiler/src/ir/builder.rs +++ b/recursion/compiler/src/ir/builder.rs @@ -90,7 +90,7 @@ impl IntoIterator for TracedVec { /// A builder for the DSL. /// /// Can compile to both assembly and a set of constraints. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct Builder { pub(crate) felt_count: u32, pub(crate) ext_count: u32, @@ -100,10 +100,35 @@ pub struct Builder { pub(crate) witness_var_count: u32, pub(crate) witness_felt_count: u32, pub(crate) witness_ext_count: u32, + pub(crate) p2_hash_num: Var, pub(crate) debug: bool, pub(crate) is_sub_builder: bool, } +impl Default for Builder { + fn default() -> Self { + // We need to create a temporary placeholder for the p2_hash_num variable. + let placeholder_p2_hash_num = Var::new(0); + + let mut new_builder = Self { + felt_count: 0, + ext_count: 0, + var_count: 0, + witness_var_count: 0, + witness_felt_count: 0, + witness_ext_count: 0, + operations: Default::default(), + nb_public_values: None, + p2_hash_num: placeholder_p2_hash_num, + debug: false, + is_sub_builder: false, + }; + + new_builder.p2_hash_num = new_builder.uninit(); + new_builder + } +} + impl Builder { /// Creates a new builder with a given number of counts for each type. pub fn new_sub_builder( @@ -111,6 +136,7 @@ impl Builder { felt_count: u32, ext_count: u32, nb_public_values: Option>, + p2_hash_num: Var, debug: bool, ) -> Self { Self { @@ -124,6 +150,7 @@ impl Builder { witness_ext_count: 0, operations: Default::default(), nb_public_values, + p2_hash_num, debug, is_sub_builder: true, } @@ -517,9 +544,12 @@ impl<'a, C: Config> IfBuilder<'a, C> { self.builder.felt_count, self.builder.ext_count, self.builder.nb_public_values, + self.builder.p2_hash_num, self.builder.debug, ); f(&mut f_builder); + self.builder.p2_hash_num = f_builder.p2_hash_num; + let then_instructions = f_builder.operations; // Dispatch instructions to the correct conditional block. @@ -565,11 +595,14 @@ impl<'a, C: Config> IfBuilder<'a, C> { self.builder.felt_count, self.builder.ext_count, self.builder.nb_public_values, + self.builder.p2_hash_num, self.builder.debug, ); // Execute the `then` and `else_then` blocks and collect the instructions. then_f(&mut then_builder); + self.builder.p2_hash_num = then_builder.p2_hash_num; + let then_instructions = then_builder.operations; let mut else_builder = Builder::::new_sub_builder( @@ -577,9 +610,12 @@ impl<'a, C: Config> IfBuilder<'a, C> { self.builder.felt_count, self.builder.ext_count, self.builder.nb_public_values, + self.builder.p2_hash_num, self.builder.debug, ); else_f(&mut else_builder); + self.builder.p2_hash_num = else_builder.p2_hash_num; + let else_instructions = else_builder.operations; // Dispatch instructions to the correct conditional block. @@ -711,10 +747,12 @@ impl<'a, C: Config> RangeBuilder<'a, C> { self.builder.felt_count, self.builder.ext_count, self.builder.nb_public_values, + self.builder.p2_hash_num, self.builder.debug, ); f(loop_variable, &mut loop_body_builder); + self.builder.p2_hash_num = loop_body_builder.p2_hash_num; let loop_instructions = loop_body_builder.operations; diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index f5c2a1b856..ab0cb67098 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -198,6 +198,10 @@ pub enum DslIr { Array>, Array>, ), + /// Absorb an array of baby bear elements for a specified hash instance. + Poseidon2AbsorbBabyBear(Var, Array>), + /// Finalize and return the hash digest of a specified hash instance. + Poseidon2FinalizeBabyBear(Var, Array>), /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should only /// be used when target is a gnark circuit. CircuitPoseidon2Permute([Var; 3]), diff --git a/recursion/compiler/src/ir/poseidon.rs b/recursion/compiler/src/ir/poseidon.rs index 2f8b0ef86b..a69f86c8d6 100644 --- a/recursion/compiler/src/ir/poseidon.rs +++ b/recursion/compiler/src/ir/poseidon.rs @@ -32,6 +32,28 @@ impl Builder { )); } + /// Applies the Poseidon2 absorb function to the given array. + /// + /// Reference: [p3_symmetric::PaddingFreeSponge] + pub fn poseidon2_absorb(&mut self, p2_hash_num: Var, input: &Array>) { + self.operations + .push(DslIr::Poseidon2AbsorbBabyBear(p2_hash_num, input.clone())); + } + + /// Applies the Poseidon2 finalize to the given hash number. + /// + /// Reference: [p3_symmetric::PaddingFreeSponge] + pub fn poseidon2_finalize_mut( + &mut self, + p2_hash_num: Var, + output: &Array>, + ) { + self.operations.push(DslIr::Poseidon2FinalizeBabyBear( + p2_hash_num, + output.clone(), + )); + } + /// Applies the Poseidon2 compression function to the given array. /// /// Reference: [p3_symmetric::TruncatedPermutation] @@ -104,33 +126,20 @@ impl Builder { array: &Array>>, ) -> Array> { self.cycle_tracker("poseidon2-hash"); - let mut state: Array> = self.dyn_array(PERMUTATION_WIDTH); - let idx: Var<_> = self.eval(C::N::zero()); + let p2_hash_num = self.p2_hash_num; self.range(0, array.len()).for_each(|i, builder| { let subarray = builder.get(array, i); - builder.range(0, subarray.len()).for_each(|j, builder| { - builder.cycle_tracker("poseidon2-hash-setup"); - let element = builder.get(&subarray, j); - builder.set_value(&mut state, idx, element); - builder.assign(idx, idx + C::N::one()); - builder.cycle_tracker("poseidon2-hash-setup"); - builder - .if_eq(idx, C::N::from_canonical_usize(HASH_RATE)) - .then(|builder| { - builder.poseidon2_permute_mut(&state); - builder.assign(idx, C::N::zero()); - }); - }); + builder.poseidon2_absorb(p2_hash_num, &subarray); }); - self.if_ne(idx, C::N::zero()).then(|builder| { - builder.poseidon2_permute_mut(&state); - }); + let output: Array> = self.dyn_array(DIGEST_SIZE); + self.poseidon2_finalize_mut(self.p2_hash_num, &output); + + self.assign(self.p2_hash_num, self.p2_hash_num + C::N::one()); - state.truncate(self, Usize::Const(DIGEST_SIZE)); self.cycle_tracker("poseidon2-hash"); - state + output } pub fn poseidon2_hash_ext( diff --git a/recursion/compiler/tests/poseidon2.rs b/recursion/compiler/tests/poseidon2.rs index 9ec2098f27..64b1442179 100644 --- a/recursion/compiler/tests/poseidon2.rs +++ b/recursion/compiler/tests/poseidon2.rs @@ -4,6 +4,7 @@ use p3_symmetric::Permutation; use rand::thread_rng; use rand::Rng; use sp1_core::stark::StarkGenericConfig; +use sp1_core::utils::setup_logger; use sp1_core::utils::BabyBearPoseidon2; use sp1_recursion_compiler::asm::AsmBuilder; use sp1_recursion_compiler::ir::Array; @@ -64,6 +65,7 @@ fn test_compiler_poseidon2_permute() { #[test] fn test_compiler_poseidon2_hash() { + setup_logger(); type SC = BabyBearPoseidon2; type F = ::Val; type EF = ::Challenge; @@ -74,19 +76,32 @@ fn test_compiler_poseidon2_hash() { let mut builder = AsmBuilder::::default(); - let random_state_vals: [F; 42] = rng.gen(); - println!("{:?}", random_state_vals); + let random_state_vals_1: [F; 42] = rng.gen(); + println!("{:?}", random_state_vals_1); + let random_state_vals_2: [F; 42] = rng.gen(); + println!("{:?}", random_state_vals_2); - let mut random_state_v1 = builder.dyn_array(random_state_vals.len()); - for (i, val) in random_state_vals.iter().enumerate() { + let mut random_state_v1 = + builder.dyn_array(random_state_vals_1.len() + random_state_vals_2.len()); + for (i, val) in random_state_vals_1.iter().enumerate() { builder.set(&mut random_state_v1, i, *val); } - let mut random_state_v2 = builder.dyn_array(random_state_vals.len()); - for (i, val) in random_state_vals.iter().enumerate() { - builder.set(&mut random_state_v2, i, *val); + for (i, val) in random_state_vals_2.iter().enumerate() { + builder.set(&mut random_state_v1, i + random_state_vals_1.len(), *val); + } + + let mut random_state_v2_1 = builder.dyn_array(random_state_vals_1.len()); + for (i, val) in random_state_vals_1.iter().enumerate() { + builder.set(&mut random_state_v2_1, i, *val); + } + let mut random_state_v2_2 = builder.dyn_array(random_state_vals_2.len()); + for (i, val) in random_state_vals_2.iter().enumerate() { + builder.set(&mut random_state_v2_2, i, *val); } - let mut nested_random_state = builder.dyn_array(1); - builder.set(&mut nested_random_state, 0, random_state_v2.clone()); + + let mut nested_random_state = builder.dyn_array(2); + builder.set(&mut nested_random_state, 0, random_state_v2_1.clone()); + builder.set(&mut nested_random_state, 1, random_state_v2_2.clone()); let result = builder.poseidon2_hash(&random_state_v1); let result_x = builder.poseidon2_hash_x(&nested_random_state); @@ -105,6 +120,7 @@ fn test_compiler_poseidon2_hash() { "The program executed successfully, number of cycles: {}", runtime.clk.as_canonical_u32() / 4 ); + runtime.print_stats(); } #[test] diff --git a/recursion/core/src/air/multi_builder.rs b/recursion/core/src/air/multi_builder.rs index 814b3ae3d0..a13dc0d742 100644 --- a/recursion/core/src/air/multi_builder.rs +++ b/recursion/core/src/air/multi_builder.rs @@ -1,4 +1,7 @@ -use p3_air::{AirBuilder, ExtensionBuilder, FilteredAirBuilder, PermutationAirBuilder}; +use p3_air::{ + AirBuilder, AirBuilderWithPublicValues, ExtensionBuilder, FilteredAirBuilder, + PermutationAirBuilder, +}; use sp1_core::air::MessageBuilder; /// The MultiBuilder is used for the multi table. It is used to create a virtual builder for one of @@ -81,3 +84,13 @@ impl<'a, AB: AirBuilder + MessageBuilder, M> MessageBuilder for MultiBuild self.inner.receive(message); } } + +impl<'a, AB: AirBuilder + AirBuilderWithPublicValues> AirBuilderWithPublicValues + for MultiBuilder<'a, AB> +{ + type PublicVar = AB::PublicVar; + + fn public_values(&self) -> &[Self::PublicVar] { + self.inner.public_values() + } +} diff --git a/recursion/core/src/cpu/columns/opcode.rs b/recursion/core/src/cpu/columns/opcode.rs index 513b563a00..a3d85a93b4 100644 --- a/recursion/core/src/cpu/columns/opcode.rs +++ b/recursion/core/src/cpu/columns/opcode.rs @@ -80,8 +80,10 @@ impl OpcodeSelectorCols { Opcode::TRAP => self.is_trap = F::one(), Opcode::HALT => self.is_halt = F::one(), Opcode::FRIFold => self.is_fri_fold = F::one(), + Opcode::Poseidon2Compress | Opcode::Poseidon2Absorb | Opcode::Poseidon2Finalize => { + self.is_poseidon = F::one() + } Opcode::ExpReverseBitsLen => self.is_exp_reverse_bits_len = F::one(), - Opcode::Poseidon2Compress => self.is_poseidon = F::one(), Opcode::Commit => self.is_commit = F::one(), Opcode::HintExt2Felt => self.is_ext_to_felt = F::one(), diff --git a/recursion/core/src/fri_fold/mod.rs b/recursion/core/src/fri_fold/mod.rs index e0da468987..bdbe76d1fd 100644 --- a/recursion/core/src/fri_fold/mod.rs +++ b/recursion/core/src/fri_fold/mod.rs @@ -1,6 +1,5 @@ #![allow(clippy::needless_range_loop)] -use crate::air::RecursionMemoryAirBuilder; use crate::memory::{MemoryReadCols, MemoryReadSingleCols, MemoryReadWriteCols}; use crate::runtime::Opcode; use core::borrow::Borrow; @@ -10,7 +9,7 @@ use p3_field::AbstractField; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use sp1_core::air::{BaseAirBuilder, BinomialExtension, ExtensionAirBuilder, MachineAir}; +use sp1_core::air::{BaseAirBuilder, BinomialExtension, MachineAir}; use sp1_core::utils::pad_rows_fixed; use sp1_derive::AlignedBorrow; use std::borrow::BorrowMut; @@ -171,7 +170,7 @@ impl MachineAir for FriFoldChip } impl FriFoldChip { - pub fn eval_fri_fold( + pub fn eval_fri_fold( &self, builder: &mut AB, local: &FriFoldCols, @@ -179,16 +178,6 @@ impl FriFoldChip { receive_table: AB::Var, memory_access: AB::Var, ) { - // Dummy constraints to normalize to DEGREE when DEGREE > 3. - if DEGREE > 3 { - let lhs = (0..DEGREE) - .map(|_| local.is_real.into()) - .product::(); - let rhs = (0..DEGREE) - .map(|_| local.is_real.into()) - .product::(); - builder.assert_eq(lhs, rhs); - } // Constraint that the operands are sent from the CPU table. let first_iteration_clk = local.clk.into() - local.m.into(); let total_num_iterations = local.m.into() + AB::Expr::one(); @@ -400,6 +389,16 @@ where let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &FriFoldCols = (*local).borrow(); let next: &FriFoldCols = (*next).borrow(); + + // Dummy constraints to normalize to DEGREE. + let lhs = (0..DEGREE) + .map(|_| local.is_real.into()) + .product::(); + let rhs = (0..DEGREE) + .map(|_| local.is_real.into()) + .product::(); + builder.assert_eq(lhs, rhs); + self.eval_fri_fold::( builder, local, diff --git a/recursion/core/src/lib.rs b/recursion/core/src/lib.rs index 785179fa77..f1c93c956d 100644 --- a/recursion/core/src/lib.rs +++ b/recursion/core/src/lib.rs @@ -4,7 +4,6 @@ pub mod exp_reverse_bits; pub mod fri_fold; pub mod memory; pub mod multi; -pub mod poseidon2; pub mod poseidon2_wide; pub mod program; pub mod range_check; diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index 0b93f52aa2..97ab427844 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -1,4 +1,7 @@ +use std::array; use std::borrow::{Borrow, BorrowMut}; +use std::cmp::max; +use std::ops::Deref; use itertools::Itertools; use p3_air::{Air, AirBuilder, BaseAir}; @@ -11,7 +14,8 @@ use sp1_derive::AlignedBorrow; use crate::air::{MultiBuilder, SP1RecursionAirBuilder}; use crate::fri_fold::{FriFoldChip, FriFoldCols}; -use crate::poseidon2::{Poseidon2Chip, Poseidon2Cols}; +use crate::poseidon2_wide::columns::Poseidon2; +use crate::poseidon2_wide::{Poseidon2WideChip, WIDTH}; use crate::runtime::{ExecutionRecord, RecursionProgram}; pub const NUM_MULTI_COLS: usize = core::mem::size_of::>(); @@ -24,27 +28,31 @@ pub struct MultiChip { #[derive(AlignedBorrow, Clone, Copy)] #[repr(C)] pub struct MultiCols { - pub instruction: InstructionSpecificCols, - pub is_fri_fold: T, + + /// Rows that needs to receive a fri_fold syscall. pub fri_fold_receive_table: T, + /// Rows that needs to access memory. pub fri_fold_memory_access: T, pub is_poseidon2: T, - pub poseidon2_receive_table: T, - pub poseidon2_memory_access: T, -} -#[derive(Clone, Copy)] -#[repr(C)] -pub union InstructionSpecificCols { - fri_fold: FriFoldCols, - poseidon2: Poseidon2Cols, + /// Rows that needs to receive a poseidon2 syscall. + pub poseidon2_receive_table: T, + /// Hash/Permute state entries that needs to access memory. This is for the the first half of the permute state. + pub poseidon2_1st_half_memory_access: [T; WIDTH / 2], + /// Flag to indicate if all of the second half of a compress state needs to access memory. + pub poseidon2_2nd_half_memory_access: T, + /// Rows that need to send a range check. + pub poseidon2_send_range_check: T, } impl BaseAir for MultiChip { fn width(&self) -> usize { - NUM_MULTI_COLS + let fri_fold_width = Self::fri_fold_width::(); + let poseidon2_width = Self::poseidon2_width::(); + + max(fri_fold_width, poseidon2_width) + NUM_MULTI_COLS } } @@ -57,50 +65,59 @@ impl MachineAir for MultiChip { "Multi".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { - // This is a no-op. - } - fn generate_trace( &self, input: &ExecutionRecord, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let fri_fold_chip = FriFoldChip::<3> { + let fri_fold_chip = FriFoldChip:: { fixed_log2_rows: None, pad: false, }; - let poseidon2 = Poseidon2Chip { + let poseidon2 = Poseidon2WideChip:: { fixed_log2_rows: None, pad: false, }; let fri_fold_trace = fri_fold_chip.generate_trace(input, output); let mut poseidon2_trace = poseidon2.generate_trace(input, output); + let num_columns = as BaseAir>::width(self); + let mut rows = fri_fold_trace .clone() .rows_mut() .chain(poseidon2_trace.rows_mut()) .enumerate() .map(|(i, instruction_row)| { - let mut row = [F::zero(); NUM_MULTI_COLS]; - row[0..instruction_row.len()].copy_from_slice(instruction_row); - let cols: &mut MultiCols = row.as_mut_slice().borrow_mut(); - if i < fri_fold_trace.height() { - cols.is_fri_fold = F::one(); - - let fri_fold_cols = *cols.fri_fold(); - cols.fri_fold_receive_table = - FriFoldChip::<3>::do_receive_table(&fri_fold_cols); - cols.fri_fold_memory_access = - FriFoldChip::<3>::do_memory_access(&fri_fold_cols); - } else { - cols.is_poseidon2 = F::one(); + let process_fri_fold = i < fri_fold_trace.height(); + + let mut row = vec![F::zero(); num_columns]; + row[NUM_MULTI_COLS..NUM_MULTI_COLS + instruction_row.len()] + .copy_from_slice(instruction_row); + + if process_fri_fold { + let multi_cols: &mut MultiCols = row[0..NUM_MULTI_COLS].borrow_mut(); + multi_cols.is_fri_fold = F::one(); - let poseidon2_cols = *cols.poseidon2(); - cols.poseidon2_receive_table = Poseidon2Chip::do_receive_table(&poseidon2_cols); - cols.poseidon2_memory_access = Poseidon2Chip::do_memory_access(&poseidon2_cols); + let fri_fold_cols: &FriFoldCols = (*instruction_row).borrow(); + multi_cols.fri_fold_receive_table = + FriFoldChip::::do_receive_table(fri_fold_cols); + multi_cols.fri_fold_memory_access = + FriFoldChip::::do_memory_access(fri_fold_cols); + } else { + let multi_cols: &mut MultiCols = row[0..NUM_MULTI_COLS].borrow_mut(); + multi_cols.is_poseidon2 = F::one(); + + let poseidon2_cols = Poseidon2WideChip::::convert::(instruction_row); + multi_cols.poseidon2_receive_table = + poseidon2_cols.control_flow().is_syscall_row; + multi_cols.poseidon2_1st_half_memory_access = + array::from_fn(|i| poseidon2_cols.memory().memory_slot_used[i]); + multi_cols.poseidon2_2nd_half_memory_access = + poseidon2_cols.control_flow().is_compress; + multi_cols.poseidon2_send_range_check = poseidon2_cols.control_flow().is_absorb; } + row }) .collect_vec(); @@ -108,12 +125,12 @@ impl MachineAir for MultiChip { // Pad the trace to a power of two. pad_rows_fixed( &mut rows, - || [F::zero(); NUM_MULTI_COLS], + || vec![F::zero(); num_columns], self.fixed_log2_rows, ); // Convert the trace to a row major matrix. - RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_MULTI_COLS) + RowMajorMatrix::new(rows.into_iter().flatten().collect(), num_columns) } fn included(&self, _: &Self::Record) -> bool { @@ -124,26 +141,32 @@ impl MachineAir for MultiChip { impl Air for MultiChip where AB: SP1RecursionAirBuilder, + AB::Var: 'static, { fn eval(&self, builder: &mut AB) { let main = builder.main(); let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local: &MultiCols = (*local).borrow(); - let next: &MultiCols = (*next).borrow(); - // Add some dummy constraints to compress the interactions. - let mut expr = local.is_fri_fold * local.is_fri_fold; - for _ in 0..(DEGREE - 2) { - expr *= local.is_fri_fold.into(); - } - builder.assert_eq(expr.clone(), expr.clone()); + let local_slice: &[::Var] = &local; + let next_slice: &[::Var] = &next; + let local_multi_cols: &MultiCols = local_slice[0..NUM_MULTI_COLS].borrow(); + let next_multi_cols: &MultiCols = next_slice[0..NUM_MULTI_COLS].borrow(); - let next_is_real = next.is_fri_fold + next.is_poseidon2; - let local_is_real = local.is_fri_fold + local.is_poseidon2; + // Dummy constraints to normalize to DEGREE. + let lhs = (0..DEGREE) + .map(|_| local_multi_cols.is_poseidon2.into()) + .product::(); + let rhs = (0..DEGREE) + .map(|_| local_multi_cols.is_poseidon2.into()) + .product::(); + builder.assert_eq(lhs, rhs); + + let next_is_real = next_multi_cols.is_fri_fold + next_multi_cols.is_poseidon2; + let local_is_real = local_multi_cols.is_fri_fold + local_multi_cols.is_poseidon2; // Assert that is_fri_fold and is_poseidon2 are bool and that at most one is set. - builder.assert_bool(local.is_fri_fold); - builder.assert_bool(local.is_poseidon2); + builder.assert_bool(local_multi_cols.is_fri_fold); + builder.assert_bool(local_multi_cols.is_poseidon2); builder.assert_bool(local_is_real.clone()); // Fri fold requires that it's rows are contiguous, since each invocation spans multiple rows @@ -160,119 +183,143 @@ where builder .when_transition() .when(next_is_real) - .when(local.is_poseidon2) - .assert_one(next.is_poseidon2); + .when(local_multi_cols.is_poseidon2) + .assert_one(next_multi_cols.is_poseidon2); + + let mut sub_builder = MultiBuilder::new( + builder, + local_multi_cols.is_fri_fold.into(), + next_multi_cols.is_fri_fold.into(), + ); - let mut sub_builder = - MultiBuilder::new(builder, local.is_fri_fold.into(), next.is_fri_fold.into()); + let local_fri_fold_cols = Self::fri_fold(&local); + let next_fri_fold_cols = Self::fri_fold(&next); - let fri_columns_local = local.fri_fold(); sub_builder.assert_eq( - local.is_fri_fold * FriFoldChip::<3>::do_memory_access::(fri_columns_local), - local.fri_fold_memory_access, + local_multi_cols.is_fri_fold + * FriFoldChip::::do_memory_access::(&local_fri_fold_cols), + local_multi_cols.fri_fold_memory_access, ); sub_builder.assert_eq( - local.is_fri_fold * FriFoldChip::<3>::do_receive_table::(fri_columns_local), - local.fri_fold_receive_table, + local_multi_cols.is_fri_fold + * FriFoldChip::::do_receive_table::(&local_fri_fold_cols), + local_multi_cols.fri_fold_receive_table, ); - let fri_fold_chip = FriFoldChip::<3>::default(); + let fri_fold_chip = FriFoldChip::::default(); fri_fold_chip.eval_fri_fold( &mut sub_builder, - local.fri_fold(), - next.fri_fold(), - local.fri_fold_receive_table, - local.fri_fold_memory_access, + &local_fri_fold_cols, + &next_fri_fold_cols, + local_multi_cols.fri_fold_receive_table, + local_multi_cols.fri_fold_memory_access, + ); + + let mut sub_builder = MultiBuilder::new( + builder, + local_multi_cols.is_poseidon2.into(), + next_multi_cols.is_poseidon2.into(), ); - let mut sub_builder = - MultiBuilder::new(builder, local.is_poseidon2.into(), next.is_poseidon2.into()); + let poseidon2_columns = MultiChip::::poseidon2(local_slice); + sub_builder.assert_eq( + local_multi_cols.is_poseidon2 * poseidon2_columns.control_flow().is_syscall_row, + local_multi_cols.poseidon2_receive_table, + ); + local_multi_cols + .poseidon2_1st_half_memory_access + .iter() + .enumerate() + .for_each(|(i, mem_access)| { + sub_builder.assert_eq( + local_multi_cols.is_poseidon2 * poseidon2_columns.memory().memory_slot_used[i], + *mem_access, + ); + }); - let poseidon2_columns = local.poseidon2(); sub_builder.assert_eq( - local.is_poseidon2 * Poseidon2Chip::do_receive_table::(poseidon2_columns), - local.poseidon2_receive_table, + local_multi_cols.is_poseidon2 * poseidon2_columns.control_flow().is_compress, + local_multi_cols.poseidon2_2nd_half_memory_access, ); + sub_builder.assert_eq( - local.is_poseidon2 * Poseidon2Chip::do_memory_access::(poseidon2_columns), - local.poseidon2_memory_access, + local_multi_cols.is_poseidon2 * poseidon2_columns.control_flow().is_absorb, + local_multi_cols.poseidon2_send_range_check, ); - let poseidon2_chip = Poseidon2Chip::default(); + let poseidon2_chip = Poseidon2WideChip::::default(); poseidon2_chip.eval_poseidon2( &mut sub_builder, - local.poseidon2(), - next.poseidon2(), - local.poseidon2_receive_table, - local.poseidon2_memory_access, + poseidon2_columns.as_ref(), + MultiChip::::poseidon2(next_slice).as_ref(), + local_multi_cols.poseidon2_receive_table, + local_multi_cols.poseidon2_1st_half_memory_access, + local_multi_cols.poseidon2_2nd_half_memory_access, + local_multi_cols.poseidon2_send_range_check, ); } } -// SAFETY: Each view is a valid interpretation of the underlying array. -impl MultiCols { - pub fn fri_fold(&self) -> &FriFoldCols { - unsafe { &self.instruction.fri_fold } + +impl MultiChip { + fn fri_fold_width() -> usize { + as BaseAir>::width(&FriFoldChip::::default()) + } + + fn fri_fold(row: &dyn Deref) -> FriFoldCols { + let row_slice: &[T] = row; + let fri_fold_width = Self::fri_fold_width::(); + let fri_fold_cols: &FriFoldCols = + (row_slice[NUM_MULTI_COLS..NUM_MULTI_COLS + fri_fold_width]).borrow(); + + *fri_fold_cols } - pub fn poseidon2(&self) -> &Poseidon2Cols { - unsafe { &self.instruction.poseidon2 } + fn poseidon2_width() -> usize { + as BaseAir>::width(&Poseidon2WideChip::::default()) + } + + fn poseidon2<'a, T>(row: impl Deref) -> Box + 'a> + where + T: Copy + 'a, + { + let row_slice: &[T] = &row; + let poseidon2_width = Self::poseidon2_width::(); + + Poseidon2WideChip::::convert::( + &row_slice[NUM_MULTI_COLS..NUM_MULTI_COLS + poseidon2_width], + ) } } #[cfg(test)] mod tests { - use itertools::Itertools; use std::time::Instant; use p3_baby_bear::BabyBear; use p3_baby_bear::DiffusionMatrixBabyBear; - use p3_field::AbstractField; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use p3_poseidon2::Poseidon2; use p3_poseidon2::Poseidon2ExternalMatrixGeneral; use sp1_core::stark::StarkGenericConfig; - use sp1_core::utils::inner_perm; use sp1_core::{ air::MachineAir, utils::{uni_stark_prove, uni_stark_verify, BabyBearPoseidon2}, }; use crate::multi::MultiChip; - use crate::{poseidon2::Poseidon2Event, runtime::ExecutionRecord}; - use p3_symmetric::Permutation; + use crate::poseidon2_wide::tests::generate_test_execution_record; + use crate::runtime::ExecutionRecord; #[test] fn prove_babybear() { let config = BabyBearPoseidon2::compressed(); let mut challenger = config.challenger(); - let chip = MultiChip::<5> { + let chip = MultiChip::<9> { fixed_log2_rows: None, }; - let test_inputs = (0..16) - .map(|i| [BabyBear::from_canonical_u32(i); 16]) - .collect_vec(); - - let gt: Poseidon2< - BabyBear, - Poseidon2ExternalMatrixGeneral, - DiffusionMatrixBabyBear, - 16, - 7, - > = inner_perm(); - - let expected_outputs = test_inputs - .iter() - .map(|input| gt.permute(*input)) - .collect::>(); - - let mut input_exec = ExecutionRecord::::default(); - for (input, output) in test_inputs.into_iter().zip_eq(expected_outputs) { - input_exec - .poseidon2_events - .push(Poseidon2Event::dummy_from_input(input, output)); - } + let input_exec = generate_test_execution_record(false); let trace: RowMajorMatrix = chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); println!( diff --git a/recursion/core/src/poseidon2/columns.rs b/recursion/core/src/poseidon2/columns.rs deleted file mode 100644 index fa12a655f2..0000000000 --- a/recursion/core/src/poseidon2/columns.rs +++ /dev/null @@ -1,62 +0,0 @@ -use sp1_derive::AlignedBorrow; - -use crate::{memory::MemoryReadWriteSingleCols, poseidon2_wide::external::WIDTH}; - -/// The column layout for the chip. -#[derive(AlignedBorrow, Clone, Copy)] -#[repr(C)] -pub struct Poseidon2Cols { - pub clk: T, - pub dst_input: T, - pub left_input: T, - pub right_input: T, - pub rounds: [T; 24], // 1 round for memory input; 1 round for initialize; 8 rounds for external; 13 rounds for internal; 1 round for memory output - pub do_receive: T, - pub do_memory: T, - pub round_specific_cols: RoundSpecificCols, - pub is_real: T, -} - -#[derive(AlignedBorrow, Clone, Copy)] -#[repr(C)] -pub union RoundSpecificCols { - computation: ComputationCols, - memory_access: MemAccessCols, -} - -// SAFETY: Each view is a valid interpretation of the underlying array. -impl RoundSpecificCols { - pub fn computation(&self) -> &ComputationCols { - unsafe { &self.computation } - } - - pub fn computation_mut(&mut self) -> &mut ComputationCols { - unsafe { &mut self.computation } - } - - pub fn memory_access(&self) -> &MemAccessCols { - unsafe { &self.memory_access } - } - - pub fn memory_access_mut(&mut self) -> &mut MemAccessCols { - unsafe { &mut self.memory_access } - } -} - -#[derive(AlignedBorrow, Clone, Copy)] -#[repr(C)] -pub struct ComputationCols { - pub input: [T; WIDTH], - pub add_rc: [T; WIDTH], - pub sbox_deg_3: [T; WIDTH], - pub sbox_deg_7: [T; WIDTH], - pub output: [T; WIDTH], -} - -#[derive(AlignedBorrow, Clone, Copy)] -#[repr(C)] -pub struct MemAccessCols { - pub addr_first_half: T, - pub addr_second_half: T, - pub mem_access: [MemoryReadWriteSingleCols; WIDTH], -} diff --git a/recursion/core/src/poseidon2/external.rs b/recursion/core/src/poseidon2/external.rs deleted file mode 100644 index 21f56edb0b..0000000000 --- a/recursion/core/src/poseidon2/external.rs +++ /dev/null @@ -1,572 +0,0 @@ -use core::borrow::Borrow; -use core::mem::size_of; -use p3_air::AirBuilder; -use p3_air::{Air, BaseAir}; -use p3_field::AbstractField; -use p3_matrix::Matrix; -use sp1_core::air::{BaseAirBuilder, ExtensionAirBuilder, SP1AirBuilder}; -use sp1_primitives::RC_16_30_U32; - -use crate::air::{RecursionInteractionAirBuilder, RecursionMemoryAirBuilder}; -use crate::memory::MemoryCols; -use crate::poseidon2_wide::{apply_m_4, internal_linear_layer}; -use crate::runtime::Opcode; - -use super::columns::Poseidon2Cols; - -/// The number of main trace columns for `AddChip`. -pub const NUM_POSEIDON2_COLS: usize = size_of::>(); - -/// The width of the permutation. -pub const WIDTH: usize = 16; - -/// A chip that implements addition for the opcode ADD. -#[derive(Default)] -pub struct Poseidon2Chip { - pub fixed_log2_rows: Option, - pub pad: bool, -} - -impl BaseAir for Poseidon2Chip { - fn width(&self) -> usize { - NUM_POSEIDON2_COLS - } -} - -impl Poseidon2Chip { - pub fn eval_poseidon2( - &self, - builder: &mut AB, - local: &Poseidon2Cols, - next: &Poseidon2Cols, - receive_table: AB::Var, - memory_access: AB::Var, - ) { - const NUM_ROUNDS_F: usize = 8; - const NUM_ROUNDS_P: usize = 13; - const ROUNDS_F_1_BEGINNING: usize = 2; // Previous rounds are memory read and initial. - const ROUNDS_P_BEGINNING: usize = ROUNDS_F_1_BEGINNING + NUM_ROUNDS_F / 2; - const ROUNDS_P_END: usize = ROUNDS_P_BEGINNING + NUM_ROUNDS_P; - const ROUND_F_2_END: usize = ROUNDS_P_END + NUM_ROUNDS_F / 2; - - let is_memory_read = local.rounds[0]; - let is_initial = local.rounds[1]; - - // First half of the external rounds. - let mut is_external_layer = (ROUNDS_F_1_BEGINNING..ROUNDS_P_BEGINNING) - .map(|i| local.rounds[i].into()) - .sum::(); - - // Second half of the external rounds. - is_external_layer += (ROUNDS_P_END..ROUND_F_2_END) - .map(|i| local.rounds[i].into()) - .sum::(); - let is_internal_layer = (ROUNDS_P_BEGINNING..ROUNDS_P_END) - .map(|i| local.rounds[i].into()) - .sum::(); - let is_memory_write = local.rounds[local.rounds.len() - 1]; - - self.eval_control_flow_and_inputs(builder, local, next); - - self.eval_syscall(builder, local, receive_table); - - self.eval_mem( - builder, - local, - next, - is_memory_read, - is_memory_write, - memory_access, - ); - - self.eval_computation( - builder, - local, - next, - is_initial.into(), - is_external_layer.clone(), - is_internal_layer.clone(), - NUM_ROUNDS_F + NUM_ROUNDS_P + 1, - ); - } - - fn eval_control_flow_and_inputs( - &self, - builder: &mut AB, - local: &Poseidon2Cols, - next: &Poseidon2Cols, - ) { - let num_total_rounds = local.rounds.len(); - for i in 0..num_total_rounds { - // Verify that the round flags are correct. - builder.assert_bool(local.rounds[i]); - - // Assert that the next round is correct. - builder - .when_transition() - .assert_eq(local.rounds[i], next.rounds[(i + 1) % num_total_rounds]); - - if i != num_total_rounds - 1 { - builder - .when_transition() - .when(local.rounds[i]) - .assert_eq(local.clk, next.clk); - builder - .when_transition() - .when(local.rounds[i]) - .assert_eq(local.dst_input, next.dst_input); - builder - .when_transition() - .when(local.rounds[i]) - .assert_eq(local.left_input, next.left_input); - builder - .when_transition() - .when(local.rounds[i]) - .assert_eq(local.right_input, next.right_input); - } - } - - // Ensure that at most one of the round flags is set. - let round_acc = local - .rounds - .iter() - .fold(AB::Expr::zero(), |acc, round_flag| acc + *round_flag); - builder.assert_bool(round_acc); - - // Verify the do_memory flag. - builder.assert_eq( - local.do_memory, - local.is_real * (local.rounds[0] + local.rounds[23]), - ); - - // Verify the do_receive flag. - builder.assert_eq(local.do_receive, local.is_real * local.rounds[0]); - - // Verify the first row starts at round 0. - builder.when_first_row().assert_one(local.rounds[0]); - // The round count is not a power of 2, so the last row should not be real. - builder.when_last_row().assert_zero(local.is_real); - - // Verify that all is_real flags within a round are equal. - let is_last_round = local.rounds[23]; - builder - .when_transition() - .when_not(is_last_round) - .assert_eq(local.is_real, next.is_real); - } - - fn eval_mem( - &self, - builder: &mut AB, - local: &Poseidon2Cols, - next: &Poseidon2Cols, - is_memory_read: AB::Var, - is_memory_write: AB::Var, - memory_access: AB::Var, - ) { - let memory_access_cols = local.round_specific_cols.memory_access(); - builder - .when(local.is_real) - .when(is_memory_read) - .assert_eq(local.left_input, memory_access_cols.addr_first_half); - builder - .when(local.is_real) - .when(is_memory_read) - .assert_eq(local.right_input, memory_access_cols.addr_second_half); - - builder - .when(local.is_real) - .when(is_memory_write) - .assert_eq(local.dst_input, memory_access_cols.addr_first_half); - builder.when(local.is_real).when(is_memory_write).assert_eq( - local.dst_input + AB::F::from_canonical_usize(WIDTH / 2), - memory_access_cols.addr_second_half, - ); - - for i in 0..WIDTH { - let addr = if i < WIDTH / 2 { - memory_access_cols.addr_first_half + AB::Expr::from_canonical_usize(i) - } else { - memory_access_cols.addr_second_half + AB::Expr::from_canonical_usize(i - WIDTH / 2) - }; - builder.recursion_eval_memory_access_single( - local.clk + AB::Expr::one() * is_memory_write, - addr, - &memory_access_cols.mem_access[i], - memory_access, - ); - builder.when(local.is_real).when(is_memory_read).assert_eq( - *memory_access_cols.mem_access[i].value(), - *memory_access_cols.mem_access[i].prev_value(), - ); - } - - // For the memory read round, need to connect the memory val to the input of the next - // computation round. - let next_computation_col = next.round_specific_cols.computation(); - for i in 0..WIDTH { - builder - .when_transition() - .when(local.is_real) - .when(is_memory_read) - .assert_eq( - *memory_access_cols.mem_access[i].value(), - next_computation_col.input[i], - ); - } - } - - #[allow(clippy::too_many_arguments)] - fn eval_computation( - &self, - builder: &mut AB, - local: &Poseidon2Cols, - next: &Poseidon2Cols, - is_initial: AB::Expr, - is_external_layer: AB::Expr, - is_internal_layer: AB::Expr, - rounds: usize, - ) { - let computation_cols = local.round_specific_cols.computation(); - - // Convert the u32 round constants to field elements. - let constants: [[AB::F; WIDTH]; 30] = RC_16_30_U32 - .iter() - .map(|round| round.map(AB::F::from_wrapped_u32)) - .collect::>() - .try_into() - .unwrap(); - - // Apply the round constants. - // - // Initial Layer: Don't apply the round constants. - // External Layers: Apply the round constants. - // Internal Layers: Only apply the round constants to the first element. - for i in 0..WIDTH { - let mut result: AB::Expr = computation_cols.input[i].into(); - for r in 0..rounds { - if i == 0 { - result += local.rounds[r + 2] - * constants[r][i] - * (is_external_layer.clone() + is_internal_layer.clone()); - } else { - result += local.rounds[r + 2] * constants[r][i] * is_external_layer.clone(); - } - } - builder - .when(local.is_real) - .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) - .assert_eq(result, computation_cols.add_rc[i]); - } - - // Apply the sbox. - // - // To differentiate between external and internal layers, we use a masking operation - // to only apply the state change to the first element for internal layers. - for i in 0..WIDTH { - let sbox_deg_3 = computation_cols.add_rc[i] - * computation_cols.add_rc[i] - * computation_cols.add_rc[i]; - builder - .when(local.is_real) - .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) - .assert_eq(computation_cols.sbox_deg_3[i], sbox_deg_3); - let sbox_deg_7 = computation_cols.sbox_deg_3[i] - * computation_cols.sbox_deg_3[i] - * computation_cols.add_rc[i]; - builder - .when(local.is_real) - .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) - .assert_eq(sbox_deg_7, computation_cols.sbox_deg_7[i]); - } - let sbox_result: [AB::Expr; WIDTH] = computation_cols - .sbox_deg_7 - .iter() - .enumerate() - .map(|(i, x)| { - // The masked first result of the sbox. - // - // Initial Layer: Pass through the result of the round constant layer. - // External Layer: Pass through the result of the sbox layer. - // Internal Layer: Pass through the result of the sbox layer. - if i == 0 { - is_initial.clone() * computation_cols.add_rc[i] - + (is_external_layer.clone() + is_internal_layer.clone()) * *x - } - // The masked result of the rest of the sbox. - // - // Initial layer: Pass through the result of the round constant layer. - // External layer: Pass through the result of the sbox layer. - // Internal layer: Pass through the result of the round constant layer. - else { - (is_initial.clone() + is_internal_layer.clone()) * computation_cols.add_rc[i] - + (is_external_layer.clone()) * *x - } - }) - .collect::>() - .try_into() - .unwrap(); - - // EXTERNAL LAYER + INITIAL LAYER - { - // First, we apply M_4 to each consecutive four elements of the state. - // In Appendix B's terminology, this replaces each x_i with x_i'. - let mut state: [AB::Expr; WIDTH] = sbox_result.clone(); - for i in (0..WIDTH).step_by(4) { - apply_m_4(&mut state[i..i + 4]); - } - - // Now, we apply the outer circulant matrix (to compute the y_i values). - // - // We first precompute the four sums of every four elements. - let sums: [AB::Expr; 4] = core::array::from_fn(|k| { - (0..WIDTH) - .step_by(4) - .map(|j| state[j + k].clone()) - .sum::() - }); - - // The formula for each y_i involves 2x_i' term and x_j' terms for each j that equals i mod 4. - // In other words, we can add a single copy of x_i' to the appropriate one of our precomputed sums. - for i in 0..WIDTH { - state[i] += sums[i % 4].clone(); - builder - .when(local.is_real) - .when(is_external_layer.clone() + is_initial.clone()) - .assert_eq(state[i].clone(), computation_cols.output[i]); - } - } - - // INTERNAL LAYER - { - // Use a simple matrix multiplication as the permutation. - let mut state: [AB::Expr; WIDTH] = sbox_result.clone(); - internal_linear_layer(&mut state); - builder - .when(local.is_real) - .when(is_internal_layer.clone()) - .assert_all_eq(state.clone(), computation_cols.output); - } - - // Assert that the round's output values are equal the the next round's input values. For the - // last computation round, assert athat the output values are equal to the output memory values. - let next_row_computation = next.round_specific_cols.computation(); - let next_row_memory_access = next.round_specific_cols.memory_access(); - for i in 0..WIDTH { - let next_round_value = builder.if_else( - local.rounds[22], - *next_row_memory_access.mem_access[i].value(), - next_row_computation.input[i], - ); - - builder - .when_transition() - .when(local.is_real) - .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) - .assert_eq(computation_cols.output[i], next_round_value); - } - } - - fn eval_syscall( - &self, - builder: &mut AB, - local: &Poseidon2Cols, - receive_table: AB::Var, - ) { - // Constraint that the operands are sent from the CPU table. - let operands: [AB::Expr; 4] = [ - local.clk.into(), - local.dst_input.into(), - local.left_input.into(), - local.right_input.into(), - ]; - builder.receive_table( - Opcode::Poseidon2Compress.as_field::(), - &operands, - receive_table, - ); - } - - pub const fn do_receive_table(local: &Poseidon2Cols) -> T { - local.do_receive - } - - pub fn do_memory_access(local: &Poseidon2Cols) -> T { - local.do_memory - } -} - -impl Air for Poseidon2Chip -where - AB: SP1AirBuilder, -{ - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local: &Poseidon2Cols = (*local).borrow(); - let next = main.row_slice(1); - let next: &Poseidon2Cols = (*next).borrow(); - - self.eval_poseidon2::( - builder, - local, - next, - Self::do_receive_table::(local), - Self::do_memory_access::(local), - ); - } -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use std::borrow::Borrow; - use std::time::Instant; - use zkhash::ark_ff::UniformRand; - - use p3_baby_bear::BabyBear; - use p3_baby_bear::DiffusionMatrixBabyBear; - use p3_matrix::{dense::RowMajorMatrix, Matrix}; - use p3_poseidon2::Poseidon2; - use p3_poseidon2::Poseidon2ExternalMatrixGeneral; - use sp1_core::stark::StarkGenericConfig; - use sp1_core::utils::inner_perm; - use sp1_core::{ - air::MachineAir, - utils::{uni_stark_prove, uni_stark_verify, BabyBearPoseidon2}, - }; - - use crate::{ - poseidon2::{Poseidon2Chip, Poseidon2Event}, - runtime::ExecutionRecord, - }; - use p3_symmetric::Permutation; - - use super::Poseidon2Cols; - - const ROWS_PER_PERMUTATION: usize = 24; - - #[test] - fn generate_trace() { - let chip = Poseidon2Chip { - fixed_log2_rows: None, - pad: true, - }; - - let rng = &mut rand::thread_rng(); - - let test_inputs: Vec<[BabyBear; 16]> = (0..16) - .map(|_| core::array::from_fn(|_| BabyBear::rand(rng))) - .collect_vec(); - - let gt: Poseidon2< - BabyBear, - Poseidon2ExternalMatrixGeneral, - DiffusionMatrixBabyBear, - 16, - 7, - > = inner_perm(); - - let expected_outputs = test_inputs - .iter() - .map(|input| gt.permute(*input)) - .collect::>(); - - let mut input_exec = ExecutionRecord::::default(); - for (input, output) in test_inputs.into_iter().zip_eq(expected_outputs.clone()) { - input_exec - .poseidon2_events - .push(Poseidon2Event::dummy_from_input(input, output)); - } - - let trace: RowMajorMatrix = - chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); - - for (i, expected_output) in expected_outputs.iter().enumerate() { - let row = trace.row(ROWS_PER_PERMUTATION * (i + 1) - 2).collect_vec(); - let cols: &Poseidon2Cols = row.as_slice().borrow(); - let computation_cols = cols.round_specific_cols.computation(); - assert_eq!(expected_output, &computation_cols.output); - } - } - - fn prove_babybear(inputs: Vec<[BabyBear; 16]>, outputs: Vec<[BabyBear; 16]>) { - let mut input_exec = ExecutionRecord::::default(); - for (input, output) in inputs.into_iter().zip_eq(outputs) { - input_exec - .poseidon2_events - .push(Poseidon2Event::dummy_from_input(input, output)); - } - - let chip = Poseidon2Chip { - fixed_log2_rows: None, - pad: true, - }; - let trace: RowMajorMatrix = - chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); - println!( - "trace dims is width: {:?}, height: {:?}", - trace.width(), - trace.height() - ); - - let start = Instant::now(); - let config = BabyBearPoseidon2::compressed(); - let mut challenger = config.challenger(); - let proof = uni_stark_prove(&config, &chip, &mut challenger, trace); - let duration = start.elapsed().as_secs_f64(); - println!("proof duration = {:?}", duration); - - let mut challenger: p3_challenger::DuplexChallenger< - BabyBear, - Poseidon2, - 16, - 8, - > = config.challenger(); - let start = Instant::now(); - uni_stark_verify(&config, &chip, &mut challenger, &proof) - .expect("expected proof to be valid"); - - let duration = start.elapsed().as_secs_f64(); - println!("verify duration = {:?}", duration); - } - - #[test] - fn prove_babybear_success() { - let rng = &mut rand::thread_rng(); - - let test_inputs: Vec<[BabyBear; 16]> = (0..16) - .map(|_| core::array::from_fn(|_| BabyBear::rand(rng))) - .collect_vec(); - - let gt: Poseidon2< - BabyBear, - Poseidon2ExternalMatrixGeneral, - DiffusionMatrixBabyBear, - 16, - 7, - > = inner_perm(); - - let expected_outputs = test_inputs - .iter() - .map(|input| gt.permute(*input)) - .collect::>(); - - prove_babybear(test_inputs, expected_outputs) - } - - #[test] - #[should_panic] - fn prove_babybear_failure() { - let rng = &mut rand::thread_rng(); - let test_inputs: Vec<[BabyBear; 16]> = (0..16) - .map(|_| core::array::from_fn(|_| BabyBear::rand(rng))) - .collect_vec(); - - let bad_outputs: Vec<[BabyBear; 16]> = (0..16) - .map(|_| core::array::from_fn(|_| BabyBear::rand(rng))) - .collect_vec(); - - prove_babybear(test_inputs, bad_outputs) - } -} diff --git a/recursion/core/src/poseidon2/mod.rs b/recursion/core/src/poseidon2/mod.rs deleted file mode 100644 index 2c42ce7219..0000000000 --- a/recursion/core/src/poseidon2/mod.rs +++ /dev/null @@ -1,47 +0,0 @@ -#![allow(clippy::needless_range_loop)] - -use crate::poseidon2::external::WIDTH; -mod columns; -pub mod external; -mod trace; -use crate::air::Block; -use crate::memory::MemoryRecord; -use p3_field::PrimeField32; - -pub use columns::Poseidon2Cols; -pub use external::Poseidon2Chip; - -#[derive(Debug, Clone)] -pub struct Poseidon2Event { - pub clk: F, - pub dst: F, // from a_val - pub left: F, // from b_val - pub right: F, // from c_val - pub input: [F; WIDTH], - pub result_array: [F; WIDTH], - pub input_records: [MemoryRecord; WIDTH], - pub result_records: [MemoryRecord; WIDTH], -} - -impl Poseidon2Event { - /// A way to construct a dummy event from an input array, used for testing. - pub fn dummy_from_input(input: [F; WIDTH], output: [F; WIDTH]) -> Self { - let input_records = core::array::from_fn(|i| { - MemoryRecord::new_read(F::zero(), Block::from(input[i]), F::one(), F::zero()) - }); - let output_records: [MemoryRecord; WIDTH] = core::array::from_fn(|i| { - MemoryRecord::new_read(F::zero(), Block::from(output[i]), F::two(), F::zero()) - }); - - Self { - clk: F::one(), - dst: F::zero(), - left: F::zero(), - right: F::zero(), - input, - result_array: [F::zero(); WIDTH], - input_records, - result_records: output_records, - } - } -} diff --git a/recursion/core/src/poseidon2/trace.rs b/recursion/core/src/poseidon2/trace.rs deleted file mode 100644 index 567c09fc7d..0000000000 --- a/recursion/core/src/poseidon2/trace.rs +++ /dev/null @@ -1,203 +0,0 @@ -use std::borrow::BorrowMut; - -use p3_field::PrimeField32; -use p3_matrix::dense::RowMajorMatrix; -use sp1_core::{air::MachineAir, utils::pad_rows_fixed}; -use sp1_primitives::RC_16_30_U32; -use tracing::instrument; - -use crate::{ - poseidon2_wide::{external_linear_layer, internal_linear_layer}, - runtime::{ExecutionRecord, RecursionProgram}, -}; - -use super::{ - external::{NUM_POSEIDON2_COLS, WIDTH}, - Poseidon2Chip, Poseidon2Cols, -}; - -impl MachineAir for Poseidon2Chip { - type Record = ExecutionRecord; - - type Program = RecursionProgram; - - fn name(&self) -> String { - "Poseidon2".to_string() - } - - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { - // This is a no-op. - } - - #[instrument(name = "generate poseidon2 trace", level = "debug", skip_all, fields(rows = input.poseidon2_events.len()))] - fn generate_trace( - &self, - input: &ExecutionRecord, - _: &mut ExecutionRecord, - ) -> RowMajorMatrix { - let mut rows = Vec::new(); - - // 1 round for memory input; 1 round for initialize; 8 rounds for external; 13 rounds for internal; 1 round for memory output - let rounds_f = 8; - let rounds_p = 13; - let rounds = rounds_f + rounds_p + 3; - let rounds_p_beginning = 2 + rounds_f / 2; - let p_end = rounds_p_beginning + rounds_p; - - for poseidon2_event in input.poseidon2_events.iter() { - let mut round_input = Default::default(); - for r in 0..rounds { - let mut row = [F::zero(); NUM_POSEIDON2_COLS]; - let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); - cols.is_real = F::one(); - - let is_receive = r == 0; - let is_memory_read = r == 0; - let is_initial_layer = r == 1; - let is_external_layer = - (r >= 2 && r < rounds_p_beginning) || (r >= p_end && r < p_end + rounds_f / 2); - let is_internal_layer = r >= rounds_p_beginning && r < p_end; - let is_memory_write = r == rounds - 1; - - let sum = (is_memory_read as u32) - + (is_initial_layer as u32) - + (is_external_layer as u32) - + (is_internal_layer as u32) - + (is_memory_write as u32); - assert!( - sum == 0 || sum == 1, - "{} {} {} {} {}", - is_memory_read, - is_initial_layer, - is_external_layer, - is_internal_layer, - is_memory_write - ); - - cols.clk = poseidon2_event.clk; - cols.dst_input = poseidon2_event.dst; - cols.left_input = poseidon2_event.left; - cols.right_input = poseidon2_event.right; - cols.rounds[r] = F::one(); - - if is_receive { - cols.do_receive = F::one(); - } - - if is_memory_read || is_memory_write { - let memory_access_cols = cols.round_specific_cols.memory_access_mut(); - - if is_memory_read { - memory_access_cols.addr_first_half = poseidon2_event.left; - memory_access_cols.addr_second_half = poseidon2_event.right; - for i in 0..WIDTH { - memory_access_cols.mem_access[i] - .populate(&poseidon2_event.input_records[i]); - } - } else { - memory_access_cols.addr_first_half = poseidon2_event.dst; - memory_access_cols.addr_second_half = - poseidon2_event.dst + F::from_canonical_usize(WIDTH / 2); - for i in 0..WIDTH { - memory_access_cols.mem_access[i] - .populate(&poseidon2_event.result_records[i]); - } - } - cols.do_memory = F::one(); - } else { - let computation_cols = cols.round_specific_cols.computation_mut(); - - if is_initial_layer { - round_input = poseidon2_event.input; - } - - computation_cols.input = round_input; - - if is_initial_layer { - // Don't apply the round constants. - computation_cols - .add_rc - .copy_from_slice(&computation_cols.input); - } else if is_external_layer { - // Apply the round constants. - for j in 0..WIDTH { - computation_cols.add_rc[j] = computation_cols.input[j] - + F::from_wrapped_u32(RC_16_30_U32[r - 2][j]); - } - } else { - // Apply the round constants only on the first element. - computation_cols - .add_rc - .copy_from_slice(&computation_cols.input); - computation_cols.add_rc[0] = - computation_cols.input[0] + F::from_wrapped_u32(RC_16_30_U32[r - 2][0]); - }; - - // Apply the sbox. - for j in 0..WIDTH { - let sbox_deg_3 = computation_cols.add_rc[j] - * computation_cols.add_rc[j] - * computation_cols.add_rc[j]; - computation_cols.sbox_deg_3[j] = sbox_deg_3; - computation_cols.sbox_deg_7[j] = - sbox_deg_3 * sbox_deg_3 * computation_cols.add_rc[j]; - } - - // What state to use for the linear layer. - let mut state = if is_initial_layer { - computation_cols.add_rc - } else if is_external_layer { - computation_cols.sbox_deg_7 - } else { - let mut state = computation_cols.add_rc; - state[0] = computation_cols.sbox_deg_7[0]; - state - }; - - // Apply either the external or internal linear layer. - if is_initial_layer || is_external_layer { - external_linear_layer(&mut state); - } else if is_internal_layer { - internal_linear_layer(&mut state) - } - - // Copy the state to the output. - computation_cols.output.copy_from_slice(&state); - - round_input = computation_cols.output; - } - - rows.push(row); - } - } - - let num_real_rows = rows.len(); - - // Pad the trace to a power of two. - if self.pad { - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_POSEIDON2_COLS], - self.fixed_log2_rows, - ); - } - - let mut round_num = 0; - for row in rows[num_real_rows..].iter_mut() { - let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); - cols.rounds[round_num] = F::one(); - - round_num = (round_num + 1) % rounds; - } - - // Convert the trace to a row major matrix. - RowMajorMatrix::new( - rows.into_iter().flatten().collect::>(), - NUM_POSEIDON2_COLS, - ) - } - - fn included(&self, record: &Self::Record) -> bool { - !record.poseidon2_events.is_empty() - } -} diff --git a/recursion/core/src/poseidon2_wide/air/control_flow.rs b/recursion/core/src/poseidon2_wide/air/control_flow.rs new file mode 100644 index 0000000000..50bcbd984f --- /dev/null +++ b/recursion/core/src/poseidon2_wide/air/control_flow.rs @@ -0,0 +1,365 @@ +use p3_air::AirBuilder; +use p3_field::AbstractField; +use sp1_core::{air::BaseAirBuilder, operations::IsZeroOperation}; + +use crate::{ + air::SP1RecursionAirBuilder, + poseidon2_wide::{ + columns::{ + control_flow::ControlFlow, opcode_workspace::OpcodeWorkspace, + syscall_params::SyscallParams, Poseidon2, + }, + Poseidon2WideChip, RATE, + }, + range_check::RangeCheckOpcode, +}; + +impl Poseidon2WideChip { + /// Constraints related to control flow. + pub(crate) fn eval_control_flow( + &self, + builder: &mut AB, + local_row: &dyn Poseidon2, + next_row: &dyn Poseidon2, + send_range_check: AB::Var, + ) where + AB::Var: 'static, + { + let local_control_flow = local_row.control_flow(); + let next_control_flow = next_row.control_flow(); + + let local_is_real = local_control_flow.is_compress + + local_control_flow.is_absorb + + local_control_flow.is_finalize; + let next_is_real = next_control_flow.is_compress + + next_control_flow.is_absorb + + next_control_flow.is_finalize; + + builder.assert_bool(local_control_flow.is_compress); + builder.assert_bool(local_control_flow.is_compress_output); + builder.assert_bool(local_control_flow.is_absorb); + builder.assert_bool(local_control_flow.is_finalize); + builder.assert_bool(local_control_flow.is_syscall_row); + builder.assert_bool(local_is_real.clone()); + + self.eval_global_control_flow( + builder, + local_control_flow, + next_control_flow, + local_row.syscall_params(), + next_row.syscall_params(), + local_row.opcode_workspace(), + next_row.opcode_workspace(), + local_is_real.clone(), + next_is_real.clone(), + ); + + self.eval_hash_control_flow( + builder, + local_control_flow, + local_row.opcode_workspace(), + next_row.opcode_workspace(), + local_row.syscall_params(), + send_range_check, + ); + } + + /// This function will verify that all hash rows are before the compress rows and that the first + /// row is the first absorb syscall. These constraints will require that there is at least one + /// absorb, finalize, and compress system call. + #[allow(clippy::too_many_arguments)] + fn eval_global_control_flow( + &self, + builder: &mut AB, + local_control_flow: &ControlFlow, + next_control_flow: &ControlFlow, + local_syscall_params: &SyscallParams, + next_syscall_params: &SyscallParams, + local_opcode_workspace: &OpcodeWorkspace, + next_opcode_workspace: &OpcodeWorkspace, + local_is_real: AB::Expr, + next_is_real: AB::Expr, + ) { + // We require that the first row is an absorb syscall and that the hash_num == 0. + let mut first_row_builder = builder.when_first_row(); + first_row_builder.assert_one(local_control_flow.is_absorb); + first_row_builder.assert_one(local_control_flow.is_syscall_row); + first_row_builder.assert_zero(local_syscall_params.absorb().hash_num); + first_row_builder.assert_one(local_opcode_workspace.absorb().is_first_hash_row); + + let mut transition_builder = builder.when_transition(); + + // For absorb rows, constrain the following: + // 1) next row is either an absorb or syscall finalize. + // 2) when last absorb row, then the next row is a syscall row. + // 2) hash_num == hash_num'. + { + let mut absorb_transition_builder = + transition_builder.when(local_control_flow.is_absorb); + absorb_transition_builder + .assert_one(next_control_flow.is_absorb + next_control_flow.is_finalize); + absorb_transition_builder + .when(local_opcode_workspace.absorb().is_last_row::()) + .assert_one(next_control_flow.is_syscall_row); + + absorb_transition_builder + .when(next_control_flow.is_absorb) + .assert_eq( + local_syscall_params.absorb().hash_num, + next_syscall_params.absorb().hash_num, + ); + absorb_transition_builder + .when(next_control_flow.is_finalize) + .assert_eq( + local_syscall_params.absorb().hash_num, + next_syscall_params.finalize().hash_num, + ); + } + + // For finalize rows, constrain the following: + // 1) next row is syscall compress or syscall absorb. + // 2) if next row is absorb -> hash_num + 1 == hash_num' + // 3) if next row is absorb -> is_first_hash' == true + { + let mut finalize_transition_builder = + transition_builder.when(local_control_flow.is_finalize); + + finalize_transition_builder + .assert_one(next_control_flow.is_absorb + next_control_flow.is_compress); + finalize_transition_builder.assert_one(next_control_flow.is_syscall_row); + + finalize_transition_builder + .when(next_control_flow.is_absorb) + .assert_eq( + local_syscall_params.finalize().hash_num + AB::Expr::one(), + next_syscall_params.absorb().hash_num, + ); + finalize_transition_builder + .when(next_control_flow.is_absorb) + .assert_one(next_opcode_workspace.absorb().is_first_hash_row); + } + + // For compress rows, constrain the following: + // 1) if compress syscall -> next row is a compress output + // 2) if compress output -> next row is a compress syscall or not real + { + transition_builder + .when(local_control_flow.is_compress) + .when(local_control_flow.is_syscall_row) + .assert_one(next_control_flow.is_compress_output); + + transition_builder + .when(local_control_flow.is_compress_output) + .assert_one( + next_control_flow.is_compress + (AB::Expr::one() - next_is_real.clone()), + ); + + transition_builder + .when(local_control_flow.is_compress_output) + .when(next_control_flow.is_compress) + .assert_one(next_control_flow.is_syscall_row); + } + + // Constrain that there is only one is_real -> not is real transition. Also contrain that + // the last real row is a compress output row. + { + transition_builder + .when_not(local_is_real.clone()) + .assert_zero(next_is_real.clone()); + + transition_builder + .when(local_is_real.clone()) + .when_not(next_is_real.clone()) + .assert_one(local_control_flow.is_compress_output); + + builder + .when_last_row() + .when(local_is_real.clone()) + .assert_one(local_control_flow.is_compress_output); + } + } + + #[allow(clippy::too_many_arguments)] + fn eval_hash_control_flow( + &self, + builder: &mut AB, + local_control_flow: &ControlFlow, + local_opcode_workspace: &OpcodeWorkspace, + next_opcode_workspace: &OpcodeWorkspace, + local_syscall_params: &SyscallParams, + send_range_check: AB::Var, + ) { + let local_hash_workspace = local_opcode_workspace.absorb(); + let next_hash_workspace = next_opcode_workspace.absorb(); + let last_row_ending_cursor_is_seven = + local_hash_workspace.last_row_ending_cursor_is_seven.result; + + // Constrain the materialized control flow flags. + { + let mut absorb_builder = builder.when(local_control_flow.is_absorb); + + absorb_builder.assert_eq( + local_hash_workspace.is_syscall_not_last_row, + local_control_flow.is_syscall_row + * (AB::Expr::one() - local_hash_workspace.is_last_row::()), + ); + absorb_builder.assert_eq( + local_hash_workspace.not_syscall_not_last_row, + (AB::Expr::one() - local_control_flow.is_syscall_row) + * (AB::Expr::one() - local_hash_workspace.is_last_row::()), + ); + absorb_builder.assert_eq( + local_hash_workspace.is_syscall_is_last_row, + local_control_flow.is_syscall_row * local_hash_workspace.is_last_row::(), + ); + absorb_builder.assert_eq( + local_hash_workspace.not_syscall_is_last_row, + (AB::Expr::one() - local_control_flow.is_syscall_row) + * local_hash_workspace.is_last_row::(), + ); + absorb_builder.assert_eq( + local_hash_workspace.is_last_row_ending_cursor_is_seven, + local_hash_workspace.is_last_row::() * last_row_ending_cursor_is_seven, + ); + absorb_builder.assert_eq( + local_hash_workspace.is_last_row_ending_cursor_not_seven, + local_hash_workspace.is_last_row::() + * (AB::Expr::one() - last_row_ending_cursor_is_seven), + ); + + builder.assert_eq( + local_control_flow.is_absorb_not_last_row, + local_control_flow.is_absorb + * (AB::Expr::one() - local_hash_workspace.is_last_row::()), + ); + + builder.assert_eq( + local_control_flow.is_absorb_no_perm, + local_control_flow.is_absorb + * (AB::Expr::one() - local_hash_workspace.do_perm::()), + ) + } + + // For the absorb syscall row, ensure correct value of num_remaining_rows, last_row_num_consumed, + // and num_remaining_rows_is_zero. + { + let mut absorb_builder = builder.when(local_control_flow.is_absorb); + + // Verify that state_cursor + syscall input_len - 1 == num_remaining_rows * RATE + last_row_ending_cursor. + // The minus one is needed, since `last_row_ending_cursor` is inclusive of the last element, + // while state_cursor + syscall input_len is not. + absorb_builder + .when(local_control_flow.is_syscall_row) + .assert_eq( + local_hash_workspace.state_cursor + local_syscall_params.absorb().input_len + - AB::Expr::one(), + local_hash_workspace.num_remaining_rows * AB::Expr::from_canonical_usize(RATE) + + local_hash_workspace.last_row_ending_cursor, + ); + + // Range check that last_row_ending_cursor is between [0, 7]. + (0..3).for_each(|i| { + absorb_builder.assert_bool(local_hash_workspace.last_row_ending_cursor_bitmap[i]) + }); + let expected_last_row_ending_cursor: AB::Expr = local_hash_workspace + .last_row_ending_cursor_bitmap + .iter() + .zip(0..3) + .map(|(bit, exp)| *bit * AB::Expr::from_canonical_u32(2u32.pow(exp))) + .sum::(); + absorb_builder + .when(local_control_flow.is_syscall_row) + .assert_eq( + local_hash_workspace.last_row_ending_cursor, + expected_last_row_ending_cursor, + ); + + // Range check that num_remaining_rows is between [0, 2^18-1]. + builder.send_range_check( + AB::Expr::from_canonical_u8(RangeCheckOpcode::U16 as u8), + local_hash_workspace.num_remaining_rows, + send_range_check, + ); + } + + // For all non last absorb rows, verify that num_remaining_rows decrements and + // that last_row_ending_cursor is copied down. + { + let mut transition_builder = builder.when_transition(); + let mut absorb_transition_builder = + transition_builder.when(local_control_flow.is_absorb); + + absorb_transition_builder + .when_not(local_hash_workspace.is_last_row::()) + .assert_eq( + next_hash_workspace.num_remaining_rows, + local_hash_workspace.num_remaining_rows - AB::Expr::one(), + ); + + // Copy down the last_row_ending_cursor value within the absorb call. + absorb_transition_builder + .when_not(local_hash_workspace.is_last_row::()) + .assert_eq( + next_hash_workspace.last_row_ending_cursor, + local_hash_workspace.last_row_ending_cursor, + ); + } + + // Constrain the state cursor. There are three constraints: + // 1) For the first hash row, verify that state_cursor == 0. + // 2) For the last absorb rows, verify that constrain + // state_cursor' = (last_row_ending_cursor + 1) % RATE. + // 3) For all non syscall rows, the state_cursor should be 0. + { + let mut absorb_builder = builder.when(local_control_flow.is_absorb); + + absorb_builder + .when(local_hash_workspace.is_first_hash_row) + .assert_zero(local_hash_workspace.state_cursor); + + absorb_builder + .when(local_hash_workspace.is_last_row_ending_cursor_is_seven) + .assert_zero(next_hash_workspace.state_cursor); + + absorb_builder + .when(local_hash_workspace.is_last_row_ending_cursor_not_seven) + .assert_eq( + next_hash_workspace.state_cursor, + local_hash_workspace.last_row_ending_cursor + AB::Expr::one(), + ); + + absorb_builder + .when_not(local_control_flow.is_syscall_row) + .assert_zero(local_hash_workspace.state_cursor); + } + + // Eval the absorb's iszero operations. + { + // Drop absorb_builder so that builder can be used in the IsZeroOperation eval. + IsZeroOperation::::eval( + builder, + local_hash_workspace.last_row_ending_cursor - AB::Expr::from_canonical_usize(7), + local_hash_workspace.last_row_ending_cursor_is_seven, + local_control_flow.is_absorb.into(), + ); + + IsZeroOperation::::eval( + builder, + local_hash_workspace.num_remaining_rows.into(), + local_hash_workspace.num_remaining_rows_is_zero, + local_control_flow.is_absorb.into(), + ); + } + + // Apply control flow constraints for finalize. + { + // Eval state_cursor_is_zero. + IsZeroOperation::::eval( + builder, + local_opcode_workspace.finalize().state_cursor.into(), + local_opcode_workspace.finalize().state_cursor_is_zero, + local_control_flow.is_finalize.into(), + ); + } + } +} diff --git a/recursion/core/src/poseidon2_wide/air/memory.rs b/recursion/core/src/poseidon2_wide/air/memory.rs new file mode 100644 index 0000000000..506de55eee --- /dev/null +++ b/recursion/core/src/poseidon2_wide/air/memory.rs @@ -0,0 +1,222 @@ +use p3_air::AirBuilder; +use p3_field::AbstractField; +use sp1_core::air::BaseAirBuilder; + +use crate::{ + air::SP1RecursionAirBuilder, + memory::MemoryCols, + poseidon2_wide::{ + columns::{ + control_flow::ControlFlow, memory::Memory, opcode_workspace::OpcodeWorkspace, + syscall_params::SyscallParams, + }, + Poseidon2WideChip, WIDTH, + }, +}; + +impl Poseidon2WideChip { + /// Eval the memory related columns. + #[allow(clippy::too_many_arguments)] + pub(crate) fn eval_mem( + &self, + builder: &mut AB, + syscall_params: &SyscallParams, + local_memory: &Memory, + next_memory: &Memory, + opcode_workspace: &OpcodeWorkspace, + control_flow: &ControlFlow, + first_half_memory_access: [AB::Var; WIDTH / 2], + second_half_memory_access: AB::Var, + ) { + let clk = syscall_params.get_raw_params()[0]; + let is_real = control_flow.is_compress + control_flow.is_absorb + control_flow.is_finalize; + + // Constrain the memory flags. + for i in 0..WIDTH / 2 { + builder.assert_bool(local_memory.memory_slot_used[i]); + + // The memory slot flag will be used as the memory access multiplicity flag, so we need to + // ensure that those values are zero for all non real rows. + builder + .when_not(is_real.clone()) + .assert_zero(local_memory.memory_slot_used[i]); + + // For compress and finalize, all of the slots should be true. + builder + .when(control_flow.is_compress + control_flow.is_finalize) + .assert_one(local_memory.memory_slot_used[i]); + + // For absorb, need to make sure the memory_slots_used is consistent with the start_cursor and + // end_cursor (i.e. start_cursor + num_consumed); + self.eval_absorb_memory_slots(builder, control_flow, local_memory, opcode_workspace); + } + + // Verify the start_addr column. + { + // For compress syscall rows, the start_addr should be the param's left ptr. + builder + .when(control_flow.is_compress * control_flow.is_syscall_row) + .assert_eq(syscall_params.compress().left_ptr, local_memory.start_addr); + + // For compress output rows, the start_addr should be the param's dst ptr. + builder + .when(control_flow.is_compress_output) + .assert_eq(syscall_params.compress().dst_ptr, local_memory.start_addr); + + // For absorb syscall rows, the start_addr should initially be from the syscall param's + // input_ptr, and for subsequent rows, it's incremented by the number of consumed elements. + builder + .when(control_flow.is_absorb) + .when(control_flow.is_syscall_row) + .assert_eq(syscall_params.absorb().input_ptr, local_memory.start_addr); + builder.when(control_flow.is_absorb_not_last_row).assert_eq( + next_memory.start_addr, + local_memory.start_addr + opcode_workspace.absorb().num_consumed::(), + ); + + // For finalize syscall rows, the start_addr should be the param's output ptr. + builder.when(control_flow.is_finalize).assert_eq( + syscall_params.finalize().output_ptr, + local_memory.start_addr, + ); + } + + // Contrain memory access for the first half of the memory accesses. + { + let mut addr: AB::Expr = local_memory.start_addr.into(); + for i in 0..WIDTH / 2 { + builder.recursion_eval_memory_access_single( + clk + control_flow.is_compress_output, + addr.clone(), + &local_memory.memory_accesses[i], + first_half_memory_access[i], + ); + + let compress_syscall_row = control_flow.is_compress * control_flow.is_syscall_row; + // For read only accesses, assert the value didn't change. + builder + .when(compress_syscall_row + control_flow.is_absorb) + .assert_eq( + *local_memory.memory_accesses[i].prev_value(), + *local_memory.memory_accesses[i].value(), + ); + + addr = addr.clone() + local_memory.memory_slot_used[i].into(); + } + } + + // Contrain memory access for the 2nd half of the memory accesses. + { + let compress_workspace = opcode_workspace.compress(); + + // Verify the start addr. + let is_compress_syscall = control_flow.is_compress * control_flow.is_syscall_row; + builder.when(is_compress_syscall.clone()).assert_eq( + compress_workspace.start_addr, + syscall_params.compress().right_ptr, + ); + builder.when(control_flow.is_compress_output).assert_eq( + compress_workspace.start_addr, + syscall_params.compress().dst_ptr + AB::Expr::from_canonical_usize(WIDTH / 2), + ); + + let mut addr: AB::Expr = compress_workspace.start_addr.into(); + for i in 0..WIDTH / 2 { + builder.recursion_eval_memory_access_single( + clk + control_flow.is_compress_output, + addr.clone(), + &compress_workspace.memory_accesses[i], + second_half_memory_access, + ); + + // For read only accesses, assert the value didn't change. + builder.when(is_compress_syscall.clone()).assert_eq( + *compress_workspace.memory_accesses[i].prev_value(), + *compress_workspace.memory_accesses[i].value(), + ); + + addr = addr.clone() + AB::Expr::one(); + } + } + } + + fn eval_absorb_memory_slots( + &self, + builder: &mut AB, + control_flow: &ControlFlow, + local_memory: &Memory, + opcode_workspace: &OpcodeWorkspace, + ) { + // To verify that the absorb memory slots are correct, we take the derivative of the memory slots, + // (e.g. memory_slot_used[i] - memory_slot_used[i - 1]), and assert the following: + // 1) when start_mem_idx_bitmap[i] == 1 -> derivative == 1 + // 2) when end_mem_idx_bitmap[i + 1] == 1 -> derivative == -1 + // 3) when start_mem_idx_bitmap[i] == 0 and end_mem_idx_bitmap[i + 1] == 0 -> derivative == 0 + let mut absorb_builder = builder.when(control_flow.is_absorb); + + let start_mem_idx_bitmap = opcode_workspace.absorb().start_mem_idx_bitmap; + let end_mem_idx_bitmap = opcode_workspace.absorb().end_mem_idx_bitmap; + for i in 0..WIDTH / 2 { + let derivative: AB::Expr = if i == 0 { + local_memory.memory_slot_used[i].into() + } else { + local_memory.memory_slot_used[i] - local_memory.memory_slot_used[i - 1] + }; + + let is_start_mem_idx = start_mem_idx_bitmap[i].into(); + + let is_previous_end_mem_idx = if i == 0 { + AB::Expr::zero() + } else { + end_mem_idx_bitmap[i - 1].into() + }; + + absorb_builder + .when(is_start_mem_idx.clone()) + .assert_one(derivative.clone()); + + absorb_builder + .when(is_previous_end_mem_idx.clone()) + .assert_zero(derivative.clone() + AB::Expr::one()); + + absorb_builder + .when_not(is_start_mem_idx + is_previous_end_mem_idx) + .assert_zero(derivative); + } + + // Verify that all elements of start_mem_idx_bitmap and end_mem_idx_bitmap are bool. + start_mem_idx_bitmap.iter().for_each(|bit| { + absorb_builder.assert_bool(*bit); + }); + end_mem_idx_bitmap.iter().for_each(|bit| { + absorb_builder.assert_bool(*bit); + }); + + // Verify correct value of start_mem_idx_bitmap and end_mem_idx_bitmap. + let start_mem_idx: AB::Expr = start_mem_idx_bitmap + .iter() + .enumerate() + .map(|(i, bit)| AB::Expr::from_canonical_usize(i) * *bit) + .sum(); + absorb_builder.assert_eq(start_mem_idx, opcode_workspace.absorb().state_cursor); + + let end_mem_idx: AB::Expr = end_mem_idx_bitmap + .iter() + .enumerate() + .map(|(i, bit)| AB::Expr::from_canonical_usize(i) * *bit) + .sum(); + + // When we are not in the last row, end_mem_idx should be zero. + absorb_builder + .when_not(opcode_workspace.absorb().is_last_row::()) + .assert_zero(end_mem_idx.clone()); + + // When we are in the last row, end_mem_idx bitmap should equal last_row_ending_cursor. + absorb_builder + .when(opcode_workspace.absorb().is_last_row::()) + .assert_eq( + end_mem_idx, + opcode_workspace.absorb().last_row_ending_cursor, + ); + } +} diff --git a/recursion/core/src/poseidon2_wide/air/mod.rs b/recursion/core/src/poseidon2_wide/air/mod.rs new file mode 100644 index 0000000000..e69ecb066e --- /dev/null +++ b/recursion/core/src/poseidon2_wide/air/mod.rs @@ -0,0 +1,204 @@ +//! The air module contains the AIR constraints for the poseidon2 chip. Those constraints will +//! enforce the following properties: +//! +//! # Layout of the poseidon2 chip: +//! +//! All the hash related rows should be in the first part of the chip and all the compress +//! related rows in the second part. E.g. the chip should has this format: +//! +//! absorb row (for hash num 1) +//! absorb row (for hash num 1) +//! absorb row (for hash num 1) +//! finalize row (for hash num 1) +//! absorb row (for hash num 2) +//! absorb row (for hash num 2) +//! finalize row (for hash num 2) +//! . +//! . +//! . +//! compress syscall/input row +//! compress output row +//! +//! # Absorb rows +//! +//! For absorb rows, the AIR needs to ensure that all of the input is written into the hash state +//! and that its written into the correct parts of that state. To do this, the AIR will first ensure +//! the correct values for num_remaining_rows (e.g. total number of rows of an absorb syscall) and +//! the last_row_ending_cursor. It does this by checking the following: +//! +//! 1. start_state_cursor + syscall_input_len == num_remaining_rows * RATE + last_row_ending_cursor +//! 2. range check syscall_input_len to be [0, 2^16 - 1] +//! 3. range check last_row_ending_cursor to be [0, RATE] +//! +//! For all subsequent absorb rows, the num_remaining_rows will be decremented by 1, and the +//! last_row_ending_cursor will be copied down to all of the rows. Also, for the next absorb/finalize +//! syscall, its state_cursor is set to (last_row_ending_cursor + 1) % RATE. +//! +//! From num_remaining_rows and syscall column, we know the absorb 's first row and last row. +//! From that fact, we can then enforce the following state writes. +//! +//! 1. is_first_row && is_last_row -> state writes are [state_cursor..state_cursor + last_row_ending_cursor] +//! 2. is_first_row && !is_last_row -> state writes are [state_cursor..RATE - 1] +//! 3. !is_first_row && !is_last_row -> state writes are [0..RATE - 1] +//! 4. !is_first_row && is_last_row -> state writes are [0..last_row_ending_cursor] +//! +//! From the state writes range, we can then populate a bitmap that specifies which state elements +//! should be overwritten (stored in Memory.memory_slot_used columns). To verify that this bitmap +//! is correct, we utilize the column's derivative (memory_slot_used[i] - memory_slot_used[i-1], +//! where memory_slot_used[-1] is 0). +//! +//! 1. When idx == state write start_idx -> derivative == 1 +//! 2. When idx == (state write end_idx - 1) -> derivative == -1 +//! 3. For all other cases, derivative == 0 +//! +//! In addition to determining the hash state writes, the AIR also needs to ensure that the do_perm +//! flag is correct (which is used to determine if a permutation should be done). It does this +//! by enforcing the following. +//! +//! 1. is_first_row && !is_last_row -> do_perm == 1 +//! 2. !is_first_row && !is_last_row -> do_perm == 1 +//! 3. is_last_row && last_row_ending_cursor == RATE - 1 -> do_perm == 1 +//! 4. is_last_row && last_row_ending_cursor != RATE - 1 -> do_perm == 0 +//! +//! # Finalize rows +//! +//! For finalize, the main flag that needs to be checked is do_perm. If state_cursor == 0, then +//! do_perm should be 0, otherwise it should be 1. If state_cursor == 0, that means that the +//! previous row did a perm. +//! +//! # Compress rows +//! +//! For compress, the main invariants that needs to be checked is that all syscall compress rows +//! verifies the correct memory read accesses, does the permutation, and copies the permuted value +//! into the next row. That row should then verify the correct memory write accesses. + +use p3_air::{Air, BaseAir}; +use p3_matrix::Matrix; + +use crate::air::SP1RecursionAirBuilder; + +pub mod control_flow; +pub mod memory; +pub mod permutation; +pub mod state_transition; +pub mod syscall_params; + +use super::{ + columns::{Poseidon2, NUM_POSEIDON2_DEGREE3_COLS, NUM_POSEIDON2_DEGREE9_COLS}, + Poseidon2WideChip, WIDTH, +}; + +impl BaseAir for Poseidon2WideChip { + fn width(&self) -> usize { + if DEGREE == 3 { + NUM_POSEIDON2_DEGREE3_COLS + } else if DEGREE == 9 || DEGREE == 17 { + NUM_POSEIDON2_DEGREE9_COLS + } else { + panic!("Unsupported degree: {}", DEGREE); + } + } +} + +impl Air for Poseidon2WideChip +where + AB: SP1RecursionAirBuilder, + AB::Var: 'static, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local_row = Self::convert::(main.row_slice(0)); + let next_row = Self::convert::(main.row_slice(1)); + + // Dummy constraints to normalize to DEGREE. + let lhs = (0..DEGREE) + .map(|_| local_row.control_flow().is_compress.into()) + .product::(); + let rhs = (0..DEGREE) + .map(|_| local_row.control_flow().is_compress.into()) + .product::(); + builder.assert_eq(lhs, rhs); + + self.eval_poseidon2( + builder, + local_row.as_ref(), + next_row.as_ref(), + local_row.control_flow().is_syscall_row, + local_row.memory().memory_slot_used, + local_row.control_flow().is_compress, + local_row.control_flow().is_absorb, + ); + } +} + +impl Poseidon2WideChip { + #[allow(clippy::too_many_arguments)] + pub(crate) fn eval_poseidon2( + &self, + builder: &mut AB, + local_row: &dyn Poseidon2, + next_row: &dyn Poseidon2, + receive_syscall: AB::Var, + first_half_memory_access: [AB::Var; WIDTH / 2], + second_half_memory_access: AB::Var, + send_range_check: AB::Var, + ) where + AB: SP1RecursionAirBuilder, + AB::Var: 'static, + { + let local_control_flow = local_row.control_flow(); + let next_control_flow = next_row.control_flow(); + let local_syscall = local_row.syscall_params(); + let next_syscall = next_row.syscall_params(); + let local_memory = local_row.memory(); + let next_memory = next_row.memory(); + let local_perm = local_row.permutation(); + let local_opcode_workspace = local_row.opcode_workspace(); + let next_opcode_workspace = next_row.opcode_workspace(); + + // Check that all the control flow columns are correct. + self.eval_control_flow(builder, local_row, next_row, send_range_check); + + // Check that the syscall columns are correct. + self.eval_syscall_params( + builder, + local_syscall, + next_syscall, + local_control_flow, + next_control_flow, + receive_syscall, + ); + + // Check that all the memory access columns are correct. + self.eval_mem( + builder, + local_syscall, + local_memory, + next_memory, + local_opcode_workspace, + local_control_flow, + first_half_memory_access, + second_half_memory_access, + ); + + // Check that the permutation columns are correct. + self.eval_perm( + builder, + local_perm.as_ref(), + local_memory, + local_opcode_workspace, + local_control_flow, + ); + + // Check that the permutation output is copied to the next row correctly. + self.eval_state_transition( + builder, + local_control_flow, + local_opcode_workspace, + next_opcode_workspace, + local_perm.as_ref(), + local_memory, + next_memory, + ); + } +} diff --git a/recursion/core/src/poseidon2_wide/air/permutation.rs b/recursion/core/src/poseidon2_wide/air/permutation.rs new file mode 100644 index 0000000000..c9920a8a2e --- /dev/null +++ b/recursion/core/src/poseidon2_wide/air/permutation.rs @@ -0,0 +1,177 @@ +use std::array; + +use p3_field::AbstractField; +use sp1_primitives::RC_16_30_U32; + +use crate::{ + air::SP1RecursionAirBuilder, + memory::MemoryCols, + poseidon2_wide::{ + columns::{ + control_flow::ControlFlow, memory::Memory, opcode_workspace::OpcodeWorkspace, + permutation::Permutation, + }, + external_linear_layer, internal_linear_layer, Poseidon2WideChip, NUM_EXTERNAL_ROUNDS, + NUM_INTERNAL_ROUNDS, WIDTH, + }, +}; + +impl Poseidon2WideChip { + pub(crate) fn eval_perm( + &self, + builder: &mut AB, + perm_cols: &dyn Permutation, + memory: &Memory, + opcode_workspace: &OpcodeWorkspace, + control_flow: &ControlFlow, + ) { + // Construct the input array of the permutation. That array is dependent on the row type. + // For compress_syscall rows, the input is from the memory access values. For absorb, the + // input is the previous state, with select elements being read from the memory access values. + // For finalize, the input is the previous state. + let input: [AB::Expr; WIDTH] = array::from_fn(|i| { + let previous_state = opcode_workspace.absorb().previous_state[i]; + + let (compress_input, absorb_input, finalize_input) = if i < WIDTH / 2 { + let mem_value = *memory.memory_accesses[i].value(); + + let compress_input = mem_value; + let absorb_input = + builder.if_else(memory.memory_slot_used[i], mem_value, previous_state); + let finalize_input = previous_state.into(); + + (compress_input, absorb_input, finalize_input) + } else { + let compress_input = + *opcode_workspace.compress().memory_accesses[i - WIDTH / 2].value(); + let absorb_input = previous_state.into(); + let finalize_input = previous_state.into(); + + (compress_input, absorb_input, finalize_input) + }; + + control_flow.is_compress * compress_input + + control_flow.is_absorb * absorb_input + + control_flow.is_finalize * finalize_input + }); + + // Apply the initial round. + let initial_round_output = { + let mut initial_round_output = input; + external_linear_layer(&mut initial_round_output); + initial_round_output + }; + let external_round_0_state: [AB::Expr; WIDTH] = core::array::from_fn(|i| { + let state = perm_cols.external_rounds_state()[0]; + state[i].into() + }); + + builder.assert_all_eq(external_round_0_state.clone(), initial_round_output); + + // Apply the first half of external rounds. + for r in 0..NUM_EXTERNAL_ROUNDS / 2 { + self.eval_external_round(builder, perm_cols, r); + } + + // Apply the internal rounds. + self.eval_internal_rounds(builder, perm_cols); + + // Apply the second half of external rounds. + for r in NUM_EXTERNAL_ROUNDS / 2..NUM_EXTERNAL_ROUNDS { + self.eval_external_round(builder, perm_cols, r); + } + } + + fn eval_external_round( + &self, + builder: &mut AB, + perm_cols: &dyn Permutation, + r: usize, + ) { + let external_state = perm_cols.external_rounds_state()[r]; + + // Add the round constants. + let round = if r < NUM_EXTERNAL_ROUNDS / 2 { + r + } else { + r + NUM_INTERNAL_ROUNDS + }; + let add_rc: [AB::Expr; WIDTH] = core::array::from_fn(|i| { + external_state[i].into() + AB::F::from_wrapped_u32(RC_16_30_U32[round][i]) + }); + + // Apply the sboxes. + // See `populate_external_round` for why we don't have columns for the sbox output here. + let mut sbox_deg_7: [AB::Expr; WIDTH] = core::array::from_fn(|_| AB::Expr::zero()); + let mut sbox_deg_3: [AB::Expr; WIDTH] = core::array::from_fn(|_| AB::Expr::zero()); + for i in 0..WIDTH { + let calculated_sbox_deg_3 = add_rc[i].clone() * add_rc[i].clone() * add_rc[i].clone(); + + if let Some(external_sbox) = perm_cols.external_rounds_sbox() { + builder.assert_eq(external_sbox[r][i].into(), calculated_sbox_deg_3); + sbox_deg_3[i] = external_sbox[r][i].into(); + } else { + sbox_deg_3[i] = calculated_sbox_deg_3; + } + + sbox_deg_7[i] = sbox_deg_3[i].clone() * sbox_deg_3[i].clone() * add_rc[i].clone(); + } + + // Apply the linear layer. + let mut state = sbox_deg_7; + external_linear_layer(&mut state); + + let next_state_cols = if r == NUM_EXTERNAL_ROUNDS / 2 - 1 { + perm_cols.internal_rounds_state() + } else if r == NUM_EXTERNAL_ROUNDS - 1 { + perm_cols.perm_output() + } else { + &perm_cols.external_rounds_state()[r + 1] + }; + for i in 0..WIDTH { + builder.assert_eq(next_state_cols[i], state[i].clone()); + } + } + + fn eval_internal_rounds( + &self, + builder: &mut AB, + perm_cols: &dyn Permutation, + ) { + let state = &perm_cols.internal_rounds_state(); + let s0 = perm_cols.internal_rounds_s0(); + let mut state: [AB::Expr; WIDTH] = core::array::from_fn(|i| state[i].into()); + for r in 0..NUM_INTERNAL_ROUNDS { + // Add the round constant. + let round = r + NUM_EXTERNAL_ROUNDS / 2; + let add_rc = if r == 0 { + state[0].clone() + } else { + s0[r - 1].into() + } + AB::Expr::from_wrapped_u32(RC_16_30_U32[round][0]); + + let mut sbox_deg_3 = add_rc.clone() * add_rc.clone() * add_rc.clone(); + if let Some(internal_sbox) = perm_cols.internal_rounds_sbox() { + builder.assert_eq(internal_sbox[r], sbox_deg_3); + sbox_deg_3 = internal_sbox[r].into(); + } + + // See `populate_internal_rounds` for why we don't have columns for the sbox output here. + let sbox_deg_7 = sbox_deg_3.clone() * sbox_deg_3.clone() * add_rc.clone(); + + // Apply the linear layer. + // See `populate_internal_rounds` for why we don't have columns for the new state here. + state[0] = sbox_deg_7.clone(); + internal_linear_layer(&mut state); + + if r < NUM_INTERNAL_ROUNDS - 1 { + builder.assert_eq(s0[r], state[0].clone()); + } + } + + let external_state = perm_cols.external_rounds_state()[NUM_EXTERNAL_ROUNDS / 2]; + for i in 0..WIDTH { + builder.assert_eq(external_state[i], state[i].clone()) + } + } +} diff --git a/recursion/core/src/poseidon2_wide/air/state_transition.rs b/recursion/core/src/poseidon2_wide/air/state_transition.rs new file mode 100644 index 0000000000..1b4b522a5d --- /dev/null +++ b/recursion/core/src/poseidon2_wide/air/state_transition.rs @@ -0,0 +1,123 @@ +use std::array; + +use p3_air::AirBuilder; +use sp1_core::{air::BaseAirBuilder, utils::DIGEST_SIZE}; + +use crate::{ + air::SP1RecursionAirBuilder, + memory::MemoryCols, + poseidon2_wide::{ + columns::{ + control_flow::ControlFlow, memory::Memory, opcode_workspace::OpcodeWorkspace, + permutation::Permutation, + }, + Poseidon2WideChip, WIDTH, + }, +}; + +impl Poseidon2WideChip { + #[allow(clippy::too_many_arguments)] + pub(crate) fn eval_state_transition( + &self, + builder: &mut AB, + control_flow: &ControlFlow, + local_opcode_workspace: &OpcodeWorkspace, + next_opcode_workspace: &OpcodeWorkspace, + permutation: &dyn Permutation, + local_memory: &Memory, + next_memory: &Memory, + ) { + // For compress syscall rows, verify that the permutation output's state is equal to + // the compress output memory values. + { + let compress_output_mem_values: [AB::Var; WIDTH] = array::from_fn(|i| { + if i < WIDTH / 2 { + *next_memory.memory_accesses[i].value() + } else { + *next_opcode_workspace.compress().memory_accesses[i - WIDTH / 2].value() + } + }); + + builder + .when_transition() + .when(control_flow.is_compress) + .when(control_flow.is_syscall_row) + .assert_all_eq(compress_output_mem_values, *permutation.perm_output()); + } + + // Absorb rows. + { + // Check that the state is zero on the first_hash_row. + builder + .when(control_flow.is_absorb) + .when(local_opcode_workspace.absorb().is_first_hash_row) + .assert_all_zero(local_opcode_workspace.absorb().previous_state); + + // Check that the state is equal to the permutation output when the permutation is applied. + builder + .when(control_flow.is_absorb) + .when(local_opcode_workspace.absorb().do_perm::()) + .assert_all_eq( + local_opcode_workspace.absorb().state, + *permutation.perm_output(), + ); + + // Construct the input into the permutation. + let input: [AB::Expr; WIDTH] = array::from_fn(|i| { + if i < WIDTH / 2 { + builder.if_else( + local_memory.memory_slot_used[i], + *local_memory.memory_accesses[i].value(), + local_opcode_workspace.absorb().previous_state[i], + ) + } else { + local_opcode_workspace.absorb().previous_state[i].into() + } + }); + + // Check that the state is equal the the permutation input when the permutation is not applied. + builder + .when(control_flow.is_absorb_no_perm) + .assert_all_eq(local_opcode_workspace.absorb().state, input); + + // Check that the state is copied to the next row. + builder + .when_transition() + .when(control_flow.is_absorb) + .assert_all_eq( + local_opcode_workspace.absorb().state, + next_opcode_workspace.absorb().previous_state, + ); + } + + // Finalize rows. + { + // Check that the state is equal to the permutation output when the permutation is applied. + builder + .when(control_flow.is_finalize) + .when(local_opcode_workspace.finalize().do_perm::()) + .assert_all_eq( + local_opcode_workspace.finalize().state, + *permutation.perm_output(), + ); + + // Check that the state is equal to the previous state when the permutation is not applied. + builder + .when(control_flow.is_finalize) + .when_not(local_opcode_workspace.finalize().do_perm::()) + .assert_all_eq( + local_opcode_workspace.finalize().state, + local_opcode_workspace.finalize().previous_state, + ); + + // Check that the finalize memory values are equal to the state. + let output_mem_values: [AB::Var; DIGEST_SIZE] = + array::from_fn(|i| *local_memory.memory_accesses[i].value()); + + builder.when(control_flow.is_finalize).assert_all_eq( + output_mem_values, + local_opcode_workspace.finalize().state[0..DIGEST_SIZE].to_vec(), + ); + } + } +} diff --git a/recursion/core/src/poseidon2_wide/air/syscall_params.rs b/recursion/core/src/poseidon2_wide/air/syscall_params.rs new file mode 100644 index 0000000000..db57a2cf7d --- /dev/null +++ b/recursion/core/src/poseidon2_wide/air/syscall_params.rs @@ -0,0 +1,88 @@ +use p3_air::AirBuilder; +use sp1_core::air::BaseAirBuilder; + +use crate::{ + air::SP1RecursionAirBuilder, + poseidon2_wide::{ + columns::{control_flow::ControlFlow, syscall_params::SyscallParams}, + Poseidon2WideChip, + }, + runtime::Opcode, +}; + +impl Poseidon2WideChip { + /// Eval the syscall parameters. + pub(crate) fn eval_syscall_params( + &self, + builder: &mut AB, + local_syscall: &SyscallParams, + next_syscall: &SyscallParams, + local_control_flow: &ControlFlow, + next_control_flow: &ControlFlow, + receive_syscall: AB::Var, + ) { + // Constraint that the operands are sent from the CPU table. + let params = local_syscall.get_raw_params(); + let opcodes: [AB::Expr; 3] = [ + Opcode::Poseidon2Compress, + Opcode::Poseidon2Absorb, + Opcode::Poseidon2Finalize, + ] + .map(|x| x.as_field::().into()); + let opcode_selectors = [ + local_control_flow.is_compress, + local_control_flow.is_absorb, + local_control_flow.is_finalize, + ]; + + let used_opcode: AB::Expr = opcodes + .iter() + .zip(opcode_selectors.iter()) + .map(|(opcode, opcode_selector)| opcode.clone() * *opcode_selector) + .sum(); + + builder.receive_table(used_opcode, ¶ms, receive_syscall); + + let mut transition_builder = builder.when_transition(); + + // Verify that the syscall parameters are copied to the compress output row. + { + let mut compress_syscall_builder = transition_builder + .when(local_control_flow.is_compress * local_control_flow.is_syscall_row); + + let local_syscall_params = local_syscall.compress(); + let next_syscall_params = next_syscall.compress(); + compress_syscall_builder.assert_eq(local_syscall_params.clk, next_syscall_params.clk); + compress_syscall_builder + .assert_eq(local_syscall_params.dst_ptr, next_syscall_params.dst_ptr); + compress_syscall_builder + .assert_eq(local_syscall_params.left_ptr, next_syscall_params.left_ptr); + compress_syscall_builder.assert_eq( + local_syscall_params.right_ptr, + next_syscall_params.right_ptr, + ); + } + + // Verify that the syscall parameters are copied down to all the non syscall absorb rows. + { + let mut absorb_syscall_builder = transition_builder.when(local_control_flow.is_absorb); + let mut absorb_syscall_builder = + absorb_syscall_builder.when_not(next_control_flow.is_syscall_row); + + let local_syscall_params = local_syscall.absorb(); + let next_syscall_params = next_syscall.absorb(); + + absorb_syscall_builder.assert_eq(local_syscall_params.clk, next_syscall_params.clk); + absorb_syscall_builder + .assert_eq(local_syscall_params.hash_num, next_syscall_params.hash_num); + absorb_syscall_builder.assert_eq( + local_syscall_params.input_ptr, + next_syscall_params.input_ptr, + ); + absorb_syscall_builder.assert_eq( + local_syscall_params.input_len, + next_syscall_params.input_len, + ); + } + } +} diff --git a/recursion/core/src/poseidon2_wide/columns.rs b/recursion/core/src/poseidon2_wide/columns.rs deleted file mode 100644 index 371b9d620d..0000000000 --- a/recursion/core/src/poseidon2_wide/columns.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::mem::size_of; - -use sp1_derive::AlignedBorrow; - -use crate::memory::{MemoryReadSingleCols, MemoryReadWriteSingleCols}; - -use super::external::{NUM_EXTERNAL_ROUNDS, NUM_INTERNAL_ROUNDS, WIDTH}; - -/// An enum the encapsulates mutable references to a wide version of poseidon2 chip (contains -/// intermediate sbox colunns) and a narrow version of the poseidon2 chip (doesn't contain -/// intermediate sbox columns). -pub(crate) enum Poseidon2ColTypeMut<'a, T> { - Wide(&'a mut Poseidon2SBoxCols), - Narrow(&'a mut Poseidon2Cols), -} - -impl Poseidon2ColTypeMut<'_, T> { - /// Returns mutable references to the poseidon2 columns and optional the intermediate sbox columns. - #[allow(clippy::type_complexity)] - pub fn get_cols_mut( - &mut self, - ) -> ( - &mut Poseidon2Cols, - Option<&mut [[T; WIDTH]; NUM_EXTERNAL_ROUNDS]>, - Option<&mut [T; NUM_INTERNAL_ROUNDS]>, - ) { - match self { - Poseidon2ColTypeMut::Wide(cols) => ( - &mut cols.poseidon2_cols, - Some(&mut cols.external_rounds_sbox), - Some(&mut cols.internal_rounds_sbox), - ), - Poseidon2ColTypeMut::Narrow(cols) => (cols, None, None), - } - } -} - -/// An immutable version of Poseidon2ColTypeMut. -pub(crate) enum Poseidon2ColType { - Wide(Poseidon2SBoxCols), - Narrow(Poseidon2Cols), -} - -impl Poseidon2ColType { - /// Returns reference to the poseidon2 columns. - pub fn get_poseidon2_cols(&self) -> Poseidon2Cols { - match self { - Poseidon2ColType::Wide(cols) => cols.poseidon2_cols.clone(), - Poseidon2ColType::Narrow(cols) => cols.clone(), - } - } - - /// Returns the external sbox columns for the given round. - pub const fn get_external_sbox(&self, round: usize) -> Option<&[T; WIDTH]> { - match self { - Poseidon2ColType::Wide(cols) => Some(&cols.external_rounds_sbox[round]), - Poseidon2ColType::Narrow(_) => None, - } - } - - /// Returns the internal sbox columns. - pub const fn get_internal_sbox(&self) -> Option<&[T; NUM_INTERNAL_ROUNDS]> { - match self { - Poseidon2ColType::Wide(cols) => Some(&cols.internal_rounds_sbox), - Poseidon2ColType::Narrow(_) => None, - } - } -} - -/// Memory columns for Poseidon2. -#[derive(AlignedBorrow, Clone, Copy)] -#[repr(C)] -pub struct Poseidon2MemCols { - pub timestamp: T, - pub dst: T, - pub left: T, - pub right: T, - pub input: [MemoryReadSingleCols; WIDTH], - pub output: [MemoryReadWriteSingleCols; WIDTH], - pub is_real: T, -} - -pub const NUM_POSEIDON2_COLS: usize = size_of::>(); - -/// Columns for the "narrow" Poseidon2 chip. -/// -/// As an optimization, we can represent all of the internal rounds without columns for intermediate -/// states except for the 0th element. This is because the linear layer that comes after the sbox is -/// degree 1, so all state elements at the end can be expressed as a degree-3 polynomial of: -/// 1) the 0th state element at rounds prior to the current round -/// 2) the rest of the state elements at the beginning of the internal rounds -#[derive(AlignedBorrow, Clone, Copy)] -#[repr(C)] -pub struct Poseidon2Cols { - pub(crate) memory: Poseidon2MemCols, - pub(crate) external_rounds_state: [[T; WIDTH]; NUM_EXTERNAL_ROUNDS], - pub(crate) internal_rounds_state: [T; WIDTH], - pub(crate) internal_rounds_s0: [T; NUM_INTERNAL_ROUNDS - 1], -} - -pub const NUM_POSEIDON2_SBOX_COLS: usize = size_of::>(); - -/// Columns for the "wide" Poseidon2 chip. -#[derive(AlignedBorrow, Clone, Copy)] -#[repr(C)] -pub struct Poseidon2SBoxCols { - pub(crate) poseidon2_cols: Poseidon2Cols, - pub(crate) external_rounds_sbox: [[T; WIDTH]; NUM_EXTERNAL_ROUNDS], - pub(crate) internal_rounds_sbox: [T; NUM_INTERNAL_ROUNDS], -} diff --git a/recursion/core/src/poseidon2_wide/columns/control_flow.rs b/recursion/core/src/poseidon2_wide/columns/control_flow.rs new file mode 100644 index 0000000000..1280a7a6a1 --- /dev/null +++ b/recursion/core/src/poseidon2_wide/columns/control_flow.rs @@ -0,0 +1,24 @@ +use sp1_derive::AlignedBorrow; + +/// Columns related to control flow. +#[derive(AlignedBorrow, Clone, Copy, Debug)] +#[repr(C)] +pub struct ControlFlow { + /// Specifies if this row is for compress. + pub is_compress: T, + /// Specifies if this row is for the compress output. + pub is_compress_output: T, + + /// Specifies if this row is for absorb. + pub is_absorb: T, + /// Specifies if this row is for absorb with no permutation. + pub is_absorb_no_perm: T, + /// Specifies if this row is for an absorb that is not the last row. + pub is_absorb_not_last_row: T, + + /// Specifies if this row is for finalize. + pub is_finalize: T, + + /// Specifies if this row needs to recieve a syscall interaction. + pub is_syscall_row: T, +} diff --git a/recursion/core/src/poseidon2_wide/columns/memory.rs b/recursion/core/src/poseidon2_wide/columns/memory.rs new file mode 100644 index 0000000000..63b62783ad --- /dev/null +++ b/recursion/core/src/poseidon2_wide/columns/memory.rs @@ -0,0 +1,17 @@ +use sp1_derive::AlignedBorrow; + +use crate::{memory::MemoryReadWriteSingleCols, poseidon2_wide::WIDTH}; + +/// This struct is the columns for the WIDTH/2 sequential memory slots. +/// For compress rows, this is used for the first half of read/write from the permutation state. +/// For hash related rows, this is reading absorb input and writing finalize output. +#[derive(AlignedBorrow, Clone, Copy, Debug)] +#[repr(C)] +pub struct Memory { + /// The first address of the memory sequence. + pub start_addr: T, + /// Bitmap if whether the memory address is accessed. This is set to all 1 for compress and + /// finalize rows. + pub memory_slot_used: [T; WIDTH / 2], + pub memory_accesses: [MemoryReadWriteSingleCols; WIDTH / 2], +} diff --git a/recursion/core/src/poseidon2_wide/columns/mod.rs b/recursion/core/src/poseidon2_wide/columns/mod.rs new file mode 100644 index 0000000000..baac52dc1a --- /dev/null +++ b/recursion/core/src/poseidon2_wide/columns/mod.rs @@ -0,0 +1,250 @@ +use std::mem::{size_of, transmute}; + +use sp1_core::utils::indices_arr; +use sp1_derive::AlignedBorrow; + +use self::{ + control_flow::ControlFlow, + memory::Memory, + opcode_workspace::OpcodeWorkspace, + permutation::{Permutation, PermutationNoSbox, PermutationSBox}, + syscall_params::SyscallParams, +}; + +use super::WIDTH; + +pub mod control_flow; +pub mod memory; +pub mod opcode_workspace; +pub mod permutation; +pub mod syscall_params; + +/// Trait for getter methods for Poseidon2 columns. +pub trait Poseidon2<'a, T: Copy + 'a> { + fn control_flow(&self) -> &ControlFlow; + + fn syscall_params(&self) -> &SyscallParams; + + fn memory(&self) -> &Memory; + + fn opcode_workspace(&self) -> &OpcodeWorkspace; + + fn permutation(&self) -> Box + 'a>; +} + +/// Trait for setter methods for Poseidon2 columns. +pub trait Poseidon2Mut<'a, T: Copy + 'a> { + fn control_flow_mut(&mut self) -> &mut ControlFlow; + + fn syscall_params_mut(&mut self) -> &mut SyscallParams; + + fn memory_mut(&mut self) -> &mut Memory; + + fn opcode_workspace_mut(&mut self) -> &mut OpcodeWorkspace; +} + +/// Enum to enable dynamic dispatch for the Poseidon2 columns. +#[allow(dead_code)] +enum Poseidon2Enum { + P2Degree3(Poseidon2Degree3), + P2Degree9(Poseidon2Degree9), +} + +impl<'a, T: Copy + 'a> Poseidon2<'a, T> for Poseidon2Enum { + // type Perm = PermutationSBox; + + fn control_flow(&self) -> &ControlFlow { + match self { + Poseidon2Enum::P2Degree3(p) => p.control_flow(), + Poseidon2Enum::P2Degree9(p) => p.control_flow(), + } + } + + fn syscall_params(&self) -> &SyscallParams { + match self { + Poseidon2Enum::P2Degree3(p) => p.syscall_params(), + Poseidon2Enum::P2Degree9(p) => p.syscall_params(), + } + } + + fn memory(&self) -> &Memory { + match self { + Poseidon2Enum::P2Degree3(p) => p.memory(), + Poseidon2Enum::P2Degree9(p) => p.memory(), + } + } + + fn opcode_workspace(&self) -> &OpcodeWorkspace { + match self { + Poseidon2Enum::P2Degree3(p) => p.opcode_workspace(), + Poseidon2Enum::P2Degree9(p) => p.opcode_workspace(), + } + } + + fn permutation(&self) -> Box + 'a> { + match self { + Poseidon2Enum::P2Degree3(p) => p.permutation(), + Poseidon2Enum::P2Degree9(p) => p.permutation(), + } + } +} + +/// Enum to enable dynamic dispatch for the Poseidon2 columns. +#[allow(dead_code)] +enum Poseidon2MutEnum<'a, T: Copy> { + P2Degree3(&'a mut Poseidon2Degree3), + P2Degree9(&'a mut Poseidon2Degree9), +} + +impl<'a, T: Copy + 'a> Poseidon2Mut<'a, T> for Poseidon2MutEnum<'a, T> { + fn control_flow_mut(&mut self) -> &mut ControlFlow { + match self { + Poseidon2MutEnum::P2Degree3(p) => p.control_flow_mut(), + Poseidon2MutEnum::P2Degree9(p) => p.control_flow_mut(), + } + } + + fn syscall_params_mut(&mut self) -> &mut SyscallParams { + match self { + Poseidon2MutEnum::P2Degree3(p) => p.syscall_params_mut(), + Poseidon2MutEnum::P2Degree9(p) => p.syscall_params_mut(), + } + } + + fn memory_mut(&mut self) -> &mut Memory { + match self { + Poseidon2MutEnum::P2Degree3(p) => p.memory_mut(), + Poseidon2MutEnum::P2Degree9(p) => p.memory_mut(), + } + } + + fn opcode_workspace_mut(&mut self) -> &mut OpcodeWorkspace { + match self { + Poseidon2MutEnum::P2Degree3(p) => p.opcode_workspace_mut(), + Poseidon2MutEnum::P2Degree9(p) => p.opcode_workspace_mut(), + } + } +} + +pub const NUM_POSEIDON2_DEGREE3_COLS: usize = size_of::>(); + +const fn make_col_map_degree3() -> Poseidon2Degree3 { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_POSEIDON2_DEGREE3_COLS], Poseidon2Degree3>(indices_arr) + } +} +pub const POSEIDON2_DEGREE3_COL_MAP: Poseidon2Degree3 = make_col_map_degree3(); + +/// Struct for the poseidon2 chip that contains sbox columns. +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub struct Poseidon2Degree3 { + pub control_flow: ControlFlow, + pub syscall_input: SyscallParams, + pub memory: Memory, + pub opcode_specific_cols: OpcodeWorkspace, + pub permutation_cols: PermutationSBox, + pub state_cursor: [T; WIDTH / 2], // Only used for absorb +} + +impl<'a, T: Copy + 'a> Poseidon2<'a, T> for Poseidon2Degree3 { + fn control_flow(&self) -> &ControlFlow { + &self.control_flow + } + + fn syscall_params(&self) -> &SyscallParams { + &self.syscall_input + } + + fn memory(&self) -> &Memory { + &self.memory + } + + fn opcode_workspace(&self) -> &OpcodeWorkspace { + &self.opcode_specific_cols + } + + fn permutation(&self) -> Box + 'a> { + Box::new(self.permutation_cols) + } +} + +impl<'a, T: Copy + 'a> Poseidon2Mut<'a, T> for &'a mut Poseidon2Degree3 { + fn control_flow_mut(&mut self) -> &mut ControlFlow { + &mut self.control_flow + } + + fn syscall_params_mut(&mut self) -> &mut SyscallParams { + &mut self.syscall_input + } + + fn memory_mut(&mut self) -> &mut Memory { + &mut self.memory + } + + fn opcode_workspace_mut(&mut self) -> &mut OpcodeWorkspace { + &mut self.opcode_specific_cols + } +} + +pub const NUM_POSEIDON2_DEGREE9_COLS: usize = size_of::>(); +const fn make_col_map_degree9() -> Poseidon2Degree9 { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_POSEIDON2_DEGREE9_COLS], Poseidon2Degree9>(indices_arr) + } +} +pub const POSEIDON2_DEGREE9_COL_MAP: Poseidon2Degree9 = make_col_map_degree9(); + +/// Struct for the poseidon2 chip that doesn't contain sbox columns. +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub struct Poseidon2Degree9 { + pub control_flow: ControlFlow, + pub syscall_input: SyscallParams, + pub memory: Memory, + pub opcode_specific_cols: OpcodeWorkspace, + pub permutation_cols: PermutationNoSbox, + pub state_cursor: [T; WIDTH / 2], // Only used for absorb +} + +impl<'a, T: Copy + 'a> Poseidon2<'a, T> for Poseidon2Degree9 { + fn control_flow(&self) -> &ControlFlow { + &self.control_flow + } + + fn syscall_params(&self) -> &SyscallParams { + &self.syscall_input + } + + fn memory(&self) -> &Memory { + &self.memory + } + + fn opcode_workspace(&self) -> &OpcodeWorkspace { + &self.opcode_specific_cols + } + + fn permutation(&self) -> Box + 'a> { + Box::new(self.permutation_cols) + } +} + +impl<'a, T: Copy + 'a> Poseidon2Mut<'a, T> for &'a mut Poseidon2Degree9 { + fn control_flow_mut(&mut self) -> &mut ControlFlow { + &mut self.control_flow + } + + fn syscall_params_mut(&mut self) -> &mut SyscallParams { + &mut self.syscall_input + } + + fn memory_mut(&mut self) -> &mut Memory { + &mut self.memory + } + + fn opcode_workspace_mut(&mut self) -> &mut OpcodeWorkspace { + &mut self.opcode_specific_cols + } +} diff --git a/recursion/core/src/poseidon2_wide/columns/opcode_workspace.rs b/recursion/core/src/poseidon2_wide/columns/opcode_workspace.rs new file mode 100644 index 0000000000..139db2e329 --- /dev/null +++ b/recursion/core/src/poseidon2_wide/columns/opcode_workspace.rs @@ -0,0 +1,143 @@ +use p3_field::AbstractField; +use sp1_core::operations::IsZeroOperation; +use sp1_derive::AlignedBorrow; + +use crate::{ + air::SP1RecursionAirBuilder, + memory::MemoryReadWriteSingleCols, + poseidon2_wide::{RATE, WIDTH}, +}; + +/// Workspace columns. They are different for each opcode. +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub union OpcodeWorkspace { + compress: CompressWorkspace, + absorb: AbsorbWorkspace, + finalize: FinalizeWorkspace, +} +/// Getter and setter functions for the opcode workspace. +impl OpcodeWorkspace { + pub fn compress(&self) -> &CompressWorkspace { + unsafe { &self.compress } + } + + pub fn compress_mut(&mut self) -> &mut CompressWorkspace { + unsafe { &mut self.compress } + } + + pub fn absorb(&self) -> &AbsorbWorkspace { + unsafe { &self.absorb } + } + + pub fn absorb_mut(&mut self) -> &mut AbsorbWorkspace { + unsafe { &mut self.absorb } + } + + pub fn finalize(&self) -> &FinalizeWorkspace { + unsafe { &self.finalize } + } + + pub fn finalize_mut(&mut self) -> &mut FinalizeWorkspace { + unsafe { &mut self.finalize } + } +} + +/// Workspace columns for compress. This is used memory read/writes for the 2nd half of the +/// compress permutation state. +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub struct CompressWorkspace { + pub start_addr: T, + pub memory_accesses: [MemoryReadWriteSingleCols; WIDTH / 2], +} + +/// Workspace columns for absorb. +#[derive(AlignedBorrow, Clone, Copy, Debug)] +#[repr(C)] +pub struct AbsorbWorkspace { + /// State related columns. + pub previous_state: [T; WIDTH], + pub state: [T; WIDTH], + pub state_cursor: T, + + /// Control flow columns. + pub is_first_hash_row: T, + pub num_remaining_rows: T, + pub num_remaining_rows_is_zero: IsZeroOperation, + + /// Memory columns. + pub start_mem_idx_bitmap: [T; WIDTH / 2], + pub end_mem_idx_bitmap: [T; WIDTH / 2], + + /// This is the state index of that last element consumed by the absorb syscall. + pub last_row_ending_cursor: T, + pub last_row_ending_cursor_is_seven: IsZeroOperation, // Needed when doing the (last_row_ending_cursor_is_seven + 1) % 8 calculation. + pub last_row_ending_cursor_bitmap: [T; 3], + + /// Materialized control flow flags to deal with max contraint degree. + /// Is an absorb syscall row which is not the last row for that absorb. + pub is_syscall_not_last_row: T, + /// Is an absorb syscall row that is the last row for that absorb. + pub is_syscall_is_last_row: T, + /// Is not an absorb syscall row and is not the last row for that absorb. + pub not_syscall_not_last_row: T, + /// Is not an absorb syscall row and is last row for that absorb. + pub not_syscall_is_last_row: T, + /// Is the last of an absorb and the state is filled up (e.g. it's ending cursor is 7). + pub is_last_row_ending_cursor_is_seven: T, + /// Is the last of an absorb and the state is not filled up (e.g. it's ending cursor is not 7). + pub is_last_row_ending_cursor_not_seven: T, +} + +/// Methods that are "virtual" columns (e.g. will return expressions). +impl AbsorbWorkspace { + pub(crate) fn is_last_row(&self) -> AB::Expr + where + T: Into, + { + self.num_remaining_rows_is_zero.result.into() + } + + pub(crate) fn do_perm(&self) -> AB::Expr + where + T: Into, + { + self.is_syscall_not_last_row.into() + + self.not_syscall_not_last_row.into() + + self.is_last_row_ending_cursor_is_seven.into() + } + + pub(crate) fn num_consumed(&self) -> AB::Expr + where + T: Into, + { + self.is_syscall_not_last_row.into() + * (AB::Expr::from_canonical_usize(RATE) - self.state_cursor.into()) + + self.is_syscall_is_last_row.into() + * (self.last_row_ending_cursor.into() - self.state_cursor.into() + AB::Expr::one()) + + self.not_syscall_not_last_row.into() * AB::Expr::from_canonical_usize(RATE) + + self.not_syscall_is_last_row.into() + * (self.last_row_ending_cursor.into() + AB::Expr::one()) + } +} + +/// Workspace columns for finalize. +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub struct FinalizeWorkspace { + /// State related columns. + pub previous_state: [T; WIDTH], + pub state: [T; WIDTH], + pub state_cursor: T, + pub state_cursor_is_zero: IsZeroOperation, +} + +impl FinalizeWorkspace { + pub(crate) fn do_perm(&self) -> AB::Expr + where + T: Into, + { + AB::Expr::one() - self.state_cursor_is_zero.result.into() + } +} diff --git a/recursion/core/src/poseidon2_wide/columns/permutation.rs b/recursion/core/src/poseidon2_wide/columns/permutation.rs new file mode 100644 index 0000000000..caaab08fd0 --- /dev/null +++ b/recursion/core/src/poseidon2_wide/columns/permutation.rs @@ -0,0 +1,239 @@ +use std::{borrow::BorrowMut, mem::size_of}; + +use sp1_derive::AlignedBorrow; + +use crate::poseidon2_wide::{NUM_EXTERNAL_ROUNDS, NUM_INTERNAL_ROUNDS, WIDTH}; + +use super::{POSEIDON2_DEGREE3_COL_MAP, POSEIDON2_DEGREE9_COL_MAP}; + +/// Trait that describes getter functions for the permutation columns. +pub trait Permutation { + fn external_rounds_state(&self) -> &[[T; WIDTH]]; + + fn internal_rounds_state(&self) -> &[T; WIDTH]; + + fn internal_rounds_s0(&self) -> &[T; NUM_INTERNAL_ROUNDS - 1]; + + fn external_rounds_sbox(&self) -> Option<&[[T; WIDTH]; NUM_EXTERNAL_ROUNDS]>; + + fn internal_rounds_sbox(&self) -> Option<&[T; NUM_INTERNAL_ROUNDS]>; + + fn perm_output(&self) -> &[T; WIDTH]; +} + +/// Trait that describes setter functions for the permutation columns. +pub trait PermutationMut { + #[allow(clippy::type_complexity)] + fn get_cols_mut( + &mut self, + ) -> ( + &mut [[T; WIDTH]], + &mut [T; WIDTH], + &mut [T; NUM_INTERNAL_ROUNDS - 1], + Option<&mut [[T; WIDTH]; NUM_EXTERNAL_ROUNDS]>, + Option<&mut [T; NUM_INTERNAL_ROUNDS]>, + &mut [T; WIDTH], + ); +} + +/// Permutation columns struct with S-boxes. +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub struct PermutationSBox { + pub external_rounds_state: [[T; WIDTH]; NUM_EXTERNAL_ROUNDS], + pub internal_rounds_state: [T; WIDTH], + pub internal_rounds_s0: [T; NUM_INTERNAL_ROUNDS - 1], + pub external_rounds_sbox: [[T; WIDTH]; NUM_EXTERNAL_ROUNDS], + pub internal_rounds_sbox: [T; NUM_INTERNAL_ROUNDS], + pub output_state: [T; WIDTH], +} + +impl Permutation for PermutationSBox { + fn external_rounds_state(&self) -> &[[T; WIDTH]] { + &self.external_rounds_state + } + + fn internal_rounds_state(&self) -> &[T; WIDTH] { + &self.internal_rounds_state + } + + fn internal_rounds_s0(&self) -> &[T; NUM_INTERNAL_ROUNDS - 1] { + &self.internal_rounds_s0 + } + + fn external_rounds_sbox(&self) -> Option<&[[T; WIDTH]; NUM_EXTERNAL_ROUNDS]> { + Some(&self.external_rounds_sbox) + } + + fn internal_rounds_sbox(&self) -> Option<&[T; NUM_INTERNAL_ROUNDS]> { + Some(&self.internal_rounds_sbox) + } + + fn perm_output(&self) -> &[T; WIDTH] { + &self.output_state + } +} + +impl PermutationMut for &mut PermutationSBox { + fn get_cols_mut( + &mut self, + ) -> ( + &mut [[T; WIDTH]], + &mut [T; WIDTH], + &mut [T; NUM_INTERNAL_ROUNDS - 1], + Option<&mut [[T; WIDTH]; NUM_EXTERNAL_ROUNDS]>, + Option<&mut [T; NUM_INTERNAL_ROUNDS]>, + &mut [T; WIDTH], + ) { + ( + &mut self.external_rounds_state, + &mut self.internal_rounds_state, + &mut self.internal_rounds_s0, + Some(&mut self.external_rounds_sbox), + Some(&mut self.internal_rounds_sbox), + &mut self.output_state, + ) + } +} + +/// Permutation columns struct without S-boxes. +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub struct PermutationNoSbox { + pub external_rounds_state: [[T; WIDTH]; NUM_EXTERNAL_ROUNDS], + pub internal_rounds_state: [T; WIDTH], + pub internal_rounds_s0: [T; NUM_INTERNAL_ROUNDS - 1], + pub output_state: [T; WIDTH], +} + +impl Permutation for PermutationNoSbox { + fn external_rounds_state(&self) -> &[[T; WIDTH]] { + &self.external_rounds_state + } + + fn internal_rounds_state(&self) -> &[T; WIDTH] { + &self.internal_rounds_state + } + + fn internal_rounds_s0(&self) -> &[T; NUM_INTERNAL_ROUNDS - 1] { + &self.internal_rounds_s0 + } + + fn external_rounds_sbox(&self) -> Option<&[[T; WIDTH]; NUM_EXTERNAL_ROUNDS]> { + None + } + + fn internal_rounds_sbox(&self) -> Option<&[T; NUM_INTERNAL_ROUNDS]> { + None + } + + fn perm_output(&self) -> &[T; WIDTH] { + &self.output_state + } +} + +impl PermutationMut for &mut PermutationNoSbox { + fn get_cols_mut( + &mut self, + ) -> ( + &mut [[T; WIDTH]], + &mut [T; WIDTH], + &mut [T; NUM_INTERNAL_ROUNDS - 1], + Option<&mut [[T; WIDTH]; NUM_EXTERNAL_ROUNDS]>, + Option<&mut [T; NUM_INTERNAL_ROUNDS]>, + &mut [T; WIDTH], + ) { + ( + &mut self.external_rounds_state, + &mut self.internal_rounds_state, + &mut self.internal_rounds_s0, + None, + None, + &mut self.output_state, + ) + } +} + +/// Permutation columns struct without S-boxes and half of the external rounds. +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub struct PermutationNoSboxHalfExternal { + pub external_rounds_state: [[T; WIDTH]; NUM_EXTERNAL_ROUNDS / 2], + pub internal_rounds_state: [T; WIDTH], + pub internal_rounds_s0: [T; NUM_INTERNAL_ROUNDS - 1], + pub output_state: [T; WIDTH], +} + +impl Permutation for PermutationNoSboxHalfExternal { + fn external_rounds_state(&self) -> &[[T; WIDTH]] { + &self.external_rounds_state + } + + fn internal_rounds_state(&self) -> &[T; WIDTH] { + &self.internal_rounds_state + } + + fn internal_rounds_s0(&self) -> &[T; NUM_INTERNAL_ROUNDS - 1] { + &self.internal_rounds_s0 + } + + fn external_rounds_sbox(&self) -> Option<&[[T; WIDTH]; NUM_EXTERNAL_ROUNDS]> { + None + } + + fn internal_rounds_sbox(&self) -> Option<&[T; NUM_INTERNAL_ROUNDS]> { + None + } + + fn perm_output(&self) -> &[T; WIDTH] { + &self.output_state + } +} + +impl PermutationMut for &mut PermutationNoSboxHalfExternal { + fn get_cols_mut( + &mut self, + ) -> ( + &mut [[T; WIDTH]], + &mut [T; WIDTH], + &mut [T; NUM_INTERNAL_ROUNDS - 1], + Option<&mut [[T; WIDTH]; NUM_EXTERNAL_ROUNDS]>, + Option<&mut [T; NUM_INTERNAL_ROUNDS]>, + &mut [T; WIDTH], + ) { + ( + &mut self.external_rounds_state, + &mut self.internal_rounds_state, + &mut self.internal_rounds_s0, + None, + None, + &mut self.output_state, + ) + } +} + +pub fn permutation_mut<'a, 'b: 'a, T, const DEGREE: usize>( + row: &'b mut [T], +) -> Box + 'a> +where + T: Copy, +{ + if DEGREE == 3 { + let start = POSEIDON2_DEGREE3_COL_MAP + .permutation_cols + .external_rounds_state[0][0]; + let end = start + size_of::>(); + let convert: &mut PermutationSBox = row[start..end].borrow_mut(); + Box::new(convert) + } else if DEGREE == 9 || DEGREE == 17 { + let start = POSEIDON2_DEGREE9_COL_MAP + .permutation_cols + .external_rounds_state[0][0]; + let end = start + size_of::>(); + + let convert: &mut PermutationNoSbox = row[start..end].borrow_mut(); + Box::new(convert) + } else { + panic!("Unsupported degree"); + } +} diff --git a/recursion/core/src/poseidon2_wide/columns/syscall_params.rs b/recursion/core/src/poseidon2_wide/columns/syscall_params.rs new file mode 100644 index 0000000000..b03d6b81ed --- /dev/null +++ b/recursion/core/src/poseidon2_wide/columns/syscall_params.rs @@ -0,0 +1,82 @@ +use std::mem::size_of; + +use sp1_derive::AlignedBorrow; + +const SYSCALL_PARAMS_SIZE: usize = size_of::>(); + +/// Syscall params columns. They are different for each opcode. +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub union SyscallParams { + compress: CompressParams, + absorb: AbsorbParams, + finalize: FinalizeParams, +} + +impl SyscallParams { + pub fn compress(&self) -> &CompressParams { + assert!(size_of::>() == SYSCALL_PARAMS_SIZE); + unsafe { &self.compress } + } + + pub fn compress_mut(&mut self) -> &mut CompressParams { + unsafe { &mut self.compress } + } + + pub fn absorb(&self) -> &AbsorbParams { + assert!(size_of::>() == SYSCALL_PARAMS_SIZE); + unsafe { &self.absorb } + } + + pub fn absorb_mut(&mut self) -> &mut AbsorbParams { + unsafe { &mut self.absorb } + } + + pub fn finalize(&self) -> &FinalizeParams { + assert!(size_of::>() == SYSCALL_PARAMS_SIZE); + unsafe { &self.finalize } + } + + pub fn finalize_mut(&mut self) -> &mut FinalizeParams { + unsafe { &mut self.finalize } + } + + pub fn get_raw_params(&self) -> [T; SYSCALL_PARAMS_SIZE] { + // All of the union's fields should have the same size, so just choose one of them to return + // the elements. + let compress = self.compress(); + [ + compress.clk, + compress.dst_ptr, + compress.left_ptr, + compress.right_ptr, + ] + } +} + +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub struct CompressParams { + pub clk: T, + pub dst_ptr: T, + pub left_ptr: T, + pub right_ptr: T, +} + +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub struct AbsorbParams { + pub clk: T, + pub hash_num: T, + pub input_ptr: T, + pub input_len: T, +} + +#[derive(AlignedBorrow, Clone, Copy)] +#[repr(C)] +pub struct FinalizeParams { + pub clk: T, + pub hash_num: T, + pub output_ptr: T, + pub pad: T, +} diff --git a/recursion/core/src/poseidon2_wide/events.rs b/recursion/core/src/poseidon2_wide/events.rs new file mode 100644 index 0000000000..5d17d27faf --- /dev/null +++ b/recursion/core/src/poseidon2_wide/events.rs @@ -0,0 +1,153 @@ +use p3_field::PrimeField32; +use p3_symmetric::Permutation; + +use crate::memory::MemoryRecord; +use crate::poseidon2_wide::WIDTH; +use crate::runtime::DIGEST_SIZE; + +use super::RATE; + +#[derive(Debug, Clone)] +pub enum Poseidon2HashEvent { + Absorb(Poseidon2AbsorbEvent), + Finalize(Poseidon2FinalizeEvent), +} + +#[derive(Debug, Clone)] +pub struct Poseidon2CompressEvent { + pub clk: F, + pub dst: F, // from a_val + pub left: F, // from b_val + pub right: F, // from c_val + pub input: [F; WIDTH], + pub result_array: [F; WIDTH], + pub input_records: [MemoryRecord; WIDTH], + pub result_records: [MemoryRecord; WIDTH], +} + +#[derive(Debug, Clone)] +pub struct Poseidon2AbsorbEvent { + pub clk: F, + pub hash_num: F, // from a_val + pub input_addr: F, // from b_val + pub input_len: F, // from c_val + + pub iterations: Vec>, + pub is_first_aborb: bool, +} + +impl Poseidon2AbsorbEvent { + pub(crate) fn new( + clk: F, + hash_num: F, + input_addr: F, + input_len: F, + is_first_absorb: bool, + ) -> Self { + Self { + clk, + hash_num, + input_addr, + input_len, + iterations: Vec::new(), + is_first_aborb: is_first_absorb, + } + } +} + +impl Poseidon2AbsorbEvent { + pub(crate) fn populate_iterations( + &mut self, + start_addr: F, + input_len: F, + memory_records: &[MemoryRecord], + permuter: &impl Permutation<[F; WIDTH]>, + hash_state: &mut [F; WIDTH], + hash_state_cursor: &mut usize, + ) { + let mut input_records = Vec::new(); + let mut previous_state = *hash_state; + let mut iter_num_consumed = 0; + + let start_addr = start_addr.as_canonical_u32(); + let end_addr = start_addr + input_len.as_canonical_u32(); + + for (addr_iter, memory_record) in (start_addr..end_addr).zip(memory_records.iter()) { + input_records.push(*memory_record); + + hash_state[*hash_state_cursor] = memory_record.value[0]; + *hash_state_cursor += 1; + iter_num_consumed += 1; + + // Do a permutation when the hash state is full. + if *hash_state_cursor == RATE { + let perm_input = *hash_state; + *hash_state = permuter.permute(*hash_state); + + self.iterations.push(Poseidon2AbsorbIteration { + state_cursor: *hash_state_cursor - iter_num_consumed, + start_addr: F::from_canonical_u32(addr_iter - iter_num_consumed as u32 + 1), + input_records, + perm_input, + perm_output: *hash_state, + previous_state, + state: *hash_state, + do_perm: true, + }); + + previous_state = *hash_state; + input_records = Vec::new(); + *hash_state_cursor = 0; + iter_num_consumed = 0; + } + } + + if *hash_state_cursor != 0 { + // Note that we still do a permutation, generate the trace and enforce permutation + // constraints for every absorb and finalize row. + self.iterations.push(Poseidon2AbsorbIteration { + state_cursor: *hash_state_cursor - iter_num_consumed, + start_addr: F::from_canonical_u32(end_addr - iter_num_consumed as u32), + input_records, + perm_input: *hash_state, + perm_output: permuter.permute(*hash_state), + previous_state, + state: *hash_state, + do_perm: false, + }); + } + } +} + +#[derive(Debug, Clone)] +pub struct Poseidon2AbsorbIteration { + pub state_cursor: usize, + pub start_addr: F, + pub input_records: Vec>, + + pub perm_input: [F; WIDTH], + pub perm_output: [F; WIDTH], + + pub previous_state: [F; WIDTH], + pub state: [F; WIDTH], + + pub do_perm: bool, +} + +#[derive(Debug, Clone)] +pub struct Poseidon2FinalizeEvent { + pub clk: F, + pub hash_num: F, // from a_val + pub output_ptr: F, // from b_val + pub output_records: [MemoryRecord; DIGEST_SIZE], + + pub state_cursor: usize, + + pub perm_input: [F; WIDTH], + pub perm_output: [F; WIDTH], + + pub previous_state: [F; WIDTH], + pub state: [F; WIDTH], + + pub do_perm: bool, +} diff --git a/recursion/core/src/poseidon2_wide/external.rs b/recursion/core/src/poseidon2_wide/external.rs deleted file mode 100644 index 9181d3a1fa..0000000000 --- a/recursion/core/src/poseidon2_wide/external.rs +++ /dev/null @@ -1,587 +0,0 @@ -use crate::poseidon2_wide::columns::{ - Poseidon2ColType, Poseidon2ColTypeMut, Poseidon2Cols, Poseidon2SBoxCols, NUM_POSEIDON2_COLS, - NUM_POSEIDON2_SBOX_COLS, -}; -use crate::runtime::Opcode; -use core::borrow::Borrow; -use p3_air::{Air, BaseAir}; -use p3_field::{AbstractField, PrimeField32}; -use p3_matrix::dense::RowMajorMatrix; -use p3_matrix::Matrix; -use sp1_core::air::{BaseAirBuilder, MachineAir, SP1AirBuilder}; -use sp1_core::utils::pad_rows_fixed; -use sp1_primitives::RC_16_30_U32; -use std::borrow::BorrowMut; -use tracing::instrument; - -use crate::air::SP1RecursionAirBuilder; -use crate::memory::MemoryCols; - -use crate::poseidon2_wide::{external_linear_layer, internal_linear_layer}; -use crate::runtime::{ExecutionRecord, RecursionProgram}; - -use super::columns::Poseidon2MemCols; - -/// The width of the permutation. -pub const WIDTH: usize = 16; - -pub const NUM_EXTERNAL_ROUNDS: usize = 8; -pub const NUM_INTERNAL_ROUNDS: usize = 13; -pub const NUM_ROUNDS: usize = NUM_EXTERNAL_ROUNDS + NUM_INTERNAL_ROUNDS; - -/// A chip that implements addition for the opcode ADD. -#[derive(Default)] -pub struct Poseidon2WideChip { - pub fixed_log2_rows: Option, -} - -impl MachineAir for Poseidon2WideChip { - type Record = ExecutionRecord; - - type Program = RecursionProgram; - - fn name(&self) -> String { - format!("Poseidon2Wide {}", DEGREE) - } - - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { - // This is a no-op. - } - - #[instrument(name = "generate poseidon2 wide trace", level = "debug", skip_all, fields(rows = input.poseidon2_events.len()))] - fn generate_trace( - &self, - input: &ExecutionRecord, - _: &mut ExecutionRecord, - ) -> RowMajorMatrix { - let mut rows = Vec::new(); - - assert!(DEGREE >= 3, "Minimum supported constraint degree is 3"); - let use_sbox_3 = DEGREE < 7; - let num_columns = >::width(self); - - for event in &input.poseidon2_events { - let mut row = vec![F::zero(); num_columns]; - - let mut cols = if use_sbox_3 { - let cols: &mut Poseidon2SBoxCols = row.as_mut_slice().borrow_mut(); - Poseidon2ColTypeMut::Wide(cols) - } else { - let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); - Poseidon2ColTypeMut::Narrow(cols) - }; - - let (poseidon2_cols, mut external_sbox, mut internal_sbox) = cols.get_cols_mut(); - - let memory = &mut poseidon2_cols.memory; - memory.timestamp = event.clk; - memory.dst = event.dst; - memory.left = event.left; - memory.right = event.right; - memory.is_real = F::one(); - - // Apply the initial round. - for i in 0..WIDTH { - memory.input[i].populate(&event.input_records[i]); - } - - for i in 0..WIDTH { - memory.output[i].populate(&event.result_records[i]); - } - - poseidon2_cols.external_rounds_state[0] = event.input; - external_linear_layer(&mut poseidon2_cols.external_rounds_state[0]); - - // Apply the first half of external rounds. - for r in 0..NUM_EXTERNAL_ROUNDS / 2 { - let next_state = populate_external_round(poseidon2_cols, &mut external_sbox, r); - - if r == NUM_EXTERNAL_ROUNDS / 2 - 1 { - poseidon2_cols.internal_rounds_state = next_state; - } else { - poseidon2_cols.external_rounds_state[r + 1] = next_state; - } - } - - // Apply the internal rounds. - poseidon2_cols.external_rounds_state[NUM_EXTERNAL_ROUNDS / 2] = - populate_internal_rounds(poseidon2_cols, &mut internal_sbox); - - // Apply the second half of external rounds. - for r in NUM_EXTERNAL_ROUNDS / 2..NUM_EXTERNAL_ROUNDS { - let next_state = populate_external_round(poseidon2_cols, &mut external_sbox, r); - if r == NUM_EXTERNAL_ROUNDS - 1 { - // Do nothing, since we set the cols.output by populating the output records - // after this loop. - for i in 0..WIDTH { - assert_eq!(event.result_records[i].value[0], next_state[i]); - } - } else { - poseidon2_cols.external_rounds_state[r + 1] = next_state; - } - } - - rows.push(row); - } - - // Pad the trace to a power of two. - pad_rows_fixed( - &mut rows, - || vec![F::zero(); num_columns], - self.fixed_log2_rows, - ); - - // Convert the trace to a row major matrix. - let trace = - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), num_columns); - - #[cfg(debug_assertions)] - println!( - "poseidon2 wide trace dims is width: {:?}, height: {:?}", - trace.width(), - trace.height() - ); - - trace - } - - fn included(&self, record: &Self::Record) -> bool { - !record.poseidon2_events.is_empty() - } -} - -fn populate_external_round( - poseidon2_cols: &mut Poseidon2Cols, - sbox: &mut Option<&mut [[F; WIDTH]; NUM_EXTERNAL_ROUNDS]>, - r: usize, -) -> [F; WIDTH] { - let mut state = { - let round_state: &mut [F; WIDTH] = poseidon2_cols.external_rounds_state[r].borrow_mut(); - - // Add round constants. - // - // Optimization: Since adding a constant is a degree 1 operation, we can avoid adding - // columns for it, and instead include it in the constraint for the x^3 part of the sbox. - let round = if r < NUM_EXTERNAL_ROUNDS / 2 { - r - } else { - r + NUM_INTERNAL_ROUNDS - }; - let mut add_rc = *round_state; - for i in 0..WIDTH { - add_rc[i] += F::from_wrapped_u32(RC_16_30_U32[round][i]); - } - - // Apply the sboxes. - // Optimization: since the linear layer that comes after the sbox is degree 1, we can - // avoid adding columns for the result of the sbox, and instead include the x^3 -> x^7 - // part of the sbox in the constraint for the linear layer - let mut sbox_deg_7: [F; 16] = [F::zero(); WIDTH]; - let mut sbox_deg_3: [F; 16] = [F::zero(); WIDTH]; - for i in 0..WIDTH { - sbox_deg_3[i] = add_rc[i] * add_rc[i] * add_rc[i]; - sbox_deg_7[i] = sbox_deg_3[i] * sbox_deg_3[i] * add_rc[i]; - } - - if let Some(sbox) = sbox.as_deref_mut() { - sbox[r] = sbox_deg_3; - } - - sbox_deg_7 - }; - - // Apply the linear layer. - external_linear_layer(&mut state); - state -} - -fn populate_internal_rounds( - poseidon2_cols: &mut Poseidon2Cols, - sbox: &mut Option<&mut [F; NUM_INTERNAL_ROUNDS]>, -) -> [F; WIDTH] { - let mut state: [F; WIDTH] = poseidon2_cols.internal_rounds_state; - let mut sbox_deg_3: [F; NUM_INTERNAL_ROUNDS] = [F::zero(); NUM_INTERNAL_ROUNDS]; - for r in 0..NUM_INTERNAL_ROUNDS { - // Add the round constant to the 0th state element. - // Optimization: Since adding a constant is a degree 1 operation, we can avoid adding - // columns for it, just like for external rounds. - let round = r + NUM_EXTERNAL_ROUNDS / 2; - let add_rc = state[0] + F::from_wrapped_u32(RC_16_30_U32[round][0]); - - // Apply the sboxes. - // Optimization: since the linear layer that comes after the sbox is degree 1, we can - // avoid adding columns for the result of the sbox, just like for external rounds. - sbox_deg_3[r] = add_rc * add_rc * add_rc; - let sbox_deg_7 = sbox_deg_3[r] * sbox_deg_3[r] * add_rc; - - // Apply the linear layer. - state[0] = sbox_deg_7; - internal_linear_layer(&mut state); - - // Optimization: since we're only applying the sbox to the 0th state element, we only - // need to have columns for the 0th state element at every step. This is because the - // linear layer is degree 1, so all state elements at the end can be expressed as a - // degree-3 polynomial of the state at the beginning of the internal rounds and the 0th - // state element at rounds prior to the current round - if r < NUM_INTERNAL_ROUNDS - 1 { - poseidon2_cols.internal_rounds_s0[r] = state[0]; - } - } - - let ret_state = state; - - if let Some(sbox) = sbox.as_deref_mut() { - *sbox = sbox_deg_3; - } - - ret_state -} - -fn eval_external_round( - builder: &mut AB, - cols: &Poseidon2ColType, - r: usize, - is_real: AB::Var, -) { - let poseidon2_cols = cols.get_poseidon2_cols(); - let external_state = poseidon2_cols.external_rounds_state[r]; - - // Add the round constants. - let round = if r < NUM_EXTERNAL_ROUNDS / 2 { - r - } else { - r + NUM_INTERNAL_ROUNDS - }; - let add_rc: [AB::Expr; WIDTH] = core::array::from_fn(|i| { - external_state[i].into() + is_real * AB::F::from_wrapped_u32(RC_16_30_U32[round][i]) - }); - - // Apply the sboxes. - // See `populate_external_round` for why we don't have columns for the sbox output here. - let mut sbox_deg_7: [AB::Expr; WIDTH] = core::array::from_fn(|_| AB::Expr::zero()); - let mut sbox_deg_3: [AB::Expr; WIDTH] = core::array::from_fn(|_| AB::Expr::zero()); - let expected_sbox_deg_3 = cols.get_external_sbox(r); - for i in 0..WIDTH { - sbox_deg_3[i] = add_rc[i].clone() * add_rc[i].clone() * add_rc[i].clone(); - - if let Some(expected) = expected_sbox_deg_3 { - builder.assert_eq(expected[i], sbox_deg_3[i].clone()); - sbox_deg_3[i] = expected[i].into(); - } - - sbox_deg_7[i] = sbox_deg_3[i].clone() * sbox_deg_3[i].clone() * add_rc[i].clone(); - } - - // Apply the linear layer. - let mut state = sbox_deg_7; - external_linear_layer(&mut state); - - let next_state_cols = if r == NUM_EXTERNAL_ROUNDS / 2 - 1 { - poseidon2_cols.internal_rounds_state - } else if r == NUM_EXTERNAL_ROUNDS - 1 { - core::array::from_fn(|i| *poseidon2_cols.memory.output[i].value()) - } else { - poseidon2_cols.external_rounds_state[r + 1] - }; - for i in 0..WIDTH { - builder.assert_eq(next_state_cols[i], state[i].clone()); - } -} - -fn eval_internal_rounds( - builder: &mut AB, - cols: &Poseidon2ColType, - is_real: AB::Var, -) { - let poseidon2_cols = cols.get_poseidon2_cols(); - let state = &poseidon2_cols.internal_rounds_state; - let s0 = poseidon2_cols.internal_rounds_s0; - let sbox_3 = cols.get_internal_sbox(); - let mut state: [AB::Expr; WIDTH] = core::array::from_fn(|i| state[i].into()); - for r in 0..NUM_INTERNAL_ROUNDS { - // Add the round constant. - let round = r + NUM_EXTERNAL_ROUNDS / 2; - let add_rc = if r == 0 { - state[0].clone() - } else { - s0[r - 1].into() - } + is_real * AB::Expr::from_wrapped_u32(RC_16_30_U32[round][0]); - - let mut sbox_deg_3 = add_rc.clone() * add_rc.clone() * add_rc.clone(); - if let Some(expected) = sbox_3 { - builder.assert_eq(expected[r], sbox_deg_3); - sbox_deg_3 = expected[r].into(); - } - - // See `populate_internal_rounds` for why we don't have columns for the sbox output here. - let sbox_deg_7 = sbox_deg_3.clone() * sbox_deg_3 * add_rc.clone(); - - // Apply the linear layer. - // See `populate_internal_rounds` for why we don't have columns for the new state here. - state[0] = sbox_deg_7.clone(); - internal_linear_layer(&mut state); - - if r < NUM_INTERNAL_ROUNDS - 1 { - builder.assert_eq(s0[r], state[0].clone()); - } - } - - let external_state = poseidon2_cols.external_rounds_state[NUM_EXTERNAL_ROUNDS / 2]; - for i in 0..WIDTH { - builder.assert_eq(external_state[i], state[i].clone()) - } -} - -impl BaseAir for Poseidon2WideChip { - fn width(&self) -> usize { - match DEGREE { - d if d < 7 => NUM_POSEIDON2_SBOX_COLS, - _ => NUM_POSEIDON2_COLS, - } - } -} - -fn eval_mem(builder: &mut AB, local: &Poseidon2MemCols) { - // Evaluate all of the memory. - for i in 0..WIDTH { - let input_addr = if i < WIDTH / 2 { - local.left + AB::F::from_canonical_usize(i) - } else { - local.right + AB::F::from_canonical_usize(i - WIDTH / 2) - }; - - builder.recursion_eval_memory_access_single( - local.timestamp, - input_addr, - &local.input[i], - local.is_real, - ); - - let output_addr = local.dst + AB::F::from_canonical_usize(i); - builder.recursion_eval_memory_access_single( - local.timestamp + AB::F::from_canonical_usize(1), - output_addr, - &local.output[i], - local.is_real, - ); - } - - // Constraint that the operands are sent from the CPU table. - let operands: [AB::Expr; 4] = [ - local.timestamp.into(), - local.dst.into(), - local.left.into(), - local.right.into(), - ]; - builder.receive_table( - Opcode::Poseidon2Compress.as_field::(), - &operands, - local.is_real, - ); -} - -impl Air for Poseidon2WideChip -where - AB: SP1RecursionAirBuilder, -{ - fn eval(&self, builder: &mut AB) { - assert!(DEGREE >= 3, "Minimum supported constraint degree is 3"); - let main = builder.main(); - let cols = main.row_slice(0); - let cols = match DEGREE { - d if d < 7 => { - let cols: &Poseidon2SBoxCols = (*cols).borrow(); - Poseidon2ColType::Wide(*cols) - } - _ => { - let cols: &Poseidon2Cols = (*cols).borrow(); - Poseidon2ColType::Narrow(*cols) - } - }; - - let poseidon2_cols = cols.get_poseidon2_cols(); - let memory = poseidon2_cols.memory; - eval_mem(builder, &memory); - - // Dummy constraints to normalize to DEGREE. - let lhs = (0..DEGREE) - .map(|_| memory.is_real.into()) - .product::(); - let rhs = (0..DEGREE) - .map(|_| memory.is_real.into()) - .product::(); - builder.assert_eq(lhs, rhs); - - // Apply the initial round. - let initial_round_output = { - let mut initial_round_output: [AB::Expr; WIDTH] = - core::array::from_fn(|i| (*poseidon2_cols.memory.input[i].value()).into()); - external_linear_layer(&mut initial_round_output); - initial_round_output - }; - let external_round_0_state: [AB::Expr; WIDTH] = core::array::from_fn(|i| { - let state = poseidon2_cols.external_rounds_state[0]; - state[i].into() - }); - builder - .when(memory.is_real) - .assert_all_eq(external_round_0_state.clone(), initial_round_output); - - // Apply the first half of external rounds. - for r in 0..NUM_EXTERNAL_ROUNDS / 2 { - eval_external_round(builder, &cols, r, memory.is_real); - } - - // Apply the internal rounds. - eval_internal_rounds(builder, &cols, memory.is_real); - - // Apply the second half of external rounds. - for r in NUM_EXTERNAL_ROUNDS / 2..NUM_EXTERNAL_ROUNDS { - eval_external_round(builder, &cols, r, memory.is_real); - } - - // Make the degree equivalent to WIDTH to compress the interaction columns. - let mut dummy = memory.is_real * memory.is_real; - for _ in 0..(DEGREE - 2) { - dummy *= memory.is_real.into(); - } - builder.assert_eq(dummy.clone(), dummy.clone()); - } -} - -#[cfg(test)] -mod tests { - use std::time::Instant; - - use crate::poseidon2::Poseidon2Event; - use crate::poseidon2_wide::external::WIDTH; - use crate::{poseidon2_wide::external::Poseidon2WideChip, runtime::ExecutionRecord}; - use itertools::Itertools; - use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear}; - use p3_field::AbstractField; - use p3_matrix::dense::RowMajorMatrix; - use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral}; - use p3_symmetric::Permutation; - use sp1_core::air::MachineAir; - use sp1_core::stark::StarkGenericConfig; - use sp1_core::utils::{inner_perm, uni_stark_prove, uni_stark_verify, BabyBearPoseidon2}; - use zkhash::ark_ff::UniformRand; - - fn generate_trace_degree() { - let chip = Poseidon2WideChip:: { - fixed_log2_rows: None, - }; - - let test_inputs = vec![ - [BabyBear::from_canonical_u32(1); WIDTH], - [BabyBear::from_canonical_u32(2); WIDTH], - [BabyBear::from_canonical_u32(3); WIDTH], - [BabyBear::from_canonical_u32(4); WIDTH], - ]; - - let gt: Poseidon2< - BabyBear, - Poseidon2ExternalMatrixGeneral, - DiffusionMatrixBabyBear, - 16, - 7, - > = inner_perm(); - - let expected_outputs = test_inputs - .iter() - .map(|input| gt.permute(*input)) - .collect::>(); - - let mut input_exec = ExecutionRecord::::default(); - for (input, output) in test_inputs.clone().into_iter().zip_eq(expected_outputs) { - input_exec - .poseidon2_events - .push(Poseidon2Event::dummy_from_input(input, output)); - } - - // Generate trace will assert for the expected outputs. - chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); - } - - /// A test generating a trace for a single permutation that checks that the output is correct - #[test] - fn generate_trace() { - generate_trace_degree::<3>(); - generate_trace_degree::<7>(); - } - - fn poseidon2_wide_prove_babybear_degree( - inputs: Vec<[BabyBear; 16]>, - outputs: Vec<[BabyBear; 16]>, - ) { - let chip = Poseidon2WideChip:: { - fixed_log2_rows: None, - }; - let mut input_exec = ExecutionRecord::::default(); - for (input, output) in inputs.into_iter().zip_eq(outputs) { - input_exec - .poseidon2_events - .push(Poseidon2Event::dummy_from_input(input, output)); - } - let trace: RowMajorMatrix = - chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); - - let config = BabyBearPoseidon2::compressed(); - let mut challenger = config.challenger(); - - let start = Instant::now(); - let proof = uni_stark_prove(&config, &chip, &mut challenger, trace); - let duration = start.elapsed().as_secs_f64(); - println!("proof duration = {:?}", duration); - - let mut challenger = config.challenger(); - let start = Instant::now(); - uni_stark_verify(&config, &chip, &mut challenger, &proof) - .expect("expected proof to be valid"); - - let duration = start.elapsed().as_secs_f64(); - println!("verify duration = {:?}", duration); - } - - #[test] - fn poseidon2_wide_prove_babybear_success() { - let rng = &mut rand::thread_rng(); - - let test_inputs: Vec<[BabyBear; 16]> = (0..1000) - .map(|_| core::array::from_fn(|_| BabyBear::rand(rng))) - .collect_vec(); - - let gt: Poseidon2< - BabyBear, - Poseidon2ExternalMatrixGeneral, - DiffusionMatrixBabyBear, - 16, - 7, - > = inner_perm(); - - let expected_outputs = test_inputs - .iter() - .map(|input| gt.permute(*input)) - .collect::>(); - - poseidon2_wide_prove_babybear_degree::<3>(test_inputs.clone(), expected_outputs.clone()); - poseidon2_wide_prove_babybear_degree::<7>(test_inputs, expected_outputs); - } - - #[test] - #[should_panic] - fn poseidon2_wide_prove_babybear_failure() { - let rng = &mut rand::thread_rng(); - - let test_inputs = (0..1000) - .map(|i| [BabyBear::from_canonical_u32(i); WIDTH]) - .collect_vec(); - - let bad_outputs: Vec<[BabyBear; 16]> = (0..1000) - .map(|_| core::array::from_fn(|_| BabyBear::rand(rng))) - .collect_vec(); - - poseidon2_wide_prove_babybear_degree::<3>(test_inputs.clone(), bad_outputs.clone()); - poseidon2_wide_prove_babybear_degree::<7>(test_inputs, bad_outputs); - } -} diff --git a/recursion/core/src/poseidon2_wide/mod.rs b/recursion/core/src/poseidon2_wide/mod.rs index f5531bad84..14f0af267a 100644 --- a/recursion/core/src/poseidon2_wide/mod.rs +++ b/recursion/core/src/poseidon2_wide/mod.rs @@ -1,19 +1,72 @@ #![allow(clippy::needless_range_loop)] -use crate::poseidon2_wide::external::WIDTH; +use std::borrow::Borrow; +use std::borrow::BorrowMut; +use std::ops::Deref; + use p3_baby_bear::{MONTY_INVERSE, POSEIDON2_INTERNAL_MATRIX_DIAG_16_BABYBEAR_MONTY}; use p3_field::AbstractField; use p3_field::PrimeField32; -mod columns; -pub mod external; +pub mod air; +pub mod columns; +pub mod events; +pub mod trace; -pub use external::Poseidon2WideChip; use p3_poseidon2::matmul_internal; -#[derive(Debug, Clone)] -pub struct Poseidon2Event { - pub input: [F; WIDTH], +use self::columns::Poseidon2; +use self::columns::Poseidon2Degree3; +use self::columns::Poseidon2Degree9; +use self::columns::Poseidon2Mut; + +/// The width of the permutation. +pub const WIDTH: usize = 16; +pub const RATE: usize = WIDTH / 2; + +pub const NUM_EXTERNAL_ROUNDS: usize = 8; +pub const NUM_INTERNAL_ROUNDS: usize = 13; +pub const NUM_ROUNDS: usize = NUM_EXTERNAL_ROUNDS + NUM_INTERNAL_ROUNDS; + +/// A chip that implements addition for the opcode ADD. +#[derive(Default)] +pub struct Poseidon2WideChip { + pub fixed_log2_rows: Option, + pub pad: bool, +} + +impl<'a, const DEGREE: usize> Poseidon2WideChip { + /// Transmute a row it to an immutable Poseidon2 instance. + pub(crate) fn convert(row: impl Deref) -> Box + 'a> + where + T: Copy + 'a, + { + if DEGREE == 3 { + let convert: &Poseidon2Degree3 = (*row).borrow(); + Box::new(*convert) + } else if DEGREE == 9 || DEGREE == 17 { + let convert: &Poseidon2Degree9 = (*row).borrow(); + Box::new(*convert) + } else { + panic!("Unsupported degree"); + } + } + + /// Transmute a row it to a mutable Poseidon2 instance. + pub(crate) fn convert_mut<'b: 'a, F: PrimeField32>( + &self, + row: &'b mut Vec, + ) -> Box + 'a> { + if DEGREE == 3 { + let convert: &mut Poseidon2Degree3 = row.as_mut_slice().borrow_mut(); + Box::new(convert) + } else if DEGREE == 9 || DEGREE == 17 { + let convert: &mut Poseidon2Degree9 = row.as_mut_slice().borrow_mut(); + Box::new(convert) + } else { + panic!("Unsupported degree"); + } + } } pub fn apply_m_4(x: &mut [AF]) @@ -60,3 +113,211 @@ pub(crate) fn internal_linear_layer(state: &mut [F; WIDTH]) { let monty_inverse = F::from_wrapped_u32(MONTY_INVERSE.as_canonical_u32()); state.iter_mut().for_each(|i| *i *= monty_inverse.clone()); } + +#[cfg(test)] +pub(crate) mod tests { + use std::array; + use std::time::Instant; + + use crate::air::Block; + use crate::memory::MemoryRecord; + use crate::poseidon2_wide::events::Poseidon2HashEvent; + use crate::runtime::{ExecutionRecord, DIGEST_SIZE}; + use itertools::Itertools; + use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear}; + use p3_field::AbstractField; + use p3_matrix::dense::RowMajorMatrix; + use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral}; + use p3_symmetric::Permutation; + use rand::random; + use sp1_core::air::MachineAir; + use sp1_core::stark::StarkGenericConfig; + use sp1_core::utils::{inner_perm, uni_stark_prove, uni_stark_verify, BabyBearPoseidon2}; + use zkhash::ark_ff::UniformRand; + + use super::events::{Poseidon2AbsorbEvent, Poseidon2CompressEvent, Poseidon2FinalizeEvent}; + use super::{Poseidon2WideChip, WIDTH}; + + fn poseidon2_wide_prove_babybear_degree( + input_exec: ExecutionRecord, + ) { + let chip = Poseidon2WideChip:: { + fixed_log2_rows: None, + pad: true, + }; + + let trace: RowMajorMatrix = + chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); + + let config = BabyBearPoseidon2::compressed(); + let mut challenger = config.challenger(); + + let start = Instant::now(); + let proof = uni_stark_prove(&config, &chip, &mut challenger, trace); + let duration = start.elapsed().as_secs_f64(); + println!("proof duration = {:?}", duration); + + let mut challenger = config.challenger(); + let start = Instant::now(); + uni_stark_verify(&config, &chip, &mut challenger, &proof) + .expect("expected proof to be valid"); + + let duration = start.elapsed().as_secs_f64(); + println!("verify duration = {:?}", duration); + } + + fn dummy_memory_access_records( + memory_values: Vec, + prev_ts: BabyBear, + ts: BabyBear, + ) -> Vec> { + memory_values + .iter() + .map(|value| MemoryRecord::new_read(BabyBear::zero(), Block::from(*value), ts, prev_ts)) + .collect_vec() + } + + pub(crate) fn generate_test_execution_record( + incorrect_trace: bool, + ) -> ExecutionRecord { + const NUM_ABSORBS: usize = 1000; + const NUM_COMPRESSES: usize = 1000; + + let mut input_exec = ExecutionRecord::::default(); + + let rng = &mut rand::thread_rng(); + let permuter: Poseidon2< + BabyBear, + Poseidon2ExternalMatrixGeneral, + DiffusionMatrixBabyBear, + 16, + 7, + > = inner_perm(); + + // Generate hash test events. + let hash_test_input_sizes: [usize; NUM_ABSORBS] = + array::from_fn(|_| random::() % 128 + 1); + hash_test_input_sizes + .iter() + .enumerate() + .for_each(|(i, input_size)| { + let test_input = (0..*input_size).map(|_| BabyBear::rand(rng)).collect_vec(); + + let prev_ts = BabyBear::from_canonical_usize(i); + let absorb_ts = BabyBear::from_canonical_usize(i + 1); + let finalize_ts = BabyBear::from_canonical_usize(i + 2); + let hash_num = BabyBear::from_canonical_usize(i); + let start_addr = BabyBear::from_canonical_usize(i + 1); + let input_len = BabyBear::from_canonical_usize(*input_size); + + let mut absorb_event = + Poseidon2AbsorbEvent::new(absorb_ts, hash_num, start_addr, input_len, true); + + let mut hash_state = [BabyBear::zero(); WIDTH]; + let mut hash_state_cursor = 0; + absorb_event.populate_iterations( + start_addr, + input_len, + &dummy_memory_access_records(test_input.clone(), prev_ts, absorb_ts), + &permuter, + &mut hash_state, + &mut hash_state_cursor, + ); + + input_exec + .poseidon2_hash_events + .push(Poseidon2HashEvent::Absorb(absorb_event)); + + let do_perm = hash_state_cursor != 0; + let mut perm_output = permuter.permute(hash_state); + if incorrect_trace { + perm_output = [BabyBear::rand(rng); WIDTH]; + } + + let state = if do_perm { perm_output } else { hash_state }; + + input_exec + .poseidon2_hash_events + .push(Poseidon2HashEvent::Finalize(Poseidon2FinalizeEvent { + clk: finalize_ts, + hash_num, + output_ptr: start_addr, + output_records: dummy_memory_access_records( + state.as_slice().to_vec(), + absorb_ts, + finalize_ts, + )[0..DIGEST_SIZE] + .try_into() + .unwrap(), + state_cursor: hash_state_cursor, + perm_input: hash_state, + perm_output, + previous_state: hash_state, + state, + do_perm, + })); + }); + + let compress_test_inputs: Vec<[BabyBear; WIDTH]> = (0..NUM_COMPRESSES) + .map(|_| core::array::from_fn(|_| BabyBear::rand(rng))) + .collect_vec(); + compress_test_inputs + .iter() + .enumerate() + .for_each(|(i, input)| { + let mut result_array = permuter.permute(*input); + if incorrect_trace { + result_array = core::array::from_fn(|_| BabyBear::rand(rng)); + } + let prev_ts = BabyBear::from_canonical_usize(i); + let input_ts = BabyBear::from_canonical_usize(i + 1); + let output_ts = BabyBear::from_canonical_usize(i + 2); + + let dst = BabyBear::from_canonical_usize(i + 1); + let left = dst + BabyBear::from_canonical_usize(WIDTH / 2); + let right = left + BabyBear::from_canonical_usize(WIDTH / 2); + + let compress_event = Poseidon2CompressEvent { + clk: input_ts, + dst, + left, + right, + input: *input, + result_array, + input_records: dummy_memory_access_records(input.to_vec(), prev_ts, input_ts) + .try_into() + .unwrap(), + result_records: dummy_memory_access_records( + result_array.to_vec(), + input_ts, + output_ts, + ) + .try_into() + .unwrap(), + }; + + input_exec.poseidon2_compress_events.push(compress_event); + }); + + input_exec + } + + #[test] + fn poseidon2_wide_prove_babybear_success() { + // Generate test input exec record. + let input_exec = generate_test_execution_record(false); + + poseidon2_wide_prove_babybear_degree::<3>(input_exec.clone()); + poseidon2_wide_prove_babybear_degree::<9>(input_exec); + } + + #[test] + #[should_panic] + fn poseidon2_wide_prove_babybear_failure() { + // Generate test input exec record. + let input_exec = generate_test_execution_record(true); + + poseidon2_wide_prove_babybear_degree::<3>(input_exec.clone()); + poseidon2_wide_prove_babybear_degree::<9>(input_exec); + } +} diff --git a/recursion/core/src/poseidon2_wide/trace.rs b/recursion/core/src/poseidon2_wide/trace.rs new file mode 100644 index 0000000000..f2049efcc3 --- /dev/null +++ b/recursion/core/src/poseidon2_wide/trace.rs @@ -0,0 +1,548 @@ +use std::borrow::Borrow; + +use p3_air::BaseAir; +use p3_field::PrimeField32; +use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use sp1_core::{air::MachineAir, utils::pad_rows_fixed}; +use sp1_primitives::RC_16_30_U32; +use tracing::instrument; + +use crate::poseidon2_wide::columns::permutation::permutation_mut; +use crate::poseidon2_wide::events::Poseidon2HashEvent; +use crate::range_check::{RangeCheckEvent, RangeCheckOpcode}; +use crate::{ + poseidon2_wide::{external_linear_layer, NUM_EXTERNAL_ROUNDS, WIDTH}, + runtime::{ExecutionRecord, RecursionProgram}, +}; + +use super::events::{Poseidon2AbsorbEvent, Poseidon2CompressEvent, Poseidon2FinalizeEvent}; +use super::RATE; +use super::{internal_linear_layer, Poseidon2WideChip, NUM_INTERNAL_ROUNDS}; + +impl MachineAir for Poseidon2WideChip { + type Record = ExecutionRecord; + + type Program = RecursionProgram; + + fn name(&self) -> String { + format!("Poseidon2Wide {}", DEGREE) + } + + #[instrument(name = "generate poseidon2 wide trace", level = "debug", skip_all, fields(rows = input.poseidon2_compress_events.len()))] + fn generate_trace( + &self, + input: &ExecutionRecord, + output: &mut ExecutionRecord, + ) -> RowMajorMatrix { + let mut rows = Vec::new(); + + let num_columns = as BaseAir>::width(self); + + // Populate the hash events. + for event in &input.poseidon2_hash_events { + match event { + Poseidon2HashEvent::Absorb(absorb_event) => { + rows.extend(self.populate_absorb_event(absorb_event, num_columns, output)); + } + + Poseidon2HashEvent::Finalize(finalize_event) => { + rows.push(self.populate_finalize_event(finalize_event, num_columns)); + } + } + } + + // Populate the compress events. + for event in &input.poseidon2_compress_events { + rows.extend(self.populate_compress_event(event, num_columns)); + } + + if self.pad { + // Pad the trace to a power of two. + pad_rows_fixed( + &mut rows, + || { + let mut padded_row = vec![F::zero(); num_columns]; + self.populate_permutation([F::zero(); WIDTH], None, &mut padded_row); + padded_row + }, + self.fixed_log2_rows, + ); + } + + // Convert the trace to a row major matrix. + let trace = + RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), num_columns); + + #[cfg(debug_assertions)] + println!( + "poseidon2 wide trace dims is width: {:?}, height: {:?}", + trace.width(), + trace.height() + ); + + trace + } + + fn included(&self, record: &Self::Record) -> bool { + !record.poseidon2_compress_events.is_empty() + } +} + +impl Poseidon2WideChip { + pub fn populate_compress_event( + &self, + compress_event: &Poseidon2CompressEvent, + num_columns: usize, + ) -> Vec> { + let mut compress_rows = Vec::new(); + + let mut input_row = vec![F::zero(); num_columns]; + // Populate the control flow fields. + { + let mut cols = self.convert_mut(&mut input_row); + let control_flow = cols.control_flow_mut(); + + control_flow.is_compress = F::one(); + control_flow.is_syscall_row = F::one(); + } + + // Populate the syscall params fields. + { + let mut cols = self.convert_mut(&mut input_row); + let syscall_params = cols.syscall_params_mut().compress_mut(); + + syscall_params.clk = compress_event.clk; + syscall_params.dst_ptr = compress_event.dst; + syscall_params.left_ptr = compress_event.left; + syscall_params.right_ptr = compress_event.right; + } + + // Populate the memory fields. + { + let mut cols = self.convert_mut(&mut input_row); + let memory = cols.memory_mut(); + + memory.start_addr = compress_event.left; + // Populate the first half of the memory inputs in the memory struct. + for i in 0..WIDTH / 2 { + memory.memory_slot_used[i] = F::one(); + memory.memory_accesses[i].populate(&compress_event.input_records[i]); + } + } + + // Populate the opcode workspace fields. + { + let mut cols = self.convert_mut(&mut input_row); + let compress_cols = cols.opcode_workspace_mut().compress_mut(); + compress_cols.start_addr = compress_event.right; + + // Populate the second half of the memory inputs. + for i in 0..WIDTH / 2 { + compress_cols.memory_accesses[i] + .populate(&compress_event.input_records[i + WIDTH / 2]); + } + } + + // Populate the permutation fields. + self.populate_permutation( + compress_event.input, + Some(compress_event.result_array), + &mut input_row, + ); + + compress_rows.push(input_row); + + let mut output_row = vec![F::zero(); num_columns]; + { + let mut cols = self.convert_mut(&mut output_row); + let control_flow = cols.control_flow_mut(); + + control_flow.is_compress = F::one(); + control_flow.is_compress_output = F::one(); + } + + { + let mut cols = self.convert_mut(&mut output_row); + let syscall_cols = cols.syscall_params_mut().compress_mut(); + + syscall_cols.clk = compress_event.clk; + syscall_cols.dst_ptr = compress_event.dst; + syscall_cols.left_ptr = compress_event.left; + syscall_cols.right_ptr = compress_event.right; + } + + { + let mut cols = self.convert_mut(&mut output_row); + let memory = cols.memory_mut(); + + memory.start_addr = compress_event.dst; + // Populate the first half of the memory inputs in the memory struct. + for i in 0..WIDTH / 2 { + memory.memory_slot_used[i] = F::one(); + memory.memory_accesses[i].populate(&compress_event.result_records[i]); + } + } + + { + let mut cols = self.convert_mut(&mut output_row); + let compress_cols = cols.opcode_workspace_mut().compress_mut(); + + compress_cols.start_addr = compress_event.dst + F::from_canonical_usize(WIDTH / 2); + for i in 0..WIDTH / 2 { + compress_cols.memory_accesses[i] + .populate(&compress_event.result_records[i + WIDTH / 2]); + } + } + + self.populate_permutation(compress_event.result_array, None, &mut output_row); + + compress_rows.push(output_row); + compress_rows + } + + pub fn populate_absorb_event( + &self, + absorb_event: &Poseidon2AbsorbEvent, + num_columns: usize, + output: &mut ExecutionRecord, + ) -> Vec> { + let mut absorb_rows = Vec::new(); + + // We currently don't support an input_len of 0, since it will need special logic in the AIR. + assert!(absorb_event.input_len > F::zero()); + + let mut last_row_ending_cursor = 0; + let num_absorb_rows = absorb_event.iterations.len(); + + for (iter_num, absorb_iter) in absorb_event.iterations.iter().enumerate() { + let mut absorb_row = vec![F::zero(); num_columns]; + let is_syscall_row = iter_num == 0; + let is_last_row = iter_num == num_absorb_rows - 1; + + // Populate the control flow fields. + { + let mut cols = self.convert_mut(&mut absorb_row); + let control_flow = cols.control_flow_mut(); + + control_flow.is_absorb = F::one(); + control_flow.is_syscall_row = F::from_bool(is_syscall_row); + control_flow.is_absorb_no_perm = F::from_bool(!absorb_iter.do_perm); + control_flow.is_absorb_not_last_row = F::from_bool(!is_last_row); + } + + // Populate the syscall params fields. + { + let mut cols = self.convert_mut(&mut absorb_row); + let syscall_params = cols.syscall_params_mut().absorb_mut(); + + syscall_params.clk = absorb_event.clk; + syscall_params.hash_num = absorb_event.hash_num; + syscall_params.input_ptr = absorb_event.input_addr; + syscall_params.input_len = absorb_event.input_len; + } + + // Populate the memory fields. + { + let mut cols = self.convert_mut(&mut absorb_row); + let memory = cols.memory_mut(); + + memory.start_addr = absorb_iter.start_addr; + for (i, input_record) in absorb_iter.input_records.iter().enumerate() { + memory.memory_slot_used[i + absorb_iter.state_cursor] = F::one(); + memory.memory_accesses[i + absorb_iter.state_cursor].populate(input_record); + } + } + + // Populate the opcode workspace fields. + { + let mut cols = self.convert_mut(&mut absorb_row); + let absorb_workspace = cols.opcode_workspace_mut().absorb_mut(); + + let num_remaining_rows = num_absorb_rows - 1 - iter_num; + absorb_workspace.num_remaining_rows = F::from_canonical_usize(num_remaining_rows); + output.add_range_check_events(&[RangeCheckEvent::new( + RangeCheckOpcode::U16, + num_remaining_rows as u16, + )]); + + // Calculate last_row_num_consumed. + // For absorb calls that span multiple rows (e.g. the last row is not the syscall row), + // last_row_num_consumed = (input_len + state_cursor) % 8 at the syscall row. + // For absorb calls that are only one row, last_row_num_consumed = absorb_event.input_len. + if is_syscall_row { + last_row_ending_cursor = (absorb_iter.state_cursor + + absorb_event.input_len.as_canonical_u32() as usize + - 1) + % RATE; + } + + absorb_workspace.last_row_ending_cursor = + F::from_canonical_usize(last_row_ending_cursor); + + absorb_workspace + .last_row_ending_cursor_is_seven + .populate_from_field_element( + F::from_canonical_usize(last_row_ending_cursor) + - F::from_canonical_usize(7), + ); + + (0..3).for_each(|i| { + absorb_workspace.last_row_ending_cursor_bitmap[i] = + F::from_bool((last_row_ending_cursor) & (1 << i) == (1 << i)) + }); + + absorb_workspace + .num_remaining_rows_is_zero + .populate(num_remaining_rows as u32); + + absorb_workspace.is_syscall_not_last_row = + F::from_bool(is_syscall_row && !is_last_row); + absorb_workspace.is_syscall_is_last_row = + F::from_bool(is_syscall_row && is_last_row); + absorb_workspace.not_syscall_not_last_row = + F::from_bool(!is_syscall_row && !is_last_row); + absorb_workspace.not_syscall_is_last_row = + F::from_bool(!is_syscall_row && is_last_row); + absorb_workspace.is_last_row_ending_cursor_is_seven = + F::from_bool(is_last_row && last_row_ending_cursor == 7); + absorb_workspace.is_last_row_ending_cursor_not_seven = + F::from_bool(is_last_row && last_row_ending_cursor != 7); + + absorb_workspace.state = absorb_iter.state; + absorb_workspace.previous_state = absorb_iter.previous_state; + absorb_workspace.state_cursor = F::from_canonical_usize(absorb_iter.state_cursor); + absorb_workspace.is_first_hash_row = + F::from_bool(iter_num == 0 && absorb_event.is_first_aborb); + + absorb_workspace.start_mem_idx_bitmap[absorb_iter.state_cursor] = F::one(); + if is_last_row { + absorb_workspace.end_mem_idx_bitmap[last_row_ending_cursor] = F::one(); + } + } + + // Populate the permutation fields. + self.populate_permutation( + absorb_iter.perm_input, + if absorb_iter.do_perm { + Some(absorb_iter.perm_output) + } else { + None + }, + &mut absorb_row, + ); + + absorb_rows.push(absorb_row); + } + + absorb_rows + } + + pub fn populate_finalize_event( + &self, + finalize_event: &Poseidon2FinalizeEvent, + num_columns: usize, + ) -> Vec { + let mut finalize_row = vec![F::zero(); num_columns]; + + // Populate the control flow fields. + { + let mut cols = self.convert_mut(&mut finalize_row); + let control_flow = cols.control_flow_mut(); + control_flow.is_finalize = F::one(); + control_flow.is_syscall_row = F::one(); + } + + // Populate the syscall params fields. + { + let mut cols = self.convert_mut(&mut finalize_row); + + let syscall_params = cols.syscall_params_mut().finalize_mut(); + syscall_params.clk = finalize_event.clk; + syscall_params.hash_num = finalize_event.hash_num; + syscall_params.output_ptr = finalize_event.output_ptr; + } + + // Populate the memory fields. + { + let mut cols = self.convert_mut(&mut finalize_row); + let memory = cols.memory_mut(); + + memory.start_addr = finalize_event.output_ptr; + for i in 0..WIDTH / 2 { + memory.memory_slot_used[i] = F::one(); + memory.memory_accesses[i].populate(&finalize_event.output_records[i]); + } + } + + // Populate the opcode workspace fields. + { + let mut cols = self.convert_mut(&mut finalize_row); + let finalize_workspace = cols.opcode_workspace_mut().finalize_mut(); + + finalize_workspace.previous_state = finalize_event.previous_state; + finalize_workspace.state = finalize_event.state; + finalize_workspace.state_cursor = F::from_canonical_usize(finalize_event.state_cursor); + finalize_workspace + .state_cursor_is_zero + .populate(finalize_event.state_cursor as u32); + } + + // Populate the permutation fields. + self.populate_permutation( + finalize_event.perm_input, + if finalize_event.do_perm { + Some(finalize_event.perm_output) + } else { + None + }, + &mut finalize_row, + ); + + finalize_row + } + + pub fn populate_permutation( + &self, + input: [F; WIDTH], + expected_output: Option<[F; WIDTH]>, + input_row: &mut [F], + ) { + let mut permutation = permutation_mut::(input_row); + + let ( + external_rounds_state, + internal_rounds_state, + internal_rounds_s0, + mut external_sbox, + mut internal_sbox, + output_state, + ) = permutation.get_cols_mut(); + + external_rounds_state[0] = input; + external_linear_layer(&mut external_rounds_state[0]); + + // Apply the first half of external rounds. + for r in 0..NUM_EXTERNAL_ROUNDS / 2 { + let next_state = + self.populate_external_round(external_rounds_state, &mut external_sbox, r); + if r == NUM_EXTERNAL_ROUNDS / 2 - 1 { + *internal_rounds_state = next_state; + } else { + external_rounds_state[r + 1] = next_state; + } + } + + // Apply the internal rounds. + external_rounds_state[NUM_EXTERNAL_ROUNDS / 2] = self.populate_internal_rounds( + internal_rounds_state, + internal_rounds_s0, + &mut internal_sbox, + ); + + // Apply the second half of external rounds. + for r in NUM_EXTERNAL_ROUNDS / 2..NUM_EXTERNAL_ROUNDS { + let next_state = + self.populate_external_round(external_rounds_state, &mut external_sbox, r); + if r == NUM_EXTERNAL_ROUNDS - 1 { + for i in 0..WIDTH { + output_state[i] = next_state[i]; + if let Some(expected_output) = expected_output { + assert_eq!(expected_output[i], next_state[i]); + } + } + } else { + external_rounds_state[r + 1] = next_state; + } + } + } + + fn populate_external_round( + &self, + external_rounds_state: &[[F; WIDTH]], + sbox: &mut Option<&mut [[F; WIDTH]; NUM_EXTERNAL_ROUNDS]>, + r: usize, + ) -> [F; WIDTH] { + let mut state = { + let round_state: &[F; WIDTH] = external_rounds_state[r].borrow(); + + // Add round constants. + // + // Optimization: Since adding a constant is a degree 1 operation, we can avoid adding + // columns for it, and instead include it in the constraint for the x^3 part of the sbox. + let round = if r < NUM_EXTERNAL_ROUNDS / 2 { + r + } else { + r + NUM_INTERNAL_ROUNDS + }; + let mut add_rc = *round_state; + for i in 0..WIDTH { + add_rc[i] += F::from_wrapped_u32(RC_16_30_U32[round][i]); + } + + // Apply the sboxes. + // Optimization: since the linear layer that comes after the sbox is degree 1, we can + // avoid adding columns for the result of the sbox, and instead include the x^3 -> x^7 + // part of the sbox in the constraint for the linear layer + let mut sbox_deg_7: [F; 16] = [F::zero(); WIDTH]; + let mut sbox_deg_3: [F; 16] = [F::zero(); WIDTH]; + for i in 0..WIDTH { + sbox_deg_3[i] = add_rc[i] * add_rc[i] * add_rc[i]; + sbox_deg_7[i] = sbox_deg_3[i] * sbox_deg_3[i] * add_rc[i]; + } + + if let Some(sbox) = sbox.as_deref_mut() { + sbox[r] = sbox_deg_3; + } + + sbox_deg_7 + }; + + // Apply the linear layer. + external_linear_layer(&mut state); + state + } + + fn populate_internal_rounds( + &self, + internal_rounds_state: &[F; WIDTH], + internal_rounds_s0: &mut [F; NUM_INTERNAL_ROUNDS - 1], + sbox: &mut Option<&mut [F; NUM_INTERNAL_ROUNDS]>, + ) -> [F; WIDTH] { + let mut state: [F; WIDTH] = *internal_rounds_state; + let mut sbox_deg_3: [F; NUM_INTERNAL_ROUNDS] = [F::zero(); NUM_INTERNAL_ROUNDS]; + for r in 0..NUM_INTERNAL_ROUNDS { + // Add the round constant to the 0th state element. + // Optimization: Since adding a constant is a degree 1 operation, we can avoid adding + // columns for it, just like for external rounds. + let round = r + NUM_EXTERNAL_ROUNDS / 2; + let add_rc = state[0] + F::from_wrapped_u32(RC_16_30_U32[round][0]); + + // Apply the sboxes. + // Optimization: since the linear layer that comes after the sbox is degree 1, we can + // avoid adding columns for the result of the sbox, just like for external rounds. + sbox_deg_3[r] = add_rc * add_rc * add_rc; + let sbox_deg_7 = sbox_deg_3[r] * sbox_deg_3[r] * add_rc; + + // Apply the linear layer. + state[0] = sbox_deg_7; + internal_linear_layer(&mut state); + + // Optimization: since we're only applying the sbox to the 0th state element, we only + // need to have columns for the 0th state element at every step. This is because the + // linear layer is degree 1, so all state elements at the end can be expressed as a + // degree-3 polynomial of the state at the beginning of the internal rounds and the 0th + // state element at rounds prior to the current round + if r < NUM_INTERNAL_ROUNDS - 1 { + internal_rounds_s0[r] = state[0]; + } + } + + let ret_state = state; + + if let Some(sbox) = sbox.as_deref_mut() { + *sbox = sbox_deg_3; + } + + ret_state + } +} diff --git a/recursion/core/src/runtime/mod.rs b/recursion/core/src/runtime/mod.rs index 741ce6e7ea..4d6b5a2631 100644 --- a/recursion/core/src/runtime/mod.rs +++ b/recursion/core/src/runtime/mod.rs @@ -4,6 +4,7 @@ mod program; mod record; mod utils; +use std::array; use std::collections::VecDeque; use std::process::exit; use std::{marker::PhantomData, sync::Arc}; @@ -25,7 +26,9 @@ use crate::cpu::CpuEvent; use crate::exp_reverse_bits::ExpReverseBitsLenEvent; use crate::fri_fold::FriFoldEvent; use crate::memory::{compute_addr_diff, MemoryRecord}; -use crate::poseidon2::Poseidon2Event; +use crate::poseidon2_wide::events::{ + Poseidon2AbsorbEvent, Poseidon2CompressEvent, Poseidon2FinalizeEvent, Poseidon2HashEvent, +}; use crate::range_check::{RangeCheckEvent, RangeCheckOpcode}; use p3_field::{ExtensionField, PrimeField32}; @@ -131,6 +134,12 @@ pub struct Runtime, Diffusion> { >, >, + p2_hash_state: [F; PERMUTATION_WIDTH], + + p2_hash_state_cursor: usize, + + p2_current_hash_num: Option, + _marker: PhantomData, } @@ -179,6 +188,9 @@ where access: CpuRecord::default(), witness_stream: VecDeque::new(), cycle_tracker: HashMap::new(), + p2_hash_state: [F::zero(); PERMUTATION_WIDTH], + p2_hash_state_cursor: 0, + p2_current_hash_num: None, _marker: PhantomData, } } @@ -209,6 +221,9 @@ where access: CpuRecord::default(), witness_stream: VecDeque::new(), cycle_tracker: HashMap::new(), + p2_hash_state: [F::zero(); PERMUTATION_WIDTH], + p2_hash_state_cursor: 0, + p2_current_hash_num: None, _marker: PhantomData, } } @@ -689,16 +704,106 @@ where )); } - self.record.poseidon2_events.push(Poseidon2Event { - clk: timestamp, - dst, - left, - right, - input: array, - result_array: result, - input_records, - result_records: result_records.try_into().unwrap(), + self.record + .poseidon2_compress_events + .push(Poseidon2CompressEvent { + clk: timestamp, + dst, + left, + right, + input: array, + result_array: result, + input_records, + result_records: result_records.try_into().unwrap(), + }); + + (a, b, c) = (a_val, b_val, c_val); + } + + Opcode::Poseidon2Absorb => { + self.nb_poseidons += 1; + let (a_val, b_val, c_val) = self.all_rr(&instruction); + + let hash_num = a_val[0]; + let start_addr = b_val[0]; + let input_len = c_val[0]; + let timestamp = self.clk; + + // We currently don't support an input_len of 0, since it will need special logic in the AIR. + assert!(input_len > F::zero()); + + let is_first_absorb = self.p2_current_hash_num.is_none() + || self.p2_current_hash_num.unwrap() != hash_num; + + let mut absorb_event = Poseidon2AbsorbEvent::new( + timestamp, + hash_num, + start_addr, + input_len, + is_first_absorb, + ); + + let memory_records: Vec> = (0..input_len.as_canonical_u32()) + .map(|i| self.mr(start_addr + F::from_canonical_u32(i), timestamp).0) + .collect_vec(); + + let permuter = self.perm.as_ref().unwrap().clone(); + absorb_event.populate_iterations( + start_addr, + input_len, + &memory_records, + &permuter, + &mut self.p2_hash_state, + &mut self.p2_hash_state_cursor, + ); + + // Update the current hash number. + self.p2_current_hash_num = Some(hash_num); + + self.record + .poseidon2_hash_events + .push(Poseidon2HashEvent::Absorb(absorb_event)); + + (a, b, c) = (a_val, b_val, c_val); + } + + Opcode::Poseidon2Finalize => { + self.nb_poseidons += 1; + let (a_val, b_val, c_val) = self.all_rr(&instruction); + + let p2_hash_num = a_val[0]; + let output_ptr = b_val[0]; + let timestamp = self.clk; + + let do_perm = self.p2_hash_state_cursor != 0; + let perm_output = self.perm.as_ref().unwrap().permute(self.p2_hash_state); + let state = if do_perm { + perm_output + } else { + self.p2_hash_state + }; + let output_records: [MemoryRecord; DIGEST_SIZE] = array::from_fn(|i| { + self.mw(output_ptr + F::from_canonical_usize(i), state[i], timestamp) }); + + self.record + .poseidon2_hash_events + .push(Poseidon2HashEvent::Finalize(Poseidon2FinalizeEvent { + clk: timestamp, + hash_num: p2_hash_num, + output_ptr, + output_records, + state_cursor: self.p2_hash_state_cursor, + perm_input: self.p2_hash_state, + perm_output, + previous_state: self.p2_hash_state, + state, + do_perm, + })); + + self.p2_hash_state_cursor = 0; + self.p2_hash_state = [F::zero(); PERMUTATION_WIDTH]; + (a, b, c) = (a_val, b_val, c_val); } Opcode::HintBits => { diff --git a/recursion/core/src/runtime/opcode.rs b/recursion/core/src/runtime/opcode.rs index d6db8abc13..fa9913dd34 100644 --- a/recursion/core/src/runtime/opcode.rs +++ b/recursion/core/src/runtime/opcode.rs @@ -32,9 +32,13 @@ pub enum Opcode { TRAP = 30, HALT = 31, - // Hash instructions. + // Poseidon2 compress. Poseidon2Compress = 39, + // Poseidon2 hash. + Poseidon2Absorb = 46, + Poseidon2Finalize = 47, + // Bit instructions. HintBits = 32, diff --git a/recursion/core/src/runtime/record.rs b/recursion/core/src/runtime/record.rs index 0c5663d97d..f1e00b78c8 100644 --- a/recursion/core/src/runtime/record.rs +++ b/recursion/core/src/runtime/record.rs @@ -10,14 +10,15 @@ use crate::air::Block; use crate::cpu::CpuEvent; use crate::exp_reverse_bits::ExpReverseBitsLenEvent; use crate::fri_fold::FriFoldEvent; -use crate::poseidon2::Poseidon2Event; +use crate::poseidon2_wide::events::{Poseidon2CompressEvent, Poseidon2HashEvent}; use crate::range_check::RangeCheckEvent; #[derive(Default, Debug, Clone)] pub struct ExecutionRecord { pub program: Arc>, pub cpu_events: Vec>, - pub poseidon2_events: Vec>, + pub poseidon2_compress_events: Vec>, + pub poseidon2_hash_events: Vec>, pub fri_fold_events: Vec>, pub range_check_events: HashMap, pub exp_reverse_bits_len_events: Vec>, @@ -51,7 +52,14 @@ impl MachineRecord for ExecutionRecord { fn stats(&self) -> HashMap { let mut stats = HashMap::new(); stats.insert("cpu_events".to_string(), self.cpu_events.len()); - stats.insert("poseidon2_events".to_string(), self.poseidon2_events.len()); + stats.insert( + "poseidon2_events".to_string(), + self.poseidon2_compress_events.len(), + ); + stats.insert( + "poseidon2_events".to_string(), + self.poseidon2_hash_events.len(), + ); stats.insert("fri_fold_events".to_string(), self.fri_fold_events.len()); stats.insert( "range_check_events".to_string(), diff --git a/recursion/core/src/stark/mod.rs b/recursion/core/src/stark/mod.rs index 8f24874f5d..b08e12bea8 100644 --- a/recursion/core/src/stark/mod.rs +++ b/recursion/core/src/stark/mod.rs @@ -4,8 +4,8 @@ pub mod utils; use crate::{ cpu::CpuChip, exp_reverse_bits::ExpReverseBitsLenChip, fri_fold::FriFoldChip, - memory::MemoryGlobalChip, multi::MultiChip, poseidon2::Poseidon2Chip, - poseidon2_wide::Poseidon2WideChip, program::ProgramChip, range_check::RangeCheckChip, + memory::MemoryGlobalChip, multi::MultiChip, poseidon2_wide::Poseidon2WideChip, + program::ProgramChip, range_check::RangeCheckChip, }; use core::iter::once; use p3_field::{extension::BinomiallyExtendable, PrimeField32}; @@ -16,19 +16,20 @@ use std::marker::PhantomData; use crate::runtime::D; pub type RecursionAirWideDeg3 = RecursionAir; -pub type RecursionAirSkinnyDeg9 = RecursionAir; +pub type RecursionAirWideDeg9 = RecursionAir; +pub type RecursionAirWideDeg17 = RecursionAir; #[derive(MachineAir)] #[sp1_core_path = "sp1_core"] #[execution_record_path = "crate::runtime::ExecutionRecord"] #[program_path = "crate::runtime::RecursionProgram"] #[builder_path = "crate::air::SP1RecursionAirBuilder"] +#[eval_trait_bound = "AB::Var: 'static"] pub enum RecursionAir, const DEGREE: usize> { Program(ProgramChip), Cpu(CpuChip), MemoryGlobal(MemoryGlobalChip), Poseidon2Wide(Poseidon2WideChip), - Poseidon2Skinny(Poseidon2Chip), FriFold(FriFoldChip), RangeCheck(RangeCheckChip), Multi(MultiChip), @@ -76,6 +77,7 @@ impl, const DEGREE: usize> RecursionAi DEGREE, > { fixed_log2_rows: None, + pad: true, }))) .chain(once(RecursionAir::FriFold(FriFoldChip:: { fixed_log2_rows: None, @@ -116,14 +118,14 @@ impl, const DEGREE: usize> RecursionAi pub fn get_wrap_all() -> Vec { once(RecursionAir::Program(ProgramChip)) .chain(once(RecursionAir::Cpu(CpuChip { - fixed_log2_rows: Some(20), + fixed_log2_rows: Some(19), _phantom: PhantomData, }))) .chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip { - fixed_log2_rows: Some(19), + fixed_log2_rows: Some(20), }))) .chain(once(RecursionAir::Multi(MultiChip { - fixed_log2_rows: Some(19), + fixed_log2_rows: Some(12), }))) .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default()))) .chain(once(RecursionAir::ExpReverseBitsLen( diff --git a/recursion/core/src/stark/utils.rs b/recursion/core/src/stark/utils.rs index 15c642c046..66fa892b81 100644 --- a/recursion/core/src/stark/utils.rs +++ b/recursion/core/src/stark/utils.rs @@ -7,7 +7,7 @@ use crate::air::Block; use crate::runtime::RecursionProgram; use crate::runtime::Runtime; use crate::stark::RecursionAir; -use crate::stark::RecursionAirSkinnyDeg9; +use crate::stark::RecursionAirWideDeg9; use p3_field::PrimeField32; use sp1_core::utils::run_test_machine; use std::collections::VecDeque; @@ -54,7 +54,7 @@ pub fn run_test_recursion( } if test_config == TestConfig::All || test_config == TestConfig::SkinnyDeg7 { - let machine = RecursionAirSkinnyDeg9::machine(BabyBearPoseidon2::compressed()); + let machine = RecursionAirWideDeg9::machine(BabyBearPoseidon2::compressed()); let (pk, vk) = machine.setup(&program); let record = runtime.record.clone(); let result = run_test_machine(record, machine, pk, vk); @@ -64,7 +64,7 @@ pub fn run_test_recursion( } if test_config == TestConfig::All || test_config == TestConfig::SkinnyDeg7Wrap { - let machine = RecursionAirSkinnyDeg9::wrap_machine(BabyBearPoseidon2::compressed()); + let machine = RecursionAirWideDeg9::wrap_machine(BabyBearPoseidon2::compressed()); let (pk, vk) = machine.setup(&program); let record = runtime.record.clone(); let result = run_test_machine(record, machine, pk, vk); diff --git a/recursion/program/src/machine/mod.rs b/recursion/program/src/machine/mod.rs index 645e46b911..64e6012f53 100644 --- a/recursion/program/src/machine/mod.rs +++ b/recursion/program/src/machine/mod.rs @@ -77,7 +77,7 @@ mod tests { let (compress_pk, compress_vk) = compress_machine.setup(&compress_program); // Make the wrap program. - let wrap_machine = RecursionAir::<_, 5>::machine(BabyBearPoseidon2Outer::default()); + let wrap_machine = RecursionAir::<_, 17>::machine(BabyBearPoseidon2Outer::default()); let wrap_program = SP1RootVerifier::::build(&compress_machine, &compress_vk, false); diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index c9c8e4000e..fddd2a83d5 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -477,6 +477,10 @@ pub(crate) mod tests { // Observe all the commitments. let mut builder = Builder::::default(); + // Add a hash invocation, since the poseidon2 table expects that it's in the first row. + let hash_input = builder.constant(vec![vec![F::one()]]); + builder.poseidon2_hash_x(&hash_input); + let mut challenger = DuplexChallengerVariable::new(&mut builder); let preprocessed_commit_val: [F; DIGEST_SIZE] = vk.commit.into(); @@ -518,6 +522,10 @@ pub(crate) mod tests { fn test_public_values_program() -> RecursionProgram { let mut builder = Builder::::default(); + // Add a hash invocation, since the poseidon2 table expects that it's in the first row. + let hash_input = builder.constant(vec![vec![F::one()]]); + builder.poseidon2_hash_x(&hash_input); + let mut public_values_stream: Vec> = (0..RECURSIVE_PROOF_NUM_PV_ELTS) .map(|_| builder.uninit()) .collect();