From 1f257029ea6379da068071155c568f6d8c6c23b3 Mon Sep 17 00:00:00 2001 From: mrmr1993 Date: Tue, 10 Oct 2023 06:34:32 +0100 Subject: [PATCH] Chunking for public input --- src/lib/pickles/wrap_verifier.ml | 117 ++++++++++++++++++------------- 1 file changed, 68 insertions(+), 49 deletions(-) diff --git a/src/lib/pickles/wrap_verifier.ml b/src/lib/pickles/wrap_verifier.ml index da3e52994d5..58468e6e965 100644 --- a/src/lib/pickles/wrap_verifier.ml +++ b/src/lib/pickles/wrap_verifier.ml @@ -304,19 +304,23 @@ struct , (domains : (Domains.t, n) Vector.t) ) srs i = Vector.map domains ~f:(fun d -> let d = Int.pow 2 (Domain.log2_size d.h) in - match[@warning "-4"] + let chunks = (Kimchi_bindings.Protocol.SRS.Fp.lagrange_commitment srs d i) .unshifted - with - | [| Finite g |] -> - let g = Inner_curve.Constant.of_affine g in - Inner_curve.constant g - | _ -> - assert false ) + in + Array.map chunks ~f:(function + | Finite g -> + let g = Inner_curve.Constant.of_affine g in + Inner_curve.constant g + | Infinity -> + (* Point at infinity should be impossible in the SRS *) + assert false ) ) |> Vector.map2 (which_branch :> (Boolean.var, n) Vector.t) - ~f:(fun b (x, y) -> Field.((b :> t) * x, (b :> t) * y)) - |> Vector.reduce_exn ~f:(Double.map2 ~f:Field.( + )) + ~f:(fun b pts -> + Array.map pts ~f:(fun (x, y) -> Field.((b :> t) * x, (b :> t) * y)) + ) + |> Vector.reduce_exn ~f:(Array.map2_exn ~f:(Double.map2 ~f:Field.( + ))) let scaled_lagrange (type n) c ~domain: @@ -324,24 +328,29 @@ struct , (domains : (Domains.t, n) Vector.t) ) srs i = Vector.map domains ~f:(fun d -> let d = Int.pow 2 (Domain.log2_size d.h) in - match[@warning "-4"] + let chunks = (Kimchi_bindings.Protocol.SRS.Fp.lagrange_commitment srs d i) .unshifted - with - | [| Finite g |] -> - let g = Inner_curve.Constant.of_affine g in - Inner_curve.Constant.scale g c |> Inner_curve.constant - | _ -> - assert false ) + in + Array.map chunks ~f:(function + | Finite g -> + let g = Inner_curve.Constant.of_affine g in + Inner_curve.Constant.scale g c |> Inner_curve.constant + | Infinity -> + (* Point at infinity should be impossible in the SRS *) + assert false ) ) |> Vector.map2 (which_branch :> (Boolean.var, n) Vector.t) - ~f:(fun b (x, y) -> Field.((b :> t) * x, (b :> t) * y)) - |> Vector.reduce_exn ~f:(Double.map2 ~f:Field.( + )) + ~f:(fun b pts -> + Array.map pts ~f:(fun (x, y) -> Field.((b :> t) * x, (b :> t) * y)) + ) + |> Vector.reduce_exn ~f:(Array.map2_exn ~f:(Double.map2 ~f:Field.( + ))) let lagrange_with_correction (type n) ~input_length ~domain: ( (which_branch : n One_hot_vector.t) - , (domains : (Domains.t, n) Vector.t) ) srs i : Inner_curve.t Double.t = + , (domains : (Domains.t, n) Vector.t) ) srs i : + Inner_curve.t Double.t array = with_label __LOC__ (fun () -> let actual_shift = (* TODO: num_bits should maybe be input_length - 1. *) @@ -352,18 +361,19 @@ struct in let base_and_correction (h : Domain.t) = let d = Int.pow 2 (Domain.log2_size h) in - match[@warning "-4"] + let chunks = (Kimchi_bindings.Protocol.SRS.Fp.lagrange_commitment srs d i) .unshifted - with - | [| Finite g |] -> - let open Inner_curve.Constant in - let g = of_affine g in - ( Inner_curve.constant g - , Inner_curve.constant (negate (pow2pow g actual_shift)) ) - | xs -> - failwithf "expected commitment to have length 1. got %d" - (Array.length xs) () + in + Array.map chunks ~f:(function + | Finite g -> + let open Inner_curve.Constant in + let g = of_affine g in + ( Inner_curve.constant g + , Inner_curve.constant (negate (pow2pow g actual_shift)) ) + | Infinity -> + (* Point at infinity should be impossible in the SRS *) + assert false ) in match domains with | [] -> @@ -377,11 +387,16 @@ struct |> Vector.map2 (which_branch :> (Boolean.var, n) Vector.t) ~f:(fun b pr -> - Double.map pr ~f:(fun (x, y) -> - Field.((b :> t) * x, (b :> t) * y) ) ) + Array.map pr + ~f: + (Double.map ~f:(fun (x, y) -> + Field.((b :> t) * x, (b :> t) * y) ) ) ) |> Vector.reduce_exn - ~f:(Double.map2 ~f:(Double.map2 ~f:Field.( + ))) - |> Double.map ~f:(Double.map ~f:(Util.seal (module Impl))) ) + ~f: + (Array.map2_exn + ~f:(Double.map2 ~f:(Double.map2 ~f:Field.( + ))) ) + |> Array.map + ~f:(Double.map ~f:(Double.map ~f:(Util.seal (module Impl)))) ) let _h_precomp = Lazy.map ~f:Inner_curve.Scaling_precomputation.create Generators.h @@ -848,36 +863,40 @@ struct (List.filter_map terms ~f:(function | `Cond_add _ -> None - | `Add_with_correction (_, (_, corr)) -> - Some corr ) ) - ~f:(Ops.add_fast ?check_finite:None) ) + | `Add_with_correction (_, chunks) -> + Some (Array.map ~f:snd chunks) ) ) + ~f:(Array.map2_exn ~f:(Ops.add_fast ?check_finite:None)) ) in with_label __LOC__ (fun () -> let init = List.fold (List.filter_map ~f:Fn.id constant_part) ~init:correction - ~f:(Ops.add_fast ?check_finite:None) + ~f:(Array.map2_exn ~f:(Ops.add_fast ?check_finite:None)) in List.fold terms ~init ~f:(fun acc term -> match term with | `Cond_add (b, g) -> with_label __LOC__ (fun () -> - Inner_curve.if_ b ~then_:(Ops.add_fast g acc) - ~else_:acc ) - | `Add_with_correction ((x, num_bits), (g, _)) -> - Ops.add_fast acc - (Ops.scale_fast2' - (module Other_field.With_top_bit0) - g x ~num_bits ) ) ) ) - |> Inner_curve.negate + Array.map2_exn acc g ~f:(fun acc g -> + Inner_curve.if_ b ~then_:(Ops.add_fast g acc) + ~else_:acc ) ) + | `Add_with_correction ((x, num_bits), chunks) -> + Array.map2_exn acc chunks ~f:(fun acc (g, _) -> + Ops.add_fast acc + (Ops.scale_fast2' + (module Other_field.With_top_bit0) + g x ~num_bits ) ) ) ) ) + |> Array.map ~f:Inner_curve.negate in let x_hat = with_label "x_hat blinding" (fun () -> - Ops.add_fast x_hat - (Inner_curve.constant (Lazy.force Generators.h)) ) + Array.map x_hat ~f:(fun x_hat -> + Ops.add_fast x_hat + (Inner_curve.constant (Lazy.force Generators.h)) ) ) in - absorb sponge PC (Boolean.true_, x_hat) ; + Array.iter x_hat ~f:(fun x_hat -> + absorb sponge PC (Boolean.true_, x_hat) ) ; let w_comm = messages.w_comm in Vector.iter ~f:absorb_g w_comm ; let joint_combiner = @@ -1234,7 +1253,7 @@ struct Pickles_types.Opt.Maybe (keep, [| p |]) ) |> append_chain (snd (Max_proofs_verified.add len_6)) - ( [ [| x_hat |] + ( [ x_hat ; [| ft_comm |] ; z_comm ; m.generic_comm