From 609d47dea177c8bc7c60593600548f0eacc4477f Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 1 Jun 2025 11:36:14 -0500 Subject: [PATCH 01/92] refactor: addInstance instead of series etc --- internal/gkr/gkrinfo/info.go | 14 +-- std/gkrapi/compile.go | 228 +++++++++++++++++------------------ 2 files changed, 117 insertions(+), 125 deletions(-) diff --git a/internal/gkr/gkrinfo/info.go b/internal/gkr/gkrinfo/info.go index de9a845e8d..55fdb7e71d 100644 --- a/internal/gkr/gkrinfo/info.go +++ b/internal/gkr/gkrinfo/info.go @@ -30,13 +30,12 @@ type ( IsGkrVar []bool } StoringInfo struct { - Circuit Circuit - Dependencies [][]InputDependency // nil for input wires - NbInstances int - HashName string - SolveHintID solver.HintID - ProveHintID solver.HintID - Prints []PrintInfo + Circuit Circuit + NbInstances int + HashName string + SolveHintID solver.HintID + ProveHintID solver.HintID + Prints []PrintInfo } Permutations struct { @@ -58,7 +57,6 @@ func (w Wire) IsOutput() bool { func (d *StoringInfo) NewInputVariable() int { i := len(d.Circuit) d.Circuit = append(d.Circuit, Wire{}) - d.Dependencies = append(d.Dependencies, nil) return i } diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 6293e86add..d8d4d3e960 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -1,7 +1,6 @@ package gkrapi import ( - "errors" "fmt" "math/bits" @@ -23,18 +22,16 @@ type circuitDataForSnark struct { assignments gkrtypes.WireAssignment } -type Solution struct { - toStore gkrinfo.StoringInfo - assignments gkrtypes.WireAssignment - parentApi frontend.API - permutations gkrinfo.Permutations -} +type InitialChallengeGetter func() []frontend.Variable -func (api *API) nbInstances() int { - if len(api.assignments) == 0 { - return -1 - } - return api.assignments.NbInstances() +// Circuit represents a GKR circuit. +type Circuit struct { + toStore gkrinfo.StoringInfo + assignments gkrtypes.WireAssignment + getInitialChallenges InitialChallengeGetter // optional getter for the initial Fiat-Shamir challenge + ins []gkr.Variable + outs []gkr.Variable + api frontend.API // the parent API used for hints } // New creates a new GKR API @@ -50,162 +47,159 @@ func log2(x uint) int { return bits.TrailingZeros(x) } -// Series like in an electric circuit, binds an input of an instance to an output of another -func (api *API) Series(input, output gkr.Variable, inputInstance, outputInstance int) *API { - if api.assignments[input][inputInstance] != nil { - panic("dependency attempting to override explicit value assignment") - } - api.toStore.Dependencies[input] = - append(api.toStore.Dependencies[input], gkrinfo.InputDependency{ - OutputWire: int(output), - OutputInstance: outputInstance, - InputInstance: inputInstance, - }) - return api +// NewInput creates a new input variable. +func (api *API) NewInput() gkr.Variable { + return gkr.Variable(api.toStore.NewInputVariable()) } -// Import creates a new input variable, whose values across all instances are given by assignment. -// If the value in an instance depends on an output of another instance, leave the corresponding index in assignment nil and use Series to specify the dependency. -func (api *API) Import(assignment []frontend.Variable) (gkr.Variable, error) { - nbInstances := len(assignment) - logNbInstances := log2(uint(nbInstances)) - if logNbInstances == -1 { - return -1, errors.New("number of assignments must be a power of 2") - } +type compileOption func(*Circuit) - if currentNbInstances := api.nbInstances(); currentNbInstances != -1 && currentNbInstances != nbInstances { - return -1, errors.New("number of assignments must be consistent across all variables") +// WithInitialChallenge provides a getter for the initial Fiat-Shamir challenge. +// If not provided, the initial challenge will be a commitment to all the input and output values of the circuit. +func WithInitialChallenge(getInitialChallenge InitialChallengeGetter) compileOption { + return func(c *Circuit) { + c.getInitialChallenges = getInitialChallenge } - api.assignments = append(api.assignments, assignment) - return gkr.Variable(api.toStore.NewInputVariable()), nil } -// appendNonNil filters out nil values from src and appends the non-nil values to dst. -// i.e. dst = [0,1], src = [nil, 2, nil, 3] => dst = [0,1,2,3]. -func appendNonNil(dst *[]frontend.Variable, src []frontend.Variable) { - for i := range src { - if src[i] != nil { - *dst = append(*dst, src[i]) - } +// Compile finalizes the GKR circuit. +// From this point on, the circuit cannot be modified. +// But instances can be added to the circuit. +func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, options ...compileOption) *Circuit { + // TODO define levels here + res := Circuit{ + toStore: api.toStore, + assignments: make(gkrtypes.WireAssignment, len(api.toStore.Circuit)), + api: parentApi, } -} -// Solve finalizes the GKR circuit and returns the output variables in the order created -func (api *API) Solve(parentApi frontend.API) (Solution, error) { + api.toStore.HashName = fiatshamirHashName - var p gkrinfo.Permutations - var err error - if p, err = api.toStore.Compile(api.assignments.NbInstances()); err != nil { - return Solution{}, err + for _, opt := range options { + opt(&res) } - api.assignments.Permute(p) - nbInstances := api.toStore.NbInstances - circuit := api.toStore.Circuit + for i := range res.toStore.Circuit { + if res.toStore.Circuit[i].IsOutput() { + res.outs = append(res.ins, gkr.Variable(i)) + } + if res.toStore.Circuit[i].IsInput() { + res.ins = append(res.ins, gkr.Variable(i)) + } + } + res.toStore.SolveHintID = solver.GetHintID(SolveHintPlaceholder(res.toStore)) + res.toStore.ProveHintID = solver.GetHintID(ProveHintPlaceholder(fiatshamirHashName)) - solveHintNIn := 0 - solveHintNOut := 0 + parentApi.Compiler().Defer(res.verify) - for i := range circuit { - v := &circuit[i] - in, out := v.IsInput(), v.IsOutput() - if in && out { - return Solution{}, fmt.Errorf("unused input (variable #%d)", i) - } + return &res +} - if in { - solveHintNIn += nbInstances - len(api.toStore.Dependencies[i]) - } else if out { - solveHintNOut += nbInstances +// AddInstance adds a new instance to the GKR circuit, returning the values of output variables for the instance. +func (c *Circuit) AddInstance(input map[gkr.Variable]frontend.Variable) (map[gkr.Variable]frontend.Variable, error) { + if len(input) != len(c.ins) { + for k := range input { + if k >= gkr.Variable(len(c.ins)) { + return nil, fmt.Errorf("variable %d is out of bounds (max %d)", k, len(c.ins)-1) + } + if !c.toStore.Circuit[k].IsInput() { + return nil, fmt.Errorf("value provided for non-input variable %d", k) + } } } - - // arrange inputs wire first, then in the order solved - ins := make([]frontend.Variable, 0, solveHintNIn) - for i := range circuit { - if circuit[i].IsInput() { - appendNonNil(&ins, api.assignments[i]) + hintIn := make([]frontend.Variable, 1+len(c.ins)) // first input denotes the instance number + hintIn[0] = c.toStore.NbInstances + for hintInI, in := range c.ins { + if inV, ok := input[in]; !ok { + return nil, fmt.Errorf("missing entry for input variable %d", in) + } else { + hintIn[hintInI+1] = inV } } - solveHintPlaceholder := SolveHintPlaceholder(api.toStore) - outsSerialized, err := parentApi.Compiler().NewHint(solveHintPlaceholder, solveHintNOut, ins...) - api.toStore.SolveHintID = solver.GetHintID(solveHintPlaceholder) + c.toStore.NbInstances++ + solveHintPlaceholder := SolveHintPlaceholder(c.toStore) + outsSerialized, err := c.api.Compiler().NewHint(solveHintPlaceholder, len(c.outs), hintIn...) if err != nil { - return Solution{}, err - } - - for i := range circuit { - if circuit[i].IsOutput() { - api.assignments[i] = outsSerialized[:nbInstances] - outsSerialized = outsSerialized[nbInstances:] - } + return nil, fmt.Errorf("failed to create solve hint: %w", err) } - - for i := range circuit { - for _, dep := range api.toStore.Dependencies[i] { - api.assignments[i][dep.InputInstance] = api.assignments[dep.OutputWire][dep.OutputInstance] - } + res := make(map[gkr.Variable]frontend.Variable, len(c.outs)) + for i, v := range c.outs { + res[v] = outsSerialized[i] + c.assignments[v] = append(c.assignments[v], outsSerialized[i]) } - return Solution{ - toStore: api.toStore, - assignments: api.assignments, - parentApi: parentApi, - permutations: p, - }, nil + return res, nil } -// Export returns the values of an output variable across all instances -func (s Solution) Export(v gkr.Variable) []frontend.Variable { - return utils.Map(s.permutations.SortedInstances, utils.SliceAt(s.assignments[v])) -} +// verify encodes the verification circuitry for the GKR circuit +func (c *Circuit) verify(api frontend.API) error { + if api != c.api { + panic("api mismatch") + } + + if len(c.outs) == 0 || len(c.assignments[0]) == 0 { + return nil + } -// Verify encodes the verification circuitry for the GKR circuit -func (s Solution) Verify(hashName string, initialChallenges ...frontend.Variable) error { var ( - err error - proofSerialized []frontend.Variable - proof gadget.Proof + err error + proofSerialized []frontend.Variable + proof gadget.Proof + initialChallenges []frontend.Variable ) - forSnark := newCircuitDataForSnark(s.toStore, s.assignments) - logNbInstances := log2(uint(s.assignments.NbInstances())) + if c.getInitialChallenges != nil { + initialChallenges = c.getInitialChallenges() + } else { + // default initial challenge is a commitment to all input and output values + initialChallenges = make([]frontend.Variable, 0, (len(c.ins)+len(c.outs))*len(c.assignments[c.ins[0]])) + for _, in := range c.ins { + initialChallenges = append(initialChallenges, c.assignments[in]...) + } + for _, out := range c.outs { + initialChallenges = append(initialChallenges, c.assignments[out]...) + } - hintIns := make([]frontend.Variable, len(initialChallenges)+1) // hack: adding one of the outputs of the solve hint to ensure "prove" is called after "solve" - for i, w := range s.toStore.Circuit { - if w.IsOutput() { - hintIns[0] = s.assignments[i][len(s.assignments[i])-1] - break + if initialChallenges[0], err = api.(frontend.Committer).Commit(initialChallenges...); err != nil { + return fmt.Errorf("failed to commit to in/out values: %w", err) } + initialChallenges = initialChallenges[:1] // use the commitment as the only initial challenge } + + forSnark := newCircuitDataForSnark(c.toStore, c.assignments) + logNbInstances := log2(uint(c.assignments.NbInstances())) + + hintIns := make([]frontend.Variable, len(initialChallenges)+1) // hack: adding one of the outputs of the solve hint to ensure "prove" is called after "solve" + firstOutputAssignment := c.assignments[c.outs[0]] + hintIns[0] = firstOutputAssignment[len(firstOutputAssignment)-1] // take the last output of the first output wire + copy(hintIns[1:], initialChallenges) - proveHintPlaceholder := ProveHintPlaceholder(hashName) - if proofSerialized, err = s.parentApi.Compiler().NewHint( + proveHintPlaceholder := ProveHintPlaceholder(c.toStore.HashName) + if proofSerialized, err = api.Compiler().NewHint( proveHintPlaceholder, gadget.ProofSize(forSnark.circuit, logNbInstances), hintIns...); err != nil { return err } - s.toStore.ProveHintID = solver.GetHintID(proveHintPlaceholder) + c.toStore.ProveHintID = solver.GetHintID(proveHintPlaceholder) - forSnarkSorted := utils.MapRange(0, len(s.toStore.Circuit), slicePtrAt(forSnark.circuit)) + forSnarkSorted := utils.MapRange(0, len(c.toStore.Circuit), slicePtrAt(forSnark.circuit)) if proof, err = gadget.DeserializeProof(forSnarkSorted, proofSerialized); err != nil { return err } var hsh hash.FieldHasher - if hsh, err = hash.GetFieldHasher(hashName, s.parentApi); err != nil { + if hsh, err = hash.GetFieldHasher(c.toStore.HashName, api); err != nil { return err } - s.toStore.HashName = hashName - err = gadget.Verify(s.parentApi, forSnark.circuit, forSnark.assignments, proof, fiatshamir.WithHash(hsh, initialChallenges...), gadget.WithSortedCircuit(forSnarkSorted)) + err = gadget.Verify(api, forSnark.circuit, forSnark.assignments, proof, fiatshamir.WithHash(hsh, initialChallenges...), gadget.WithSortedCircuit(forSnarkSorted)) if err != nil { return err } - return s.parentApi.(gkrinfo.ConstraintSystem).SetGkrInfo(s.toStore) + return api.(gkrinfo.ConstraintSystem).SetGkrInfo(c.toStore) } func slicePtrAt[T any](slice []T) func(int) *T { From 79d4cbfef3705f1ac25b7e349dd2e7739ac24cf4 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 1 Jun 2025 16:03:13 -0500 Subject: [PATCH 02/92] feat: check for duplicate gates, allow limiting curves for gate --- constraint/solver/gkrgates/registry.go | 52 ++++++++++++++++--- .../backend/template/gkr/gate_testing.go.tmpl | 11 ++++ internal/gkr/bls12-377/gate_testing.go | 11 ++++ internal/gkr/bls12-381/gate_testing.go | 11 ++++ internal/gkr/bls24-315/gate_testing.go | 11 ++++ internal/gkr/bls24-317/gate_testing.go | 11 ++++ internal/gkr/bn254/gate_testing.go | 11 ++++ internal/gkr/bw6-633/gate_testing.go | 11 ++++ internal/gkr/bw6-761/gate_testing.go | 11 ++++ internal/gkr/gkrtypes/types.go | 26 +++++++--- internal/gkr/small_rational/gate_testing.go | 11 ++++ std/gkrapi/compile.go | 32 +++++------- 12 files changed, 179 insertions(+), 30 deletions(-) diff --git a/constraint/solver/gkrgates/registry.go b/constraint/solver/gkrgates/registry.go index 49610a1789..88d0d3daa8 100644 --- a/constraint/solver/gkrgates/registry.go +++ b/constraint/solver/gkrgates/registry.go @@ -7,6 +7,7 @@ import ( "runtime" "sync" + "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" @@ -100,12 +101,43 @@ func WithCurves(curves ...ecc.ID) registerOption { // - f is the polynomial function defining the gate. // - nbIn is the number of inputs to the gate. func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { - s := registerSettings{degree: -1, solvableVar: -1, name: GetDefaultGateName(f), curves: []ecc.ID{ecc.BN254}} + s := registerSettings{degree: -1, solvableVar: -1, name: GetDefaultGateName(f)} for _, option := range options { option(&s) } - for _, curve := range s.curves { + curvesForTesting := s.curves + allowedCurves := s.curves + if len(curvesForTesting) == 0 { + // no restriction on curves, but only test on BN254 + curvesForTesting = []ecc.ID{ecc.BN254} + allowedCurves = gnark.Curves() + } + + if g, ok := gates[s.name]; ok { + // gate already registered + if reflect.ValueOf(f).Pointer() != reflect.ValueOf(gates[s.name].Evaluate).Pointer() { + return fmt.Errorf("gate \"%s\" already registered with a different function", s.name) + } + // it still might be an anonymous function with different parameters. + // need to test further + if g.NbIn() != nbIn { + return fmt.Errorf("gate \"%s\" already registered with a different number of inputs (%d != %d)", s.name, g.NbIn(), nbIn) + } + + for _, curve := range curvesForTesting { + gateVer, err := NewGateVerifier(curve) + if err != nil { + return err + } + if !gateVer.equal(f, g.Evaluate, nbIn) { + return fmt.Errorf("mismatch with already registered gate \"%s\" (degree %d) over curve %s", s.name, s.degree, curve) + } + } + + } + + for _, curve := range curvesForTesting { gateVer, err := NewGateVerifier(curve) if err != nil { return err @@ -118,12 +150,12 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { const maxAutoDegreeBound = 32 var err error if s.degree, err = gateVer.findDegree(f, maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", s.name, err) + return fmt.Errorf("for gate \"%s\": %v", s.name, err) } } else { if !s.noDegreeVerification { // check that the given degree is correct if err = gateVer.verifyDegree(f, s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", s.name, err) + return fmt.Errorf("for gate \"%s\": %v", s.name, err) } } } @@ -135,7 +167,7 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { } else { // solvable variable given if !s.noSolvableVarVerification && !gateVer.isVarSolvable(f, s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, s.name) + return fmt.Errorf("cannot verify the solvability of variable %d in gate \"%s\"", s.solvableVar, s.name) } } @@ -143,7 +175,7 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { gatesLock.Lock() defer gatesLock.Unlock() - gates[s.name] = gkrtypes.NewGate(f, nbIn, s.degree, s.solvableVar) + gates[s.name] = gkrtypes.NewGate(f, nbIn, s.degree, s.solvableVar, gkrtypes.WithCurves(allowedCurves...)) return nil } @@ -160,6 +192,7 @@ type gateVerifier struct { isAdditive func(f gkr.GateFunction, i int, nbIn int) bool findDegree func(f gkr.GateFunction, max, nbIn int) (int, error) verifyDegree func(f gkr.GateFunction, claimedDegree, nbIn int) error + equal func(f1, f2 gkr.GateFunction, nbIn int) bool } func NewGateVerifier(curve ecc.ID) (*gateVerifier, error) { @@ -172,30 +205,37 @@ func NewGateVerifier(curve ecc.ID) (*gateVerifier, error) { o.isAdditive = bls12377.IsGateFunctionAdditive o.findDegree = bls12377.FindGateFunctionDegree o.verifyDegree = bls12377.VerifyGateFunctionDegree + o.equal = bls12377.EqualGateFunction case ecc.BLS12_381: o.isAdditive = bls12381.IsGateFunctionAdditive o.findDegree = bls12381.FindGateFunctionDegree o.verifyDegree = bls12381.VerifyGateFunctionDegree + o.equal = bls12381.EqualGateFunction case ecc.BLS24_315: o.isAdditive = bls24315.IsGateFunctionAdditive o.findDegree = bls24315.FindGateFunctionDegree o.verifyDegree = bls24315.VerifyGateFunctionDegree + o.equal = bls24315.EqualGateFunction case ecc.BLS24_317: o.isAdditive = bls24317.IsGateFunctionAdditive o.findDegree = bls24317.FindGateFunctionDegree o.verifyDegree = bls24317.VerifyGateFunctionDegree + o.equal = bls24317.EqualGateFunction case ecc.BN254: o.isAdditive = bn254.IsGateFunctionAdditive o.findDegree = bn254.FindGateFunctionDegree o.verifyDegree = bn254.VerifyGateFunctionDegree + o.equal = bn254.EqualGateFunction case ecc.BW6_633: o.isAdditive = bw6633.IsGateFunctionAdditive o.findDegree = bw6633.FindGateFunctionDegree o.verifyDegree = bw6633.VerifyGateFunctionDegree + o.equal = bw6633.EqualGateFunction case ecc.BW6_761: o.isAdditive = bw6761.IsGateFunctionAdditive o.findDegree = bw6761.FindGateFunctionDegree o.verifyDegree = bw6761.VerifyGateFunctionDegree + o.equal = bw6761.EqualGateFunction default: err = fmt.Errorf("unsupported curve %s", curve) } diff --git a/internal/generator/backend/template/gkr/gate_testing.go.tmpl b/internal/generator/backend/template/gkr/gate_testing.go.tmpl index 534b4b01c8..89d1343be6 100644 --- a/internal/generator/backend/template/gkr/gate_testing.go.tmpl +++ b/internal/generator/backend/template/gkr/gate_testing.go.tmpl @@ -155,6 +155,17 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error return nil } +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make({{.FieldPackageName}}.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} + {{- if not .CanUseFFT }} // interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) // Note that the runtime is O(len(X)³) diff --git a/internal/gkr/bls12-377/gate_testing.go b/internal/gkr/bls12-377/gate_testing.go index 415a5ff5b3..9e5a3868f3 100644 --- a/internal/gkr/bls12-377/gate_testing.go +++ b/internal/gkr/bls12-377/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls12-381/gate_testing.go b/internal/gkr/bls12-381/gate_testing.go index ef7694dc18..5b281fd634 100644 --- a/internal/gkr/bls12-381/gate_testing.go +++ b/internal/gkr/bls12-381/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls24-315/gate_testing.go b/internal/gkr/bls24-315/gate_testing.go index 1682d24771..058b53cc06 100644 --- a/internal/gkr/bls24-315/gate_testing.go +++ b/internal/gkr/bls24-315/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls24-317/gate_testing.go b/internal/gkr/bls24-317/gate_testing.go index 1bffab29e3..ed418ff1b0 100644 --- a/internal/gkr/bls24-317/gate_testing.go +++ b/internal/gkr/bls24-317/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bn254/gate_testing.go b/internal/gkr/bn254/gate_testing.go index 716ba3891b..e9311a3ea5 100644 --- a/internal/gkr/bn254/gate_testing.go +++ b/internal/gkr/bn254/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bw6-633/gate_testing.go b/internal/gkr/bw6-633/gate_testing.go index 0fafa45a0d..8074b9621c 100644 --- a/internal/gkr/bw6-633/gate_testing.go +++ b/internal/gkr/bw6-633/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bw6-761/gate_testing.go b/internal/gkr/bw6-761/gate_testing.go index 6eda2ebe73..0bae6258dc 100644 --- a/internal/gkr/bw6-761/gate_testing.go +++ b/internal/gkr/bw6-761/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/gkrtypes/types.go b/internal/gkr/gkrtypes/types.go index 7aed5ccd27..201f063952 100644 --- a/internal/gkr/gkrtypes/types.go +++ b/internal/gkr/gkrtypes/types.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" + "github.com/consensys/gnark" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/utils" @@ -17,15 +19,27 @@ type Gate struct { nbIn int // number of inputs degree int // total degree of the polynomial solvableVar int // if there is a variable whose value can be uniquely determined from the value of the gate and the other inputs, its index, -1 otherwise + curves []ecc.ID // curves that the gate is allowed to be used over } -func NewGate(f gkr.GateFunction, nbIn int, degree int, solvableVar int) *Gate { +func NewGate(f gkr.GateFunction, nbIn int, degree int, solvableVar int, curves []ecc.ID) *Gate { + return &Gate{ evaluate: f, nbIn: nbIn, degree: degree, solvableVar: solvableVar, + curves: curves, + } +} + +func (g *Gate) SupportsCurve(curve ecc.ID) bool { + for _, c := range g.curves { + if c == curve { + return true + } } + return false } func (g *Gate) Evaluate(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { @@ -388,33 +402,33 @@ var ErrZeroFunction = errors.New("detected a zero function") func Identity() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return in[0] - }, 1, 1, 0) + }, 1, 1, 0, gnark.Curves()) } // Add2 gate: (x, y) -> x + y func Add2() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Add(in[0], in[1]) - }, 2, 1, 0) + }, 2, 1, 0, gnark.Curves()) } // Sub2 gate: (x, y) -> x - y func Sub2() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Sub(in[0], in[1]) - }, 2, 1, 0) + }, 2, 1, 0, gnark.Curves()) } // Neg gate: x -> -x func Neg() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Neg(in[0]) - }, 1, 1, 0) + }, 1, 1, 0, gnark.Curves()) } // Mul2 gate: (x, y) -> x * y func Mul2() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Mul(in[0], in[1]) - }, 2, 2, -1) + }, 2, 2, -1, gnark.Curves()) } diff --git a/internal/gkr/small_rational/gate_testing.go b/internal/gkr/small_rational/gate_testing.go index dc29624d7b..6e3dea5781 100644 --- a/internal/gkr/small_rational/gate_testing.go +++ b/internal/gkr/small_rational/gate_testing.go @@ -142,6 +142,17 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error return nil } +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(small_rational.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} + // interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) // Note that the runtime is O(len(X)³) func interpolate(X, Y []small_rational.SmallRational) (polynomial.Polynomial, error) { diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index d8d4d3e960..0ba6213286 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -4,6 +4,7 @@ import ( "fmt" "math/bits" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" @@ -167,7 +168,10 @@ func (c *Circuit) verify(api frontend.API) error { initialChallenges = initialChallenges[:1] // use the commitment as the only initial challenge } - forSnark := newCircuitDataForSnark(c.toStore, c.assignments) + forSnark, err := newCircuitDataForSnark(utils.FieldToCurve(api.Compiler().Field()), c.toStore, c.assignments) + if err != nil { + return fmt.Errorf("failed to create circuit data for snark: %w", err) + } logNbInstances := log2(uint(c.assignments.NbInstances())) hintIns := make([]frontend.Variable, len(initialChallenges)+1) // hack: adding one of the outputs of the solve hint to ensure "prove" is called after "solve" @@ -208,30 +212,22 @@ func slicePtrAt[T any](slice []T) func(int) *T { } } -func ite[T any](condition bool, ifNot, IfSo T) T { - if condition { - return IfSo +func newCircuitDataForSnark(curve ecc.ID, info gkrinfo.StoringInfo, assignment gkrtypes.WireAssignment) (circuitDataForSnark, error) { + circuit, err := gkrtypes.CircuitInfoToCircuit(info.Circuit, gkrgates.Get) + if err != nil { + return circuitDataForSnark{}, fmt.Errorf("failed to convert GKR info to circuit: %w", err) } - return ifNot -} - -func newCircuitDataForSnark(info gkrinfo.StoringInfo, assignment gkrtypes.WireAssignment) circuitDataForSnark { - circuit := make(gkrtypes.Circuit, len(info.Circuit)) - snarkAssignment := make(gkrtypes.WireAssignment, len(info.Circuit)) for i := range circuit { - w := info.Circuit[i] - circuit[i] = gkrtypes.Wire{ - Gate: gkrgates.Get(ite(w.IsInput(), gkr.GateName(w.Gate), gkr.Identity)), - Inputs: w.Inputs, - NbUniqueOutputs: w.NbUniqueOutputs, + if !circuit[i].Gate.SupportsCurve(curve) { + return circuitDataForSnark{}, fmt.Errorf("gate \"%s\" not usable over curve \"%s\"", info.Circuit[i].Gate, curve) } - snarkAssignment[i] = assignment[i] } + return circuitDataForSnark{ circuit: circuit, - assignments: snarkAssignment, - } + assignments: assignment, + }, nil } func init() { From 82e33e54a94a7e4bb3bfd1aa6f3a9702be09a145 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 1 Jun 2025 17:03:06 -0500 Subject: [PATCH 03/92] refactor: gkr api tests --- constraint/solver/gkrgates/registry.go | 31 ++- std/gkrapi/api.go | 3 +- std/gkrapi/api_test.go | 366 +++++++++---------------- std/gkrapi/testing.go | 120 -------- 4 files changed, 153 insertions(+), 367 deletions(-) delete mode 100644 std/gkrapi/testing.go diff --git a/constraint/solver/gkrgates/registry.go b/constraint/solver/gkrgates/registry.go index 88d0d3daa8..2e1d8642ef 100644 --- a/constraint/solver/gkrgates/registry.go +++ b/constraint/solver/gkrgates/registry.go @@ -100,7 +100,9 @@ func WithCurves(curves ...ecc.ID) registerOption { // - name is a human-readable name for the gate. // - f is the polynomial function defining the gate. // - nbIn is the number of inputs to the gate. -func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { +// +// If the gate is already registered, it will return false and no error. +func Register(f gkr.GateFunction, nbIn int, options ...registerOption) (registered bool, err error) { s := registerSettings{degree: -1, solvableVar: -1, name: GetDefaultGateName(f)} for _, option := range options { option(&s) @@ -114,33 +116,37 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { allowedCurves = gnark.Curves() } + gatesLock.Lock() + defer gatesLock.Unlock() + if g, ok := gates[s.name]; ok { // gate already registered if reflect.ValueOf(f).Pointer() != reflect.ValueOf(gates[s.name].Evaluate).Pointer() { - return fmt.Errorf("gate \"%s\" already registered with a different function", s.name) + return false, fmt.Errorf("gate \"%s\" already registered with a different function", s.name) } // it still might be an anonymous function with different parameters. // need to test further if g.NbIn() != nbIn { - return fmt.Errorf("gate \"%s\" already registered with a different number of inputs (%d != %d)", s.name, g.NbIn(), nbIn) + return false, fmt.Errorf("gate \"%s\" already registered with a different number of inputs (%d != %d)", s.name, g.NbIn(), nbIn) } for _, curve := range curvesForTesting { gateVer, err := NewGateVerifier(curve) if err != nil { - return err + return false, err } if !gateVer.equal(f, g.Evaluate, nbIn) { - return fmt.Errorf("mismatch with already registered gate \"%s\" (degree %d) over curve %s", s.name, s.degree, curve) + return false, fmt.Errorf("mismatch with already registered gate \"%s\" (degree %d) over curve %s", s.name, s.degree, curve) } } + return false, nil // gate already registered } for _, curve := range curvesForTesting { gateVer, err := NewGateVerifier(curve) if err != nil { - return err + return false, err } if s.degree == -1 { // find a degree @@ -148,14 +154,13 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { panic("invalid settings") } const maxAutoDegreeBound = 32 - var err error if s.degree, err = gateVer.findDegree(f, maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate \"%s\": %v", s.name, err) + return false, fmt.Errorf("for gate \"%s\": %v", s.name, err) } } else { if !s.noDegreeVerification { // check that the given degree is correct if err = gateVer.verifyDegree(f, s.degree, nbIn); err != nil { - return fmt.Errorf("for gate \"%s\": %v", s.name, err) + return false, fmt.Errorf("for gate \"%s\": %v", s.name, err) } } } @@ -167,16 +172,14 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { } else { // solvable variable given if !s.noSolvableVarVerification && !gateVer.isVarSolvable(f, s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate \"%s\"", s.solvableVar, s.name) + return false, fmt.Errorf("cannot verify the solvability of variable %d in gate \"%s\"", s.solvableVar, s.name) } } } - gatesLock.Lock() - defer gatesLock.Unlock() - gates[s.name] = gkrtypes.NewGate(f, nbIn, s.degree, s.solvableVar, gkrtypes.WithCurves(allowedCurves...)) - return nil + gates[s.name] = gkrtypes.NewGate(f, nbIn, s.degree, s.solvableVar, allowedCurves) + return true, nil } func Get(name gkr.GateName) *gkrtypes.Gate { diff --git a/std/gkrapi/api.go b/std/gkrapi/api.go index 771613ce0d..18a9b23279 100644 --- a/std/gkrapi/api.go +++ b/std/gkrapi/api.go @@ -23,12 +23,11 @@ func (api *API) NamedGate(gate gkr.GateName, in ...gkr.Variable) gkr.Variable { Inputs: utils.Map(in, frontendVarToInt), }) api.assignments = append(api.assignments, nil) - api.toStore.Dependencies = append(api.toStore.Dependencies, nil) // formality. Dependencies are only defined for input vars. return gkr.Variable(len(api.toStore.Circuit) - 1) } func (api *API) Gate(gate gkr.GateFunction, in ...gkr.Variable) gkr.Variable { - if err := gkrgates.Register(gate, len(in)); err != nil { + if _, err := gkrgates.Register(gate, len(in)); err != nil { panic(err) } return api.NamedGate(gkrgates.GetDefaultGateName(gate), in...) diff --git a/std/gkrapi/api_test.go b/std/gkrapi/api_test.go index 7bb255d70c..5cd9163ed8 100644 --- a/std/gkrapi/api_test.go +++ b/std/gkrapi/api_test.go @@ -16,10 +16,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" gcHash "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/backend/groth16" - "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/std/gkrapi/gkr" stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" @@ -40,23 +38,21 @@ type doubleNoDependencyCircuit struct { func (c *doubleNoDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x gkr.Variable - var err error - if x, err = gkrApi.Import(c.X); err != nil { - return err - } + x := gkrApi.NewInput() z := gkrApi.Add(x, x) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { - return err - } - Z := solution.Export(z) - for i := range Z { - api.AssertIsEqual(Z[i], api.Mul(2, c.X[i])) - } + gkrCircuit := gkrApi.Compile(api, c.hashName) - return solution.Verify(c.hashName) + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[x] = c.X[i] + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + api.AssertIsEqual(instanceOut[z], api.Mul(2, c.X[i])) + } + return nil } func TestDoubleNoDependencyCircuit(t *testing.T) { @@ -88,23 +84,21 @@ type sqNoDependencyCircuit struct { func (c *sqNoDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x gkr.Variable - var err error - if x, err = gkrApi.Import(c.X); err != nil { - return err - } + x := gkrApi.NewInput() z := gkrApi.Mul(x, x) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { - return err - } - Z := solution.Export(z) - for i := range Z { - api.AssertIsEqual(Z[i], api.Mul(c.X[i], c.X[i])) - } + gkrCircuit := gkrApi.Compile(api, c.hashName) - return solution.Verify(c.hashName) + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[x] = c.X[i] + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + api.AssertIsEqual(instanceOut[z], api.Mul(c.X[i], c.X[i])) + } + return nil } func TestSqNoDependencyCircuit(t *testing.T) { @@ -135,29 +129,23 @@ type mulNoDependencyCircuit struct { func (c *mulNoDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x, y gkr.Variable - var err error - if x, err = gkrApi.Import(c.X); err != nil { - return err - } - if y, err = gkrApi.Import(c.Y); err != nil { - return err - } - gkrApi.Println(0, "values of x and y in instance number", 0, x, y) + x := gkrApi.NewInput() + y := gkrApi.NewInput() + z := gkrApi.Add(x, y) - z := gkrApi.Mul(x, y) - gkrApi.Println(1, "value of z in instance number", 1, z) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { - return err - } - Z := solution.Export(z) + gkrCircuit := gkrApi.Compile(api, c.hashName) + instanceIn := make(map[gkr.Variable]frontend.Variable) for i := range c.X { - api.AssertIsEqual(Z[i], api.Mul(c.X[i], c.Y[i])) + instanceIn[x] = c.X[i] + instanceIn[y] = c.Y[i] + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + api.AssertIsEqual(instanceOut[z], api.Mul(c.Y[i], c.X[i])) } - - return solution.Verify(c.hashName) + return nil } func TestMulNoDependency(t *testing.T) { @@ -191,91 +179,68 @@ func TestMulNoDependency(t *testing.T) { } type mulWithDependencyCircuit struct { - XLast frontend.Variable + XFirst frontend.Variable Y []frontend.Variable hashName string } func (c *mulWithDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x, y gkr.Variable - var err error - - X := make([]frontend.Variable, len(c.Y)) - X[len(c.Y)-1] = c.XLast - if x, err = gkrApi.Import(X); err != nil { - return err - } - if y, err = gkrApi.Import(c.Y); err != nil { - return err - } + x := gkrApi.NewInput() // x is the state variable + y := gkrApi.NewInput() z := gkrApi.Mul(x, y) - for i := len(X) - 1; i > 0; i-- { - gkrApi.Series(x, z, i-1, i) - } + gkrCircuit := gkrApi.Compile(api, c.hashName) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { - return err - } - X = solution.Export(x) - Y := solution.Export(y) - Z := solution.Export(z) + state := c.XFirst + instanceIn := make(map[gkr.Variable]frontend.Variable) + + for i := range c.Y { + instanceIn[x] = state + instanceIn[y] = c.Y[i] + + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } - lastI := len(X) - 1 - api.AssertIsEqual(Z[lastI], api.Mul(c.XLast, Y[lastI])) - for i := 0; i < lastI; i++ { - api.AssertIsEqual(Z[i], api.Mul(Z[i+1], Y[i])) + state = instanceOut[z] // update state for the next iteration + api.AssertIsEqual(state, api.Mul(state, c.Y[i])) } - return solution.Verify(c.hashName) + return nil } func TestSolveMulWithDependency(t *testing.T) { assert := test.NewAssert(t) assignment := mulWithDependencyCircuit{ - XLast: 1, - Y: []frontend.Variable{3, 2}, + XFirst: 1, + Y: []frontend.Variable{3, 2}, } circuit := mulWithDependencyCircuit{Y: make([]frontend.Variable, len(assignment.Y)), hashName: "-20"} assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254)) } func TestApiMul(t *testing.T) { - var ( - x gkr.Variable - y gkr.Variable - z gkr.Variable - err error - ) api := New() - x, err = api.Import([]frontend.Variable{nil, nil}) - require.NoError(t, err) - y, err = api.Import([]frontend.Variable{nil, nil}) - require.NoError(t, err) - z = api.Mul(x, y) + x := api.NewInput() + y := api.NewInput() + z := api.Mul(x, y) assertSliceEqual(t, api.toStore.Circuit[z].Inputs, []int{int(x), int(y)}) // TODO: Find out why assert.Equal gives false positives ( []*Wire{x,x} as second argument passes when it shouldn't ) } func BenchmarkMiMCMerkleTree(b *testing.B) { - depth := 14 - bottom := make([]frontend.Variable, 1<= 0; d-- { - for i := 0; i < 1< 1 { + nextLayer := curLayer[:len(curLayer)/2] - challenge, err := api.(frontend.Committer).Commit(Z...) - if err != nil { - return err - } + for i := range nextLayer { + instanceIn[x] = curLayer[2*i] + instanceIn[y] = curLayer[2*i+1] - return solution.Verify("-20", challenge) + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + nextLayer[i] = instanceOut[z] // store the result of the hash + } + + curLayer = nextLayer + } + return nil } -func registerMiMC() { +func init() { stdHash.Register("MIMC", func(api frontend.API) (stdHash.FieldHasher, error) { m, err := mimc.NewMiMC(api) return &m, err }) } -func init() { - registerMiMC() - registerMiMCGate() -} - -func registerMiMCGate() { - // register mimc gate - panicIfError(gkrgates.Register(func(api gkr.GateAPI, input ...frontend.Variable) frontend.Variable { - mimcSnarkTotalCalls++ +func mimcGate(api gkr.GateAPI, input ...frontend.Variable) frontend.Variable { + mimcSnarkTotalCalls++ - if len(input) != 2 { - panic("mimc has fan-in 2") - } - sum := api.Add(input[0], input[1] /*, m.Ark*/) + if len(input) != 2 { + panic("mimc has fan-in 2") + } + sum := api.Add(input[0], input[1] /*, m.Ark*/) - sumCubed := api.Mul(sum, sum, sum) // sum^3 - return api.Mul(sumCubed, sumCubed, sum) - }, 2, gkrgates.WithDegree(7), gkrgates.WithName("MIMC"))) + sumCubed := api.Mul(sum, sum, sum) // sum^3 + return api.Mul(sumCubed, sumCubed, sum) } type constPseudoHash int @@ -465,26 +422,25 @@ type mimcNoDepCircuit struct { } func (c *mimcNoDepCircuit) Define(api frontend.API) error { - _gkr := New() - x, err := _gkr.Import(c.X) - if err != nil { - return err - } - var ( - y gkr.Variable - solution Solution - ) - if y, err = _gkr.Import(c.Y); err != nil { - return err - } + // define the circuit + gkrApi := New() + x := gkrApi.NewInput() + y := gkrApi.NewInput() + gkrApi.Gate(mimcGate, x, y) - z := _gkr.NamedGate("MIMC", x, y) + gkrCircuit := gkrApi.Compile(api, c.hashName) - if solution, err = _gkr.Solve(api); err != nil { - return err + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[x] = c.X[i] + instanceIn[y] = c.Y[i] + + _, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } } - Z := solution.Export(z) - return solution.Verify(c.hashName, Z...) + return nil } func mimcNoDepCircuits(mimcDepth, nbInstances int, hashName string) (circuit, assignment frontend.Circuit) { @@ -566,58 +522,6 @@ func mimcNoGkrCircuits(mimcDepth, nbInstances int) (circuit, assignment frontend return } -func TestSolveInTestEngine(t *testing.T) { - assignment := testSolveInTestEngineCircuit{ - X: []frontend.Variable{2, 3, 4, 5, 6, 7, 8, 9}, - } - circuit := testSolveInTestEngineCircuit{ - X: make([]frontend.Variable, len(assignment.X)), - } - - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BN254.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS24_315.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_381.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS24_317.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BW6_633.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_377.ScalarField())) -} - -type testSolveInTestEngineCircuit struct { - X []frontend.Variable -} - -func (c *testSolveInTestEngineCircuit) Define(api frontend.API) error { - gkrBn254 := New() - x, err := gkrBn254.Import(c.X) - if err != nil { - return err - } - Y := make([]frontend.Variable, len(c.X)) - Y[0] = 1 - y, err := gkrBn254.Import(Y) - if err != nil { - return err - } - - z := gkrBn254.Mul(x, y) - - for i := range len(c.X) - 1 { - gkrBn254.Series(y, z, i+1, i) - } - - assignments := gkrBn254.SolveInTestEngine(api) - - product := frontend.Variable(1) - for i := range c.X { - api.AssertIsEqual(assignments[y][i], product) - product = api.Mul(product, c.X[i]) - api.AssertIsEqual(assignments[z][i], product) - } - - return nil -} - func panicIfError(err error) { if err != nil { panic(err) diff --git a/std/gkrapi/testing.go b/std/gkrapi/testing.go deleted file mode 100644 index 17163c0b5a..0000000000 --- a/std/gkrapi/testing.go +++ /dev/null @@ -1,120 +0,0 @@ -package gkrapi - -import ( - "errors" - "fmt" - "sync" - - "github.com/consensys/gnark/constraint/solver/gkrgates" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/std/gkrapi/gkr" - stdHash "github.com/consensys/gnark/std/hash" -) - -type solveInTestEngineSettings struct { - hashName string -} - -type SolveInTestEngineOption func(*solveInTestEngineSettings) - -func WithHashName(name string) SolveInTestEngineOption { - return func(s *solveInTestEngineSettings) { - s.hashName = name - } -} - -// SolveInTestEngine solves the defined circuit directly inside the SNARK circuit. This means that the method does not compute the GKR proof of the circuit and does not embed the GKR proof verifier inside a SNARK. -// The output is the values of all variables, across all instances; i.e. indexed variable-first, instance-second. -// This method only works under the test engine and should only be called to debug a GKR circuit, as the GKR prover's errors can be obscure. -func (api *API) SolveInTestEngine(parentApi frontend.API, options ...SolveInTestEngineOption) [][]frontend.Variable { - gateVer, err := gkrgates.NewGateVerifier(utils.FieldToCurve(parentApi.Compiler().Field())) - if err != nil { - panic(err) - } - - var s solveInTestEngineSettings - for _, o := range options { - o(&s) - } - if s.hashName != "" { - // hash something and make sure it gives the same answer both on prover and verifier sides - // TODO @Tabaie If indeed cheap, move this feature to Verify so that it is always run - h, err := stdHash.GetFieldHasher(s.hashName, parentApi) - if err != nil { - panic(err) - } - nbBytes := (parentApi.Compiler().FieldBitLen() + 7) / 8 - toHash := frontend.Variable(0) - for i := range nbBytes { - toHash = parentApi.Add(parentApi.Mul(toHash, 256), i%256) - } - h.Reset() - h.Write(toHash) - hashed := h.Sum() - - hintOut, err := parentApi.Compiler().NewHint(CheckHashHint(s.hashName), 1, toHash, hashed) - if err != nil { - panic(err) - } - parentApi.AssertIsEqual(hintOut[0], hashed) // the hint already checks this - } - - res := make([][]frontend.Variable, len(api.toStore.Circuit)) - var verifiedGates sync.Map - for i, w := range api.toStore.Circuit { - res[i] = make([]frontend.Variable, api.nbInstances()) - copy(res[i], api.assignments[i]) - if len(w.Inputs) == 0 { - continue - } - } - for instanceI := range api.nbInstances() { - for wireI, w := range api.toStore.Circuit { - deps := api.toStore.Dependencies[wireI] - if len(deps) != 0 && len(w.Inputs) != 0 { - panic(fmt.Errorf("non-input wire %d should not have dependencies", wireI)) - } - for _, dep := range deps { - if dep.InputInstance == instanceI { - if dep.OutputInstance >= instanceI { - panic(fmt.Errorf("out of order dependency not yet supported in SolveInTestEngine; (wire %d, instance %d) depends on (wire %d, instance %d)", wireI, instanceI, dep.OutputWire, dep.OutputInstance)) - } - if res[wireI][instanceI] != nil { - panic(fmt.Errorf("dependency (wire %d, instance %d) <- (wire %d, instance %d) attempting to override existing value assignment", wireI, instanceI, dep.OutputWire, dep.OutputInstance)) - } - res[wireI][instanceI] = res[dep.OutputWire][dep.OutputInstance] - } - } - - if res[wireI][instanceI] == nil { // no assignment or dependency - if len(w.Inputs) == 0 { - panic(fmt.Errorf("input wire %d, instance %d has no dependency or explicit assignment", wireI, instanceI)) - } - ins := make([]frontend.Variable, len(w.Inputs)) - for i, in := range w.Inputs { - ins[i] = res[in][instanceI] - } - gate := gkrgates.Get(gkr.GateName(w.Gate)) - if gate == nil && !w.IsInput() { - panic(fmt.Errorf("gate %s not found", w.Gate)) - } - if _, ok := verifiedGates.Load(w.Gate); !ok { - verifiedGates.Store(w.Gate, struct{}{}) - - err = errors.Join( - gateVer.VerifyDegree(gate), - gateVer.VerifySolvability(gate), - ) - if err != nil { - panic(fmt.Errorf("gate %s: %w", w.Gate, err)) - } - } - if gate != nil { - res[wireI][instanceI] = gate.Evaluate(parentApi, ins...) - } - } - } - } - return res -} From ba06e5d433765e8fd417ab2bc024a91e95388599 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 1 Jun 2025 17:17:13 -0500 Subject: [PATCH 04/92] refactor gkr example --- std/gkrapi/compile_test.go | 139 ------------------------------------- std/gkrapi/example_test.go | 119 ++++++++++++------------------- 2 files changed, 44 insertions(+), 214 deletions(-) delete mode 100644 std/gkrapi/compile_test.go diff --git a/std/gkrapi/compile_test.go b/std/gkrapi/compile_test.go deleted file mode 100644 index a0ca992ed4..0000000000 --- a/std/gkrapi/compile_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package gkrapi - -import ( - "testing" - - "github.com/consensys/gnark/internal/gkr/gkrinfo" - "github.com/stretchr/testify/assert" -) - -func TestCompile2Cycles(t *testing.T) { - var d = gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - nil, - { - { - OutputWire: 0, - OutputInstance: 1, - InputInstance: 0, - }, - }, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{1}, - }, - { - Inputs: []int{}, - }, - }, - } - - expectedCompiled := gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - {{ - OutputWire: 1, - OutputInstance: 0, - InputInstance: 1, - }}, - nil, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{}, - NbUniqueOutputs: 1, - }, - { - Inputs: []int{0}, - }}, - NbInstances: 2, - } - - expectedPermutations := gkrinfo.Permutations{ - SortedInstances: []int{1, 0}, - SortedWires: []int{1, 0}, - InstancesPermutation: []int{1, 0}, - WiresPermutation: []int{1, 0}, - } - - p, err := d.Compile(2) - assert.NoError(t, err) - assert.Equal(t, expectedPermutations, p) - assert.Equal(t, expectedCompiled, d) -} - -func TestCompile3Cycles(t *testing.T) { - var d = gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - nil, - { - { - OutputWire: 0, - OutputInstance: 2, - InputInstance: 0, - }, - { - OutputWire: 0, - OutputInstance: 1, - InputInstance: 2, - }, - }, - nil, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{2}, - }, - { - Inputs: []int{}, - }, - { - Inputs: []int{1}, - }, - }, - } - - expectedCompiled := gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - {{ - OutputWire: 2, - OutputInstance: 0, - InputInstance: 1, - }, { - OutputWire: 2, - OutputInstance: 1, - InputInstance: 2, - }}, - - nil, - nil, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{}, - NbUniqueOutputs: 1, - }, - { - Inputs: []int{0}, - NbUniqueOutputs: 1, - }, - { - Inputs: []int{1}, - NbUniqueOutputs: 0, - }, - }, - NbInstances: 3, // not allowed if we were actually performing gkr - } - - expectedPermutations := gkrinfo.Permutations{ - SortedInstances: []int{1, 2, 0}, - SortedWires: []int{1, 2, 0}, - InstancesPermutation: []int{2, 0, 1}, - WiresPermutation: []int{2, 0, 1}, - } - - p, err := d.Compile(3) - assert.NoError(t, err) - assert.Equal(t, expectedPermutations, p) - assert.Equal(t, expectedCompiled, d) -} diff --git a/std/gkrapi/example_test.go b/std/gkrapi/example_test.go index 4078bb0b9e..49d0209192 100644 --- a/std/gkrapi/example_test.go +++ b/std/gkrapi/example_test.go @@ -10,8 +10,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" - stdHash "github.com/consensys/gnark/std/hash" - "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" ) @@ -20,18 +18,22 @@ func Example() { // This means that the imported fr and fp packages are the same, being from BW6-761 and BLS12-377 respectively. TODO @Tabaie delete if no longer have fp imported // It is based on the function DoubleAssign() of type G1Jac in gnark-crypto v0.17.0. // github.com/consensys/gnark-crypto/ecc/bls12-377 - const fsHashName = "MIMC" // register the gates: Doing so is not needed here because // the proof is being computed in the same session as the // SNARK circuit being compiled. // But in production applications it would be necessary. - assertNoError(gkrgates.Register(squareGate, 1)) - assertNoError(gkrgates.Register(sGate, 4)) - assertNoError(gkrgates.Register(zGate, 4)) - assertNoError(gkrgates.Register(xGate, 2)) - assertNoError(gkrgates.Register(yGate, 4)) + _, err := gkrgates.Register(squareGate, 1) + assertNoError(err) + _, err = gkrgates.Register(sGate, 4) + assertNoError(err) + _, err = gkrgates.Register(zGate, 4) + assertNoError(err) + _, err = gkrgates.Register(xGate, 2) + assertNoError(err) + _, err = gkrgates.Register(yGate, 4) + assertNoError(err) const nbInstances = 2 // create instances @@ -64,13 +66,12 @@ func Example() { } circuit := exampleCircuit{ - X: make([]frontend.Variable, nbInstances), - Y: make([]frontend.Variable, nbInstances), - Z: make([]frontend.Variable, nbInstances), - XOut: make([]frontend.Variable, nbInstances), - YOut: make([]frontend.Variable, nbInstances), - ZOut: make([]frontend.Variable, nbInstances), - fsHashName: fsHashName, + X: make([]frontend.Variable, nbInstances), + Y: make([]frontend.Variable, nbInstances), + Z: make([]frontend.Variable, nbInstances), + XOut: make([]frontend.Variable, nbInstances), + YOut: make([]frontend.Variable, nbInstances), + ZOut: make([]frontend.Variable, nbInstances), } assertNoError(test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) @@ -81,7 +82,6 @@ func Example() { type exampleCircuit struct { X, Y, Z []frontend.Variable // Jacobian coordinates for each point (input) XOut, YOut, ZOut []frontend.Variable // Jacobian coordinates for the double of each point (expected output) - fsHashName string // name of the hash function used for Fiat-Shamir in the GKR verifier } func (c *exampleCircuit) Define(api frontend.API) error { @@ -91,21 +91,10 @@ func (c *exampleCircuit) Define(api frontend.API) error { gkrApi := gkrapi.New() - // create GKR circuit variables based on the given assignments - X, err := gkrApi.Import(c.X) - if err != nil { - return err - } - - Y, err := gkrApi.Import(c.Y) - if err != nil { - return err - } - - Z, err := gkrApi.Import(c.Z) - if err != nil { - return err - } + // create the GKR circuit + X := gkrApi.NewInput() + Y := gkrApi.NewInput() + Z := gkrApi.NewInput() XX := gkrApi.Gate(squareGate, X) // 405: XX.Square(&p.X) YY := gkrApi.Gate(squareGate, Y) // 406: YY.Square(&p.Y) @@ -117,51 +106,31 @@ func (c *exampleCircuit) Define(api frontend.API) error { // 414: M.Double(&XX).Add(&M, &XX) // Note (but don't explicitly compute) that M = 3XX - Z = gkrApi.Gate(zGate, Z, Y, YY, ZZ) // 415 - 418 - X = gkrApi.Gate(xGate, XX, S) // 419-422 - Y = gkrApi.Gate(yGate, S, X, XX, YYYY) // 423 - 426 - - // have to duplicate X for it to be considered an output variable - X = gkrApi.NamedGate(gkr.Identity, X) - - // register the hash function used for verification (fiat shamir) - stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { - m, err := mimc.NewMiMC(api) - return &m, err - }) - - // solve and prove the circuit - solution, err := gkrApi.Solve(api) - if err != nil { - return err + ZOut := gkrApi.Gate(zGate, Z, Y, YY, ZZ) // 415 - 418 + XOut := gkrApi.Gate(xGate, XX, S) // 419-422 + YOut := gkrApi.Gate(yGate, S, XOut, XX, YYYY) // 423 - 426 + + // have to duplicate X for it to be considered an output variable; this is an implementation detail and will be fixed in the future [https://github.com/Consensys/gnark/issues/1452] + XOut = gkrApi.NamedGate(gkr.Identity, XOut) + + gkrCircuit := gkrApi.Compile(api, "MIMC") + + // add input and check output for correctness + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[X] = c.X[i] + instanceIn[Y] = c.Y[i] + instanceIn[Z] = c.Z[i] + + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return err + } + api.AssertIsEqual(instanceOut[XOut], c.XOut[i]) + api.AssertIsEqual(instanceOut[YOut], c.YOut[i]) + api.AssertIsEqual(instanceOut[ZOut], c.ZOut[i]) } - - // check the output - - XOut := solution.Export(X) - YOut := solution.Export(Y) - ZOut := solution.Export(Z) - for i := range XOut { - api.AssertIsEqual(XOut[i], c.XOut[i]) - api.AssertIsEqual(YOut[i], c.YOut[i]) - api.AssertIsEqual(ZOut[i], c.ZOut[i]) - } - - challenges := make([]frontend.Variable, 0, len(c.X)*6) - challenges = append(challenges, XOut...) - challenges = append(challenges, YOut...) - challenges = append(challenges, ZOut...) - challenges = append(challenges, c.X...) - challenges = append(challenges, c.Y...) - challenges = append(challenges, c.Z...) - - challenge, err := api.(frontend.Committer).Commit(challenges...) - if err != nil { - return err - } - - // verify the proof - return solution.Verify(c.fsHashName, challenge) + return nil } // custom gates From 700c152b6e8d670dbef5f2221cd1a81329d831da Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 2 Jun 2025 12:12:03 -0500 Subject: [PATCH 05/92] refactor: solve hint called per instance --- constraint/bls12-377/solver.go | 6 +- internal/gkr/bls12-377/solver_hints.go | 115 +++++++++---------------- internal/gkr/gkrinfo/info.go | 30 +++++++ internal/gkr/gkrtypes/types.go | 56 ++---------- {std/gkrapi => internal/gkr}/hints.go | 101 ++++++++-------------- std/gkrapi/api.go | 18 +--- std/gkrapi/compile.go | 12 ++- 7 files changed, 130 insertions(+), 208 deletions(-) rename {std/gkrapi => internal/gkr}/hints.go (53%) diff --git a/constraint/bls12-377/solver.go b/constraint/bls12-377/solver.go index f79940e3be..206fea5702 100644 --- a/constraint/bls12-377/solver.go +++ b/constraint/bls12-377/solver.go @@ -51,14 +51,14 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 39547cff29..6353370a15 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -13,6 +13,7 @@ import ( "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,92 +22,65 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment + circuit gkrtypes.Circuit + workers *utils.WorkerPool + maxNbIn int // maximum number of inputs for a gate in the circuit + printsByInstance map[uint32][]gkrinfo.PrintInfo } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { + d := SolvingData{ + workers: utils.NewWorkerPool(), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + } - d.assignment = make(WireAssignment, len(d.circuit)) + d.circuit.SetNbUniqueOutputs() + d.maxNbIn = d.circuit.MaxGateNbIn() for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } + + return &d } // this module assumes that wire and instance indexes respect dependencies -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + if !ins[0].IsUint64() { // TODO use idiomatic printf tag + return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } - } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } - } + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } } - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end - } - - for _, p := range info.Prints { + prints := data.printsByInstance[uint32(instanceI)] + delete(data.printsByInstance, uint32(instanceI)) + for _, p := range prints { serializable := make([]any, len(p.Values)) for i, v := range p.Values { if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 @@ -117,9 +91,6 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } fmt.Println(serializable...) } - - setOuts(data.assignment, info.Circuit, outs) - return nil } } diff --git a/internal/gkr/gkrinfo/info.go b/internal/gkr/gkrinfo/info.go index 55fdb7e71d..81902df8c6 100644 --- a/internal/gkr/gkrinfo/info.go +++ b/internal/gkr/gkrinfo/info.go @@ -7,6 +7,7 @@ import ( "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/std/gkrapi/gkr" ) type ( @@ -139,3 +140,32 @@ func (d *StoringInfo) Is() bool { type ConstraintSystem interface { SetGkrInfo(info StoringInfo) error } + +func NewPrint(instance int, a ...any) PrintInfo { + isVar := make([]bool, len(a)) + vals := make([]any, len(a)) + for i := range a { + v, ok := a[i].(gkr.Variable) + isVar[i] = ok + if ok { + vals[i] = uint32(v) + } else { + vals[i] = a[i] + } + } + + return PrintInfo{ + Values: vals, + Instance: uint32(instance), + IsGkrVar: isVar, + } +} + +// NewPrintInfoMap partitions printInfo into map elements, indexed by instance +func NewPrintInfoMap(printInfo []PrintInfo) map[uint32][]PrintInfo { + res := make(map[uint32][]PrintInfo) + for i := range printInfo { + res[printInfo[i].Instance] = append(res[printInfo[i].Instance], printInfo[i]) + } + return res +} diff --git a/internal/gkr/gkrtypes/types.go b/internal/gkr/gkrtypes/types.go index 201f063952..12cdabf3d9 100644 --- a/internal/gkr/gkrtypes/types.go +++ b/internal/gkr/gkrtypes/types.go @@ -147,49 +147,10 @@ func (c Circuit) MemoryRequirements(nbInstances int) []int { } type SolvingInfo struct { - Circuit Circuit - Dependencies [][]gkrinfo.InputDependency - NbInstances int - HashName string - Prints []gkrinfo.PrintInfo -} - -// Chunks returns intervals of instances that are independent of each other and can be solved in parallel -func (info *SolvingInfo) Chunks() []int { - res := make([]int, 0, 1) - lastSeenDependencyI := make([]int, len(info.Circuit)) - - for start, end := 0, 0; start != info.NbInstances; start = end { - end = info.NbInstances - endWireI := -1 - for wI := range info.Circuit { - deps := info.Dependencies[wI] - if wDepI := lastSeenDependencyI[wI]; wDepI < len(deps) && deps[wDepI].InputInstance < end { - end = deps[wDepI].InputInstance - endWireI = wI - } - } - if endWireI != -1 { - lastSeenDependencyI[endWireI]++ - } - res = append(res, end) - } - return res -} - -// AssignmentOffsets describes the input layout of the Solve hint, by returning -// for each wire, the index of the first hint input element corresponding to it. -func (info *SolvingInfo) AssignmentOffsets() []int { - c := info.Circuit - res := make([]int, len(c)+1) - for i := range c { - nbExplicitAssignments := 0 - if c[i].IsInput() { - nbExplicitAssignments = info.NbInstances - len(info.Dependencies[i]) - } - res[i+1] = res[i] + nbExplicitAssignments - } - return res + Circuit Circuit + NbInstances int + HashName string + Prints []gkrinfo.PrintInfo } // OutputsList for each wire, returns the set of indexes of wires it is input to. @@ -282,11 +243,10 @@ func CircuitInfoToCircuit(info gkrinfo.Circuit, gateGetter func(name gkr.GateNam func StoringToSolvingInfo(info gkrinfo.StoringInfo, gateGetter func(name gkr.GateName) *Gate) (SolvingInfo, error) { circuit, err := CircuitInfoToCircuit(info.Circuit, gateGetter) return SolvingInfo{ - Circuit: circuit, - NbInstances: info.NbInstances, - HashName: info.HashName, - Dependencies: info.Dependencies, - Prints: info.Prints, + Circuit: circuit, + NbInstances: info.NbInstances, + HashName: info.HashName, + Prints: info.Prints, }, err } diff --git a/std/gkrapi/hints.go b/internal/gkr/hints.go similarity index 53% rename from std/gkrapi/hints.go rename to internal/gkr/hints.go index 577a4d6ed8..2c2621911e 100644 --- a/std/gkrapi/hints.go +++ b/internal/gkr/hints.go @@ -1,13 +1,10 @@ -package gkrapi +package gkr import ( "errors" - "fmt" "math/big" - "strings" "github.com/consensys/gnark-crypto/ecc" - gcHash "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" @@ -19,7 +16,6 @@ import ( bw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" - "github.com/consensys/gnark/internal/utils" ) var testEngineGkrSolvingData = make(map[string]any) @@ -28,6 +24,8 @@ func modKey(mod *big.Int) string { return mod.Text(32) } +// SolveHintPlaceholder solves one instance of a GKR circuit. +// The first input is the index of the instance. The rest are the inputs of the circuit, in their nominal order. func SolveHintPlaceholder(gkrInfo gkrinfo.StoringInfo) solver.Hint { return func(mod *big.Int, ins []*big.Int, outs []*big.Int) error { @@ -36,44 +34,42 @@ func SolveHintPlaceholder(gkrInfo gkrinfo.StoringInfo) solver.Hint { return err } + var hint solver.Hint + // TODO @Tabaie autogenerate this or decide not to if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { - var data bls12377.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls12377.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { - var data bls12381.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls12381.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { - var data bls24315.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls24315.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { - var data bls24317.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls24317.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BN254.ScalarField()) == 0 { - var data bn254.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bn254.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { - var data bw6633.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bw6633.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { - var data bw6761.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bw6761.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - - return errors.New("unsupported modulus") + data := bls12377.NewSolvingData(solvingInfo) + hint = bls12377.SolveHint(solvingInfo, data) + testEngineGkrSolvingData[modKey(mod)] = data + } else if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { + data := bls12381.NewSolvingData(solvingInfo) + hint = bls12381.SolveHint(solvingInfo, data) + testEngineGkrSolvingData[modKey(mod)] = data + } else if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { + data := bls24315.NewSolvingData(solvingInfo) + hint = bls24315.SolveHint(solvingInfo, data) + testEngineGkrSolvingData[modKey(mod)] = data + } else if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { + data := bls24317.NewSolvingData(solvingInfo) + hint = bls24317.SolveHint(solvingInfo, data) + testEngineGkrSolvingData[modKey(mod)] = data + } else if mod.Cmp(ecc.BN254.ScalarField()) == 0 { + data := bn254.NewSolvingData(solvingInfo) + hint = bn254.SolveHint(solvingInfo, data) + testEngineGkrSolvingData[modKey(mod)] = data + } else if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { + data := bw6633.NewSolvingData(solvingInfo) + hint = bw6633.SolveHint(solvingInfo, data) + testEngineGkrSolvingData[modKey(mod)] = data + } else if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { + data := bw6761.NewSolvingData(solvingInfo) + hint = bw6761.SolveHint(solvingInfo, data) + testEngineGkrSolvingData[modKey(mod)] = data + } else { + return errors.New("unsupported modulus") + } + + return hint(mod, ins, outs) } } @@ -112,26 +108,3 @@ func ProveHintPlaceholder(hashName string) solver.Hint { return errors.New("unsupported modulus") } } - -func CheckHashHint(hashName string) solver.Hint { - return func(mod *big.Int, ins, outs []*big.Int) error { - if len(ins) != 2 || len(outs) != 1 { - return errors.New("invalid number of inputs/outputs") - } - - toHash := ins[0].Bytes() - expectedHash := ins[1] - - hsh := gcHash.NewHash(fmt.Sprintf("%s_%s", hashName, strings.ToUpper(utils.FieldToCurve(mod).String()))) - hsh.Write(toHash) - hashed := hsh.Sum(nil) - - if hashed := new(big.Int).SetBytes(hashed); hashed.Cmp(expectedHash) != 0 { - return fmt.Errorf("hash mismatch: expected %s, got %s", expectedHash.String(), hashed.String()) - } - - outs[0].SetBytes(hashed) - - return nil - } -} diff --git a/std/gkrapi/api.go b/std/gkrapi/api.go index 18a9b23279..d4f27f5c9f 100644 --- a/std/gkrapi/api.go +++ b/std/gkrapi/api.go @@ -62,21 +62,5 @@ func (api *API) Mul(i1, i2 gkr.Variable) gkr.Variable { // Println writes to the standard output. // instance determines which values are chosen for gkr.Variable input. func (api *API) Println(instance int, a ...any) { - isVar := make([]bool, len(a)) - vals := make([]any, len(a)) - for i := range a { - v, ok := a[i].(gkr.Variable) - isVar[i] = ok - if ok { - vals[i] = uint32(v) - } else { - vals[i] = a[i] - } - } - - api.toStore.Prints = append(api.toStore.Prints, gkrinfo.PrintInfo{ - Values: vals, - Instance: uint32(instance), - IsGkrVar: isVar, - }) + api.toStore.Prints = append(api.toStore.Prints, gkrinfo.NewPrint(instance, a...)) } diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 0ba6213286..b51c9c8b6c 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -3,6 +3,7 @@ package gkrapi import ( "fmt" "math/bits" + "slices" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint/solver" @@ -88,8 +89,11 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio res.ins = append(res.ins, gkr.Variable(i)) } } - res.toStore.SolveHintID = solver.GetHintID(SolveHintPlaceholder(res.toStore)) - res.toStore.ProveHintID = solver.GetHintID(ProveHintPlaceholder(fiatshamirHashName)) + res.toStore.SolveHintID = solver.GetHintID(gadget.SolveHintPlaceholder(res.toStore)) + res.toStore.ProveHintID = solver.GetHintID(gadget.ProveHintPlaceholder(fiatshamirHashName)) + + // sort the prints before solving begins + slices.SortFunc(res.toStore.Prints, gkrinfo.PrintInfo.Cmp) parentApi.Compiler().Defer(res.verify) @@ -119,7 +123,7 @@ func (c *Circuit) AddInstance(input map[gkr.Variable]frontend.Variable) (map[gkr } c.toStore.NbInstances++ - solveHintPlaceholder := SolveHintPlaceholder(c.toStore) + solveHintPlaceholder := gadget.SolveHintPlaceholder(c.toStore) outsSerialized, err := c.api.Compiler().NewHint(solveHintPlaceholder, len(c.outs), hintIn...) if err != nil { return nil, fmt.Errorf("failed to create solve hint: %w", err) @@ -180,7 +184,7 @@ func (c *Circuit) verify(api frontend.API) error { copy(hintIns[1:], initialChallenges) - proveHintPlaceholder := ProveHintPlaceholder(c.toStore.HashName) + proveHintPlaceholder := gadget.ProveHintPlaceholder(c.toStore.HashName) if proofSerialized, err = api.Compiler().NewHint( proveHintPlaceholder, gadget.ProofSize(forSnark.circuit, logNbInstances), hintIns...); err != nil { return err From 30f5633b54e7766d4819e5146cc24d1876aee930 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 2 Jun 2025 12:15:03 -0500 Subject: [PATCH 06/92] revert: newPrint back in std/gkr to avoid import cycle --- internal/gkr/gkrinfo/info.go | 21 --------------------- std/gkrapi/api.go | 22 +++++++++++++++++++++- std/gkrapi/compile.go | 4 ---- 3 files changed, 21 insertions(+), 26 deletions(-) diff --git a/internal/gkr/gkrinfo/info.go b/internal/gkr/gkrinfo/info.go index 81902df8c6..5581221ab5 100644 --- a/internal/gkr/gkrinfo/info.go +++ b/internal/gkr/gkrinfo/info.go @@ -7,7 +7,6 @@ import ( "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/std/gkrapi/gkr" ) type ( @@ -141,26 +140,6 @@ type ConstraintSystem interface { SetGkrInfo(info StoringInfo) error } -func NewPrint(instance int, a ...any) PrintInfo { - isVar := make([]bool, len(a)) - vals := make([]any, len(a)) - for i := range a { - v, ok := a[i].(gkr.Variable) - isVar[i] = ok - if ok { - vals[i] = uint32(v) - } else { - vals[i] = a[i] - } - } - - return PrintInfo{ - Values: vals, - Instance: uint32(instance), - IsGkrVar: isVar, - } -} - // NewPrintInfoMap partitions printInfo into map elements, indexed by instance func NewPrintInfoMap(printInfo []PrintInfo) map[uint32][]PrintInfo { res := make(map[uint32][]PrintInfo) diff --git a/std/gkrapi/api.go b/std/gkrapi/api.go index d4f27f5c9f..5765fb98a1 100644 --- a/std/gkrapi/api.go +++ b/std/gkrapi/api.go @@ -62,5 +62,25 @@ func (api *API) Mul(i1, i2 gkr.Variable) gkr.Variable { // Println writes to the standard output. // instance determines which values are chosen for gkr.Variable input. func (api *API) Println(instance int, a ...any) { - api.toStore.Prints = append(api.toStore.Prints, gkrinfo.NewPrint(instance, a...)) + api.toStore.Prints = append(api.toStore.Prints, newPrint(instance, a...)) +} + +func newPrint(instance int, a ...any) gkrinfo.PrintInfo { + isVar := make([]bool, len(a)) + vals := make([]any, len(a)) + for i := range a { + v, ok := a[i].(gkr.Variable) + isVar[i] = ok + if ok { + vals[i] = uint32(v) + } else { + vals[i] = a[i] + } + } + + return gkrinfo.PrintInfo{ + Values: vals, + Instance: uint32(instance), + IsGkrVar: isVar, + } } diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index b51c9c8b6c..606596fee3 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -3,7 +3,6 @@ package gkrapi import ( "fmt" "math/bits" - "slices" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint/solver" @@ -92,9 +91,6 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio res.toStore.SolveHintID = solver.GetHintID(gadget.SolveHintPlaceholder(res.toStore)) res.toStore.ProveHintID = solver.GetHintID(gadget.ProveHintPlaceholder(fiatshamirHashName)) - // sort the prints before solving begins - slices.SortFunc(res.toStore.Prints, gkrinfo.PrintInfo.Cmp) - parentApi.Compiler().Defer(res.verify) return &res From 44d6b4cc76b068b0765b5d492598ee2953ac74cb Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 2 Jun 2025 12:19:18 -0500 Subject: [PATCH 07/92] refactor: remove circuit/instance rearranging --- internal/gkr/gkrinfo/info.go | 75 ------------------------------------ 1 file changed, 75 deletions(-) diff --git a/internal/gkr/gkrinfo/info.go b/internal/gkr/gkrinfo/info.go index 5581221ab5..f7cb5e589d 100644 --- a/internal/gkr/gkrinfo/info.go +++ b/internal/gkr/gkrinfo/info.go @@ -2,11 +2,7 @@ package gkrinfo import ( - "fmt" - "sort" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/internal/utils" ) type ( @@ -60,77 +56,6 @@ func (d *StoringInfo) NewInputVariable() int { return i } -// Compile sorts the Circuit wires, their dependencies and the instances -func (d *StoringInfo) Compile(nbInstances int) (Permutations, error) { - - var p Permutations - d.NbInstances = nbInstances - // sort the instances to decide the order in which they are to be solved - instanceDeps := make([][]int, nbInstances) - for i := range d.Circuit { - for _, dep := range d.Dependencies[i] { - instanceDeps[dep.InputInstance] = append(instanceDeps[dep.InputInstance], dep.OutputInstance) - } - } - - p.SortedInstances, _ = utils.TopologicalSort(instanceDeps) - p.InstancesPermutation = utils.InvertPermutation(p.SortedInstances) - - // this whole circuit sorting is a bit of a charade. if things are built using an api, there's no way it could NOT already be topologically sorted - // worth keeping for future-proofing? - - inputs := utils.Map(d.Circuit, func(w Wire) []int { - return w.Inputs - }) - - var uniqueOuts [][]int - p.SortedWires, uniqueOuts = utils.TopologicalSort(inputs) - p.WiresPermutation = utils.InvertPermutation(p.SortedWires) - wirePermutationAt := utils.SliceAt(p.WiresPermutation) - sorted := make([]Wire, len(d.Circuit)) // TODO: Directly manipulate d.circuit instead - sortedDeps := make([][]InputDependency, len(d.Circuit)) - - // go through the wires in the sorted order and fix the input and dependency indices according to the permutations - for newI, oldI := range p.SortedWires { - oldW := d.Circuit[oldI] - - for depI := range d.Dependencies[oldI] { - dep := &d.Dependencies[oldI][depI] - dep.OutputWire = p.WiresPermutation[dep.OutputWire] - dep.InputInstance = p.InstancesPermutation[dep.InputInstance] - dep.OutputInstance = p.InstancesPermutation[dep.OutputInstance] - } - sort.Slice(d.Dependencies[oldI], func(i, j int) bool { - return d.Dependencies[oldI][i].InputInstance < d.Dependencies[oldI][j].InputInstance - }) - for i := 1; i < len(d.Dependencies[oldI]); i++ { - if d.Dependencies[oldI][i].InputInstance == d.Dependencies[oldI][i-1].InputInstance { - return p, fmt.Errorf("an input wire can only have one dependency per instance") - } - } // TODO: Check that dependencies and explicit assignments cover all instances - - sortedDeps[newI] = d.Dependencies[oldI] - sorted[newI] = Wire{ - Gate: oldW.Gate, - Inputs: utils.Map(oldW.Inputs, wirePermutationAt), - NbUniqueOutputs: len(uniqueOuts[oldI]), - } - } - - // re-arrange the prints - for i := range d.Prints { - for j, isVar := range d.Prints[i].IsGkrVar { - if isVar { - d.Prints[i].Values[j] = uint32(p.WiresPermutation[d.Prints[i].Values[j].(uint32)]) - } - } - } - - d.Circuit, d.Dependencies = sorted, sortedDeps - - return p, nil -} - func (d *StoringInfo) Is() bool { return d.Circuit != nil } From f5f726ce3206d70d55254fe670299c96426ed7d3 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 2 Jun 2025 12:32:02 -0500 Subject: [PATCH 08/92] chore: generify --- constraint/bls12-381/solver.go | 6 +- constraint/bls24-315/solver.go | 6 +- constraint/bls24-317/solver.go | 6 +- constraint/bn254/solver.go | 6 +- constraint/bw6-633/solver.go | 6 +- constraint/bw6-761/solver.go | 6 +- .../backend/template/gkr/solver_hints.go.tmpl | 114 +++++++----------- .../template/representations/solver.go.tmpl | 6 +- internal/gkr/bls12-377/solver_hints.go | 2 + internal/gkr/bls12-381/solver_hints.go | 113 +++++++---------- internal/gkr/bls24-315/solver_hints.go | 113 +++++++---------- internal/gkr/bls24-317/solver_hints.go | 113 +++++++---------- internal/gkr/bn254/solver_hints.go | 113 +++++++---------- internal/gkr/bw6-633/solver_hints.go | 113 +++++++---------- internal/gkr/bw6-761/solver_hints.go | 113 +++++++---------- internal/gkr/gkrtesting/gkrtesting.go | 5 +- 16 files changed, 328 insertions(+), 513 deletions(-) diff --git a/constraint/bls12-381/solver.go b/constraint/bls12-381/solver.go index 1bfa4c5884..4e7835fef1 100644 --- a/constraint/bls12-381/solver.go +++ b/constraint/bls12-381/solver.go @@ -51,14 +51,14 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bls24-315/solver.go b/constraint/bls24-315/solver.go index 4f5b72c776..5dc3fc2ef7 100644 --- a/constraint/bls24-315/solver.go +++ b/constraint/bls24-315/solver.go @@ -51,14 +51,14 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bls24-317/solver.go b/constraint/bls24-317/solver.go index 9462b5d3e4..f007d16494 100644 --- a/constraint/bls24-317/solver.go +++ b/constraint/bls24-317/solver.go @@ -51,14 +51,14 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bn254/solver.go b/constraint/bn254/solver.go index 4ccc03e7e0..a7674e9542 100644 --- a/constraint/bn254/solver.go +++ b/constraint/bn254/solver.go @@ -51,14 +51,14 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bw6-633/solver.go b/constraint/bw6-633/solver.go index 642369791f..f294f7d826 100644 --- a/constraint/bw6-633/solver.go +++ b/constraint/bw6-633/solver.go @@ -51,14 +51,14 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bw6-761/solver.go b/constraint/bw6-761/solver.go index a65445eb4c..e10f2f21d7 100644 --- a/constraint/bw6-761/solver.go +++ b/constraint/bw6-761/solver.go @@ -51,14 +51,14 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index e1d41e8cb8..472121423c 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -6,6 +6,7 @@ import ( "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -14,92 +15,68 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment + circuit gkrtypes.Circuit + workers *utils.WorkerPool + maxNbIn int // maximum number of inputs for a gate in the circuit + printsByInstance map[uint32][]gkrinfo.PrintInfo } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit +func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { + d := SolvingData{ + workers: utils.NewWorkerPool(), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + } + d.circuit.SetNbUniqueOutputs() + d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } + + return &d } // this module assumes that wire and instance indexes respect dependencies -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } - } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { +func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + instanceI := ins[0].Uint64() + if !ins[0].IsUint64() { // TODO use idiomatic printf tag + return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) + } - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } - } - } - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ + } } - for _, p := range info.Prints { + prints := data.printsByInstance[uint32(instanceI)] + delete(data.printsByInstance, uint32(instanceI)) + for _, p := range prints { serializable := make([]any, len(p.Values)) for i, v := range p.Values { if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 @@ -110,9 +87,6 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } fmt.Println(serializable...) } - - setOuts(data.assignment, info.Circuit, outs) - return nil } } diff --git a/internal/generator/backend/template/representations/solver.go.tmpl b/internal/generator/backend/template/representations/solver.go.tmpl index fd685e6e21..202642b87a 100644 --- a/internal/generator/backend/template/representations/solver.go.tmpl +++ b/internal/generator/backend/template/representations/solver.go.tmpl @@ -43,14 +43,14 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, {{ if not .NoGKR -}} // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } {{ end -}} diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 6353370a15..18ad9c91e9 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -39,6 +39,8 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { d.circuit.SetNbUniqueOutputs() d.maxNbIn = d.circuit.MaxGateNbIn() + + d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index cb498c78b7..5c73505739 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -13,6 +13,7 @@ import ( "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,92 +22,67 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment + circuit gkrtypes.Circuit + workers *utils.WorkerPool + maxNbIn int // maximum number of inputs for a gate in the circuit + printsByInstance map[uint32][]gkrinfo.PrintInfo } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit +func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { + d := SolvingData{ + workers: utils.NewWorkerPool(), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + } + d.circuit.SetNbUniqueOutputs() + d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } + + return &d } // this module assumes that wire and instance indexes respect dependencies -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + if !ins[0].IsUint64() { // TODO use idiomatic printf tag + return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } - } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } - } + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } } - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end - } - - for _, p := range info.Prints { + prints := data.printsByInstance[uint32(instanceI)] + delete(data.printsByInstance, uint32(instanceI)) + for _, p := range prints { serializable := make([]any, len(p.Values)) for i, v := range p.Values { if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 @@ -117,9 +93,6 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } fmt.Println(serializable...) } - - setOuts(data.assignment, info.Circuit, outs) - return nil } } diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 914c8a9d61..b44cea8feb 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -13,6 +13,7 @@ import ( "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,92 +22,67 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment + circuit gkrtypes.Circuit + workers *utils.WorkerPool + maxNbIn int // maximum number of inputs for a gate in the circuit + printsByInstance map[uint32][]gkrinfo.PrintInfo } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit +func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { + d := SolvingData{ + workers: utils.NewWorkerPool(), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + } + d.circuit.SetNbUniqueOutputs() + d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } + + return &d } // this module assumes that wire and instance indexes respect dependencies -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + if !ins[0].IsUint64() { // TODO use idiomatic printf tag + return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } - } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } - } + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } } - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end - } - - for _, p := range info.Prints { + prints := data.printsByInstance[uint32(instanceI)] + delete(data.printsByInstance, uint32(instanceI)) + for _, p := range prints { serializable := make([]any, len(p.Values)) for i, v := range p.Values { if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 @@ -117,9 +93,6 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } fmt.Println(serializable...) } - - setOuts(data.assignment, info.Circuit, outs) - return nil } } diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index f6e1ad993d..b6c22f4eb6 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -13,6 +13,7 @@ import ( "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,92 +22,67 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment + circuit gkrtypes.Circuit + workers *utils.WorkerPool + maxNbIn int // maximum number of inputs for a gate in the circuit + printsByInstance map[uint32][]gkrinfo.PrintInfo } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit +func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { + d := SolvingData{ + workers: utils.NewWorkerPool(), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + } + d.circuit.SetNbUniqueOutputs() + d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } + + return &d } // this module assumes that wire and instance indexes respect dependencies -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + if !ins[0].IsUint64() { // TODO use idiomatic printf tag + return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } - } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } - } + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } } - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end - } - - for _, p := range info.Prints { + prints := data.printsByInstance[uint32(instanceI)] + delete(data.printsByInstance, uint32(instanceI)) + for _, p := range prints { serializable := make([]any, len(p.Values)) for i, v := range p.Values { if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 @@ -117,9 +93,6 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } fmt.Println(serializable...) } - - setOuts(data.assignment, info.Circuit, outs) - return nil } } diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 7bc3782932..8a8fa12a75 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -13,6 +13,7 @@ import ( "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,92 +22,67 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment + circuit gkrtypes.Circuit + workers *utils.WorkerPool + maxNbIn int // maximum number of inputs for a gate in the circuit + printsByInstance map[uint32][]gkrinfo.PrintInfo } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit +func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { + d := SolvingData{ + workers: utils.NewWorkerPool(), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + } + d.circuit.SetNbUniqueOutputs() + d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } + + return &d } // this module assumes that wire and instance indexes respect dependencies -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + if !ins[0].IsUint64() { // TODO use idiomatic printf tag + return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } - } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } - } + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } } - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end - } - - for _, p := range info.Prints { + prints := data.printsByInstance[uint32(instanceI)] + delete(data.printsByInstance, uint32(instanceI)) + for _, p := range prints { serializable := make([]any, len(p.Values)) for i, v := range p.Values { if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 @@ -117,9 +93,6 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } fmt.Println(serializable...) } - - setOuts(data.assignment, info.Circuit, outs) - return nil } } diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 57343d291f..e1df46828f 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -13,6 +13,7 @@ import ( "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,92 +22,67 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment + circuit gkrtypes.Circuit + workers *utils.WorkerPool + maxNbIn int // maximum number of inputs for a gate in the circuit + printsByInstance map[uint32][]gkrinfo.PrintInfo } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit +func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { + d := SolvingData{ + workers: utils.NewWorkerPool(), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + } + d.circuit.SetNbUniqueOutputs() + d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } + + return &d } // this module assumes that wire and instance indexes respect dependencies -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + if !ins[0].IsUint64() { // TODO use idiomatic printf tag + return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } - } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } - } + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } } - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end - } - - for _, p := range info.Prints { + prints := data.printsByInstance[uint32(instanceI)] + delete(data.printsByInstance, uint32(instanceI)) + for _, p := range prints { serializable := make([]any, len(p.Values)) for i, v := range p.Values { if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 @@ -117,9 +93,6 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } fmt.Println(serializable...) } - - setOuts(data.assignment, info.Circuit, outs) - return nil } } diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 606f13ec23..2104520d6b 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -13,6 +13,7 @@ import ( "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,92 +22,67 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment + circuit gkrtypes.Circuit + workers *utils.WorkerPool + maxNbIn int // maximum number of inputs for a gate in the circuit + printsByInstance map[uint32][]gkrinfo.PrintInfo } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit +func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { + d := SolvingData{ + workers: utils.NewWorkerPool(), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + } + d.circuit.SetNbUniqueOutputs() + d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } + + return &d } // this module assumes that wire and instance indexes respect dependencies -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + if !ins[0].IsUint64() { // TODO use idiomatic printf tag + return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } - } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } - } + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } } - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end - } - - for _, p := range info.Prints { + prints := data.printsByInstance[uint32(instanceI)] + delete(data.printsByInstance, uint32(instanceI)) + for _, p := range prints { serializable := make([]any, len(p.Values)) for i, v := range p.Values { if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 @@ -117,9 +93,6 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } fmt.Println(serializable...) } - - setOuts(data.assignment, info.Circuit, outs) - return nil } } diff --git a/internal/gkr/gkrtesting/gkrtesting.go b/internal/gkr/gkrtesting/gkrtesting.go index ce9ba88942..4c901f04a2 100644 --- a/internal/gkr/gkrtesting/gkrtesting.go +++ b/internal/gkr/gkrtesting/gkrtesting.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + "github.com/consensys/gnark" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -35,10 +36,10 @@ func NewCache() *Cache { res = api.Mul(res, sum) // sum^7 return res - }, 2, 7, -1) + }, 2, 7, -1, gnark.Curves()) gates["select-input-3"] = gkrtypes.NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return in[2] - }, 3, 1, 0) + }, 3, 1, 0, gnark.Curves()) return &Cache{ circuits: make(map[string]gkrtypes.Circuit), From 644fd6a6ee688da606b37505bc715273d91349f2 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 2 Jun 2025 12:58:44 -0500 Subject: [PATCH 09/92] fix: registry duplicate detection --- constraint/solver/gkrgates/registry.go | 5 --- constraint/solver/gkrgates/registry_test.go | 46 ++++++++++++++++----- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/constraint/solver/gkrgates/registry.go b/constraint/solver/gkrgates/registry.go index 2e1d8642ef..71b4969883 100644 --- a/constraint/solver/gkrgates/registry.go +++ b/constraint/solver/gkrgates/registry.go @@ -121,11 +121,6 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) (register if g, ok := gates[s.name]; ok { // gate already registered - if reflect.ValueOf(f).Pointer() != reflect.ValueOf(gates[s.name].Evaluate).Pointer() { - return false, fmt.Errorf("gate \"%s\" already registered with a different function", s.name) - } - // it still might be an anonymous function with different parameters. - // need to test further if g.NbIn() != nbIn { return false, fmt.Errorf("gate \"%s\" already registered with a different number of inputs (%d != %d)", s.name, g.NbIn(), nbIn) } diff --git a/constraint/solver/gkrgates/registry_test.go b/constraint/solver/gkrgates/registry_test.go index ec41888ef3..5a9e6e871d 100644 --- a/constraint/solver/gkrgates/registry_test.go +++ b/constraint/solver/gkrgates/registry_test.go @@ -11,20 +11,38 @@ import ( "github.com/stretchr/testify/assert" ) -func TestRegisterDegreeDetection(t *testing.T) { +func TestRegister(t *testing.T) { testGate := func(name gkr.GateName, f gkr.GateFunction, nbIn, degree int) { t.Run(string(name), func(t *testing.T) { name = name + "-register-gate-test" - assert.NoError(t, Register(f, nbIn, WithDegree(degree), WithName(name)), "given degree must be accepted") + added, err := Register(f, nbIn, WithDegree(degree), WithName(name+"_given")) + assert.NoError(t, err, "given degree must be accepted") + assert.True(t, added, "registration must succeed for given degree") - assert.Error(t, Register(f, nbIn, WithDegree(degree-1), WithName(name)), "lower degree must be rejected") + registered, err := Register(f, nbIn, WithDegree(degree-1), WithName(name+"_lower")) + assert.Error(t, err, "error must be returned for lower degree") + assert.False(t, registered, "registration must fail for lower degree") - assert.Error(t, Register(f, nbIn, WithDegree(degree+1), WithName(name)), "higher degree must be rejected") + registered, err = Register(f, nbIn, WithDegree(degree+1), WithName(name+"_higher")) + assert.Error(t, err, "error must be returned for higher degree") + assert.False(t, registered, "registration must fail for higher degree") - assert.NoError(t, Register(f, nbIn), "no degree must be accepted") + registered, err = Register(f, nbIn, WithName(name+"_no_degree")) + assert.NoError(t, err, "no error must be returned when no degree is specified") + assert.True(t, registered, "registration must succeed when no degree is specified") - assert.Equal(t, degree, Get(name).Degree(), "degree must be detected correctly") + assert.Equal(t, degree, Get(name+"_no_degree").Degree(), "degree must be detected correctly") + + added, err = Register(f, nbIn, WithDegree(degree), WithName(name+"_given")) + assert.NoError(t, err, "given degree must be accepted") + assert.False(t, added, "gate must not be re-registered") + + added, err = Register(func(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(f(api, x...), 1) + }, nbIn, WithDegree(degree), WithName(name+"_given")) + assert.Error(t, err, "registering another function under the same name must fail") + assert.False(t, added, "gate must not be re-registered") }) } @@ -47,15 +65,23 @@ func TestRegisterDegreeDetection(t *testing.T) { ) }, 2, 1) - // zero polynomial must not be accepted t.Run("zero", func(t *testing.T) { const gateName gkr.GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, gkrtypes.ErrZeroFunction) + expectedError := fmt.Errorf("for gate \"%s\": %v", gateName, gkrtypes.ErrZeroFunction) zeroGate := func(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Sub(x[0], x[0]) } - assert.Equal(t, expectedError, Register(zeroGate, 1, WithName(gateName))) - assert.Equal(t, expectedError, Register(zeroGate, 1, WithName(gateName), WithDegree(2))) + // Attempt to register the zero gate without specifying a degree + registered, err := Register(zeroGate, 1, WithName(gateName)) + assert.Error(t, err, "error must be returned for zero polynomial") + assert.Equal(t, expectedError, err, "error message must match expected error") + assert.False(t, registered, "registration must fail for zero polynomial") + + // Attempt to register the zero gate with a specified degree + registered, err = Register(zeroGate, 1, WithName(gateName), WithDegree(2)) + assert.Error(t, err, "error must be returned for zero polynomial with degree") + assert.Equal(t, expectedError, err, "error message must match expected error") + assert.False(t, registered, "registration must fail for zero polynomial with degree") }) } From 2001e8320931972a1ea2620fbdcdf52ea4c20caf Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 2 Jun 2025 15:43:31 -0500 Subject: [PATCH 10/92] fix: solver hint id mismatch --- .../backend/template/gkr/solver_hints.go.tmpl | 2 +- internal/gkr/bls12-377/solver_hints.go | 2 +- internal/gkr/bls12-381/solver_hints.go | 2 +- internal/gkr/bls24-315/solver_hints.go | 2 +- internal/gkr/bls24-317/solver_hints.go | 2 +- internal/gkr/bn254/solver_hints.go | 2 +- internal/gkr/bw6-633/solver_hints.go | 2 +- internal/gkr/bw6-761/solver_hints.go | 2 +- internal/gkr/hints.go | 19 +++++++------- std/gkrapi/compile.go | 26 ++++++++++++++----- 10 files changed, 37 insertions(+), 24 deletions(-) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index 472121423c..170c1807da 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -66,7 +66,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 18ad9c91e9..b63ceebb83 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -72,7 +72,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 5c73505739..93b65be2bf 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -72,7 +72,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index b44cea8feb..49317d446c 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -72,7 +72,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index b6c22f4eb6..c02c8ab135 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -72,7 +72,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 8a8fa12a75..5c0a2e3ef3 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -72,7 +72,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index e1df46828f..21bc7961a5 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -72,7 +72,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 2104520d6b..dacd7eb35f 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -72,7 +72,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/hints.go b/internal/gkr/hints.go index 2c2621911e..a6854e36df 100644 --- a/internal/gkr/hints.go +++ b/internal/gkr/hints.go @@ -26,8 +26,8 @@ func modKey(mod *big.Int) string { // SolveHintPlaceholder solves one instance of a GKR circuit. // The first input is the index of the instance. The rest are the inputs of the circuit, in their nominal order. -func SolveHintPlaceholder(gkrInfo gkrinfo.StoringInfo) solver.Hint { - return func(mod *big.Int, ins []*big.Int, outs []*big.Int) error { +func SolveHintPlaceholder(gkrInfo gkrinfo.StoringInfo) (solver.Hint, solver.HintID) { + hint := func(mod *big.Int, ins []*big.Int, outs []*big.Int) error { solvingInfo, err := gkrtypes.StoringToSolvingInfo(gkrInfo, gkrgates.Get) if err != nil { @@ -39,31 +39,31 @@ func SolveHintPlaceholder(gkrInfo gkrinfo.StoringInfo) solver.Hint { // TODO @Tabaie autogenerate this or decide not to if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { data := bls12377.NewSolvingData(solvingInfo) - hint = bls12377.SolveHint(solvingInfo, data) + hint = bls12377.SolveHint(data) testEngineGkrSolvingData[modKey(mod)] = data } else if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { data := bls12381.NewSolvingData(solvingInfo) - hint = bls12381.SolveHint(solvingInfo, data) + hint = bls12381.SolveHint(data) testEngineGkrSolvingData[modKey(mod)] = data } else if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { data := bls24315.NewSolvingData(solvingInfo) - hint = bls24315.SolveHint(solvingInfo, data) + hint = bls24315.SolveHint(data) testEngineGkrSolvingData[modKey(mod)] = data } else if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { data := bls24317.NewSolvingData(solvingInfo) - hint = bls24317.SolveHint(solvingInfo, data) + hint = bls24317.SolveHint(data) testEngineGkrSolvingData[modKey(mod)] = data } else if mod.Cmp(ecc.BN254.ScalarField()) == 0 { data := bn254.NewSolvingData(solvingInfo) - hint = bn254.SolveHint(solvingInfo, data) + hint = bn254.SolveHint(data) testEngineGkrSolvingData[modKey(mod)] = data } else if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { data := bw6633.NewSolvingData(solvingInfo) - hint = bw6633.SolveHint(solvingInfo, data) + hint = bw6633.SolveHint(data) testEngineGkrSolvingData[modKey(mod)] = data } else if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { data := bw6761.NewSolvingData(solvingInfo) - hint = bw6761.SolveHint(solvingInfo, data) + hint = bw6761.SolveHint(data) testEngineGkrSolvingData[modKey(mod)] = data } else { return errors.New("unsupported modulus") @@ -71,6 +71,7 @@ func SolveHintPlaceholder(gkrInfo gkrinfo.StoringInfo) solver.Hint { return hint(mod, ins, outs) } + return hint, solver.GetHintID(hint) } func ProveHintPlaceholder(hashName string) solver.Hint { diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 606596fee3..63ebe7c314 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -74,23 +74,35 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio api: parentApi, } - api.toStore.HashName = fiatshamirHashName + res.toStore.HashName = fiatshamirHashName for _, opt := range options { opt(&res) } + notOut := make([]bool, len(res.toStore.Circuit)) for i := range res.toStore.Circuit { - if res.toStore.Circuit[i].IsOutput() { - res.outs = append(res.ins, gkr.Variable(i)) - } if res.toStore.Circuit[i].IsInput() { res.ins = append(res.ins, gkr.Variable(i)) } + for _, inWI := range res.toStore.Circuit[i].Inputs { + notOut[inWI] = true + } } - res.toStore.SolveHintID = solver.GetHintID(gadget.SolveHintPlaceholder(res.toStore)) + + for i := range res.toStore.Circuit { + if !notOut[i] { + res.outs = append(res.outs, gkr.Variable(i)) + } + } + + _, res.toStore.SolveHintID = gadget.SolveHintPlaceholder(res.toStore) res.toStore.ProveHintID = solver.GetHintID(gadget.ProveHintPlaceholder(fiatshamirHashName)) + if err := parentApi.(gkrinfo.ConstraintSystem).SetGkrInfo(res.toStore); err != nil { + panic(err) + } + parentApi.Compiler().Defer(res.verify) return &res @@ -119,7 +131,7 @@ func (c *Circuit) AddInstance(input map[gkr.Variable]frontend.Variable) (map[gkr } c.toStore.NbInstances++ - solveHintPlaceholder := gadget.SolveHintPlaceholder(c.toStore) + solveHintPlaceholder, _ := gadget.SolveHintPlaceholder(c.toStore) outsSerialized, err := c.api.Compiler().NewHint(solveHintPlaceholder, len(c.outs), hintIn...) if err != nil { return nil, fmt.Errorf("failed to create solve hint: %w", err) @@ -203,7 +215,7 @@ func (c *Circuit) verify(api frontend.API) error { return err } - return api.(gkrinfo.ConstraintSystem).SetGkrInfo(c.toStore) + return nil } func slicePtrAt[T any](slice []T) func(int) *T { From d3ca3c12d9d498081b1dbe3f0daaf752b2c1183e Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 2 Jun 2025 15:49:04 -0500 Subject: [PATCH 11/92] remove redundant make --- internal/gkr/bn254/solver_hints.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 5c0a2e3ef3..e6a971268d 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -40,7 +40,6 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { d.circuit.SetNbUniqueOutputs() d.maxNbIn = d.circuit.MaxGateNbIn() - d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } From 647448206a4c4a39b88bc4d23a9a4be72904db34 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 2 Jun 2025 16:18:39 -0500 Subject: [PATCH 12/92] fix: works on plonk --- std/gkrapi/api.go | 6 ++++++ std/gkrapi/compile.go | 8 ++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/std/gkrapi/api.go b/std/gkrapi/api.go index 5765fb98a1..6c6d4b0976 100644 --- a/std/gkrapi/api.go +++ b/std/gkrapi/api.go @@ -65,6 +65,12 @@ func (api *API) Println(instance int, a ...any) { api.toStore.Prints = append(api.toStore.Prints, newPrint(instance, a...)) } +// Println writes to the standard output. +// instance determines which values are chosen for gkr.Variable input. +func (c *Circuit) Println(instance int, a ...any) { + c.toStore.Prints = append(c.toStore.Prints, newPrint(instance, a...)) +} + func newPrint(instance int, a ...any) gkrinfo.PrintInfo { isVar := make([]bool, len(a)) vals := make([]any, len(a)) diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 63ebe7c314..05a12bb8a9 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -99,10 +99,6 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio _, res.toStore.SolveHintID = gadget.SolveHintPlaceholder(res.toStore) res.toStore.ProveHintID = solver.GetHintID(gadget.ProveHintPlaceholder(fiatshamirHashName)) - if err := parentApi.(gkrinfo.ConstraintSystem).SetGkrInfo(res.toStore); err != nil { - panic(err) - } - parentApi.Compiler().Defer(res.verify) return &res @@ -151,6 +147,10 @@ func (c *Circuit) verify(api frontend.API) error { panic("api mismatch") } + if err := api.(gkrinfo.ConstraintSystem).SetGkrInfo(c.toStore); err != nil { + return err + } + if len(c.outs) == 0 || len(c.assignments[0]) == 0 { return nil } From 0c0b6b0b3a032f69436d2634a9a6edb2bb1c4ff1 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 2 Jun 2025 19:49:00 -0500 Subject: [PATCH 13/92] refactor: solve hint for test engine --- .../backend/template/gkr/gkr.go.tmpl | 2 +- .../backend/template/gkr/solver_hints.go.tmpl | 2 +- internal/gkr/engine_hints.go | 181 ++++++++++++++++++ internal/gkr/hints.go | 111 ----------- std/gkrapi/compile.go | 19 +- 5 files changed, 194 insertions(+), 121 deletions(-) create mode 100644 internal/gkr/engine_hints.go delete mode 100644 internal/gkr/hints.go diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 5105b0a33d..3e3881d15f 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -735,7 +735,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod {{ .ElementType }} - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index 170c1807da..b9f570e26e 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -47,7 +47,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { // TODO use idiomatic printf tag + if !ins[0].IsUint64() { return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go new file mode 100644 index 0000000000..ec7e10e00a --- /dev/null +++ b/internal/gkr/engine_hints.go @@ -0,0 +1,181 @@ +package gkr + +import ( + "errors" + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/frontend" + bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" + bls12381 "github.com/consensys/gnark/internal/gkr/bls12-381" + bls24315 "github.com/consensys/gnark/internal/gkr/bls24-315" + bls24317 "github.com/consensys/gnark/internal/gkr/bls24-317" + bn254 "github.com/consensys/gnark/internal/gkr/bn254" + bw6633 "github.com/consensys/gnark/internal/gkr/bw6-633" + bw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" + "github.com/consensys/gnark/internal/gkr/gkrinfo" + "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/utils" +) + +func modKey(mod *big.Int) string { + return mod.Text(32) +} + +type TestEngineHints struct { + assignment gkrtypes.WireAssignment + info *gkrinfo.StoringInfo + circuit gkrtypes.Circuit + gateIns []frontend.Variable +} + +func NewTestEngineHints(info *gkrinfo.StoringInfo) (*TestEngineHints, error) { + circuit, err := gkrtypes.CircuitInfoToCircuit(info.Circuit, gkrgates.Get) + if err != nil { + return nil, err + } + + return &TestEngineHints{ + info: info, + circuit: circuit, + gateIns: make([]frontend.Variable, circuit.MaxGateNbIn()), + }, + err +} + +// Solve solves one instance of a GKR circuit. +// The first input is the index of the instance. The rest are the inputs of the circuit, in their nominal order. +func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) error { + + // TODO handle prints + + if in0 := ins[0].Uint64(); !ins[0].IsUint64() || in0 >= uint64(len(h.info.Circuit)) || in0 > 0xffffffff { + return errors.New("first input must be a uint32 instance index") + } else if in0 != uint64(h.info.NbInstances) || h.info.NbInstances != len(h.assignment[0]) { + return errors.New("first input must equal the number of instances, and calls to Solve must be done in order of instance index") + } + + api := gateAPI{mod} + + inI := 1 + outI := 0 + for wI := range h.circuit { + w := &h.circuit[wI] + var val frontend.Variable + if w.IsInput() { + val = utils.FromInterface(ins[inI]) + inI++ + } else { + for gateInI, inWI := range w.Inputs { + h.gateIns[gateInI] = h.assignment[inWI][gateInI] + } + val = w.Gate.Evaluate(api, h.gateIns[:len(w.Inputs)]...) + } + if w.IsOutput() { + *outs[outI] = utils.FromInterface(val) + } + h.assignment[wI] = append(h.assignment[wI], val) + } + return nil +} + +func (h *TestEngineHints) Prove(mod *big.Int, ins, outs []*big.Int) error { + + // todo handle prints + k := modKey(mod) + data, ok := testEngineGkrSolvingData[k] + if !ok { + return errors.New("solving data not found") + } + delete(testEngineGkrSolvingData, k) + + // TODO @Tabaie autogenerate this or decide not to + if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { + return bls12377.ProveHint(hashName, data.(*bls12377.SolvingData))(mod, ins, outs) + } + if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { + return bls12381.ProveHint(hashName, data.(*bls12381.SolvingData))(mod, ins, outs) + } + if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { + return bls24315.ProveHint(hashName, data.(*bls24315.SolvingData))(mod, ins, outs) + } + if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { + return bls24317.ProveHint(hashName, data.(*bls24317.SolvingData))(mod, ins, outs) + } + if mod.Cmp(ecc.BN254.ScalarField()) == 0 { + return bn254.ProveHint(hashName, data.(*bn254.SolvingData))(mod, ins, outs) + } + if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { + return bw6633.ProveHint(hashName, data.(*bw6633.SolvingData))(mod, ins, outs) + } + if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { + return bw6761.ProveHint(hashName, data.(*bw6761.SolvingData))(mod, ins, outs) + } + + return errors.New("unsupported modulus") + +} + +type gateAPI struct{ *big.Int } + +func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + in1 := utils.FromInterface(i1) + in2 := utils.FromInterface(i2) + + in1.Add(&in1, &in2) + for _, v := range in { + inV := utils.FromInterface(v) + in1.Add(&in1, &inV) + } + return &in1 +} + +func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + x, y := utils.FromInterface(b), utils.FromInterface(c) + x.Mul(&x, &y) + y = utils.FromInterface(a) + x.Add(&x, &y) + return &x +} + +func (g gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + x := utils.FromInterface(i1) + x.Neg(&x) + return &x +} + +func (g gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + x := utils.FromInterface(i1) + y := utils.FromInterface(i2) + x.Sub(&x, &y) + for _, v := range in { + y = utils.FromInterface(v) + x.Sub(&x, &y) + } + return &x +} + +func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + x := utils.FromInterface(i1) + y := utils.FromInterface(i2) + x.Mul(&x, &y) + for _, v := range in { + y = utils.FromInterface(v) + x.Mul(&x, &y) + } + return &x +} + +func (g gateAPI) Println(a ...frontend.Variable) { + strings := make([]string, len(a)) + for i := range a { + if s, ok := a[i].(fmt.Stringer); ok { + strings[i] = s.String() + } else { + bigInt := utils.FromInterface(a[i]) + strings[i] = bigInt.String() + } + } +} diff --git a/internal/gkr/hints.go b/internal/gkr/hints.go deleted file mode 100644 index a6854e36df..0000000000 --- a/internal/gkr/hints.go +++ /dev/null @@ -1,111 +0,0 @@ -package gkr - -import ( - "errors" - "math/big" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/constraint/solver/gkrgates" - bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" - bls12381 "github.com/consensys/gnark/internal/gkr/bls12-381" - bls24315 "github.com/consensys/gnark/internal/gkr/bls24-315" - bls24317 "github.com/consensys/gnark/internal/gkr/bls24-317" - bn254 "github.com/consensys/gnark/internal/gkr/bn254" - bw6633 "github.com/consensys/gnark/internal/gkr/bw6-633" - bw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" - "github.com/consensys/gnark/internal/gkr/gkrinfo" - "github.com/consensys/gnark/internal/gkr/gkrtypes" -) - -var testEngineGkrSolvingData = make(map[string]any) - -func modKey(mod *big.Int) string { - return mod.Text(32) -} - -// SolveHintPlaceholder solves one instance of a GKR circuit. -// The first input is the index of the instance. The rest are the inputs of the circuit, in their nominal order. -func SolveHintPlaceholder(gkrInfo gkrinfo.StoringInfo) (solver.Hint, solver.HintID) { - hint := func(mod *big.Int, ins []*big.Int, outs []*big.Int) error { - - solvingInfo, err := gkrtypes.StoringToSolvingInfo(gkrInfo, gkrgates.Get) - if err != nil { - return err - } - - var hint solver.Hint - - // TODO @Tabaie autogenerate this or decide not to - if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { - data := bls12377.NewSolvingData(solvingInfo) - hint = bls12377.SolveHint(data) - testEngineGkrSolvingData[modKey(mod)] = data - } else if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { - data := bls12381.NewSolvingData(solvingInfo) - hint = bls12381.SolveHint(data) - testEngineGkrSolvingData[modKey(mod)] = data - } else if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { - data := bls24315.NewSolvingData(solvingInfo) - hint = bls24315.SolveHint(data) - testEngineGkrSolvingData[modKey(mod)] = data - } else if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { - data := bls24317.NewSolvingData(solvingInfo) - hint = bls24317.SolveHint(data) - testEngineGkrSolvingData[modKey(mod)] = data - } else if mod.Cmp(ecc.BN254.ScalarField()) == 0 { - data := bn254.NewSolvingData(solvingInfo) - hint = bn254.SolveHint(data) - testEngineGkrSolvingData[modKey(mod)] = data - } else if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { - data := bw6633.NewSolvingData(solvingInfo) - hint = bw6633.SolveHint(data) - testEngineGkrSolvingData[modKey(mod)] = data - } else if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { - data := bw6761.NewSolvingData(solvingInfo) - hint = bw6761.SolveHint(data) - testEngineGkrSolvingData[modKey(mod)] = data - } else { - return errors.New("unsupported modulus") - } - - return hint(mod, ins, outs) - } - return hint, solver.GetHintID(hint) -} - -func ProveHintPlaceholder(hashName string) solver.Hint { - return func(mod *big.Int, ins, outs []*big.Int) error { - k := modKey(mod) - data, ok := testEngineGkrSolvingData[k] - if !ok { - return errors.New("solving data not found") - } - delete(testEngineGkrSolvingData, k) - - // TODO @Tabaie autogenerate this or decide not to - if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { - return bls12377.ProveHint(hashName, data.(*bls12377.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { - return bls12381.ProveHint(hashName, data.(*bls12381.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { - return bls24315.ProveHint(hashName, data.(*bls24315.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { - return bls24317.ProveHint(hashName, data.(*bls24317.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BN254.ScalarField()) == 0 { - return bn254.ProveHint(hashName, data.(*bn254.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { - return bw6633.ProveHint(hashName, data.(*bw6633.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { - return bw6761.ProveHint(hashName, data.(*bw6761.SolvingData))(mod, ins, outs) - } - - return errors.New("unsupported modulus") - } -} diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 05a12bb8a9..32979de0fe 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -32,7 +32,8 @@ type Circuit struct { getInitialChallenges InitialChallengeGetter // optional getter for the initial Fiat-Shamir challenge ins []gkr.Variable outs []gkr.Variable - api frontend.API // the parent API used for hints + api frontend.API // the parent API used for hints + hints *gadget.TestEngineHints // hints for the GKR circuit, used for testing purposes } // New creates a new GKR API @@ -75,6 +76,7 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio } res.toStore.HashName = fiatshamirHashName + res.hints = gadget.NewTestEngineHints(&res.toStore) for _, opt := range options { opt(&res) @@ -96,8 +98,7 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio } } - _, res.toStore.SolveHintID = gadget.SolveHintPlaceholder(res.toStore) - res.toStore.ProveHintID = solver.GetHintID(gadget.ProveHintPlaceholder(fiatshamirHashName)) + res.toStore.ProveHintID = solver.GetHintID(res.hints.Prove) parentApi.Compiler().Defer(res.verify) @@ -126,9 +127,12 @@ func (c *Circuit) AddInstance(input map[gkr.Variable]frontend.Variable) (map[gkr } } + if c.toStore.NbInstances == 0 { + c.toStore.SolveHintID = solver.GetHintID(c.hints.Solve) + } + c.toStore.NbInstances++ - solveHintPlaceholder, _ := gadget.SolveHintPlaceholder(c.toStore) - outsSerialized, err := c.api.Compiler().NewHint(solveHintPlaceholder, len(c.outs), hintIn...) + outsSerialized, err := c.api.Compiler().NewHint(c.hints.Solve, len(c.outs), hintIn...) if err != nil { return nil, fmt.Errorf("failed to create solve hint: %w", err) } @@ -192,12 +196,11 @@ func (c *Circuit) verify(api frontend.API) error { copy(hintIns[1:], initialChallenges) - proveHintPlaceholder := gadget.ProveHintPlaceholder(c.toStore.HashName) if proofSerialized, err = api.Compiler().NewHint( - proveHintPlaceholder, gadget.ProofSize(forSnark.circuit, logNbInstances), hintIns...); err != nil { + c.hints.Prove, gadget.ProofSize(forSnark.circuit, logNbInstances), hintIns...); err != nil { return err } - c.toStore.ProveHintID = solver.GetHintID(proveHintPlaceholder) + c.toStore.ProveHintID = solver.GetHintID(c.hints.Prove) forSnarkSorted := utils.MapRange(0, len(c.toStore.Circuit), slicePtrAt(forSnark.circuit)) From 478d53d1e6385c498fa323cfc261e242114f2eb4 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 13:22:56 -0500 Subject: [PATCH 14/92] fix prove hint --- internal/gkr/bn254/solver_hints.go | 40 ++++++++++++++++++++++++++---- internal/gkr/engine_hints.go | 37 +++++++++++++-------------- std/gkrapi/compile.go | 7 +++++- 3 files changed, 60 insertions(+), 24 deletions(-) diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index e6a971268d..88d521f531 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -10,7 +10,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" @@ -24,14 +23,29 @@ import ( type SolvingData struct { assignment WireAssignment circuit gkrtypes.Circuit - workers *utils.WorkerPool maxNbIn int // maximum number of inputs for a gate in the circuit printsByInstance map[uint32][]gkrinfo.PrintInfo } -func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} + +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment + } +} + +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } + d := SolvingData{ - workers: utils.NewWorkerPool(), circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), @@ -44,6 +58,22 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { d.assignment[i] = make([]fr.Element, info.NbInstances) } + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range d.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) + } + } + } + } + return &d } @@ -107,7 +137,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BN254") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index ec7e10e00a..dd6408c718 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -20,13 +20,9 @@ import ( "github.com/consensys/gnark/internal/utils" ) -func modKey(mod *big.Int) string { - return mod.Text(32) -} - type TestEngineHints struct { assignment gkrtypes.WireAssignment - info *gkrinfo.StoringInfo + info *gkrinfo.StoringInfo // we retain a reference to the solving info to allow the caller to modify it between calls to Solve and Prove circuit gkrtypes.Circuit gateIns []frontend.Variable } @@ -82,36 +78,41 @@ func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) e } func (h *TestEngineHints) Prove(mod *big.Int, ins, outs []*big.Int) error { - // todo handle prints - k := modKey(mod) - data, ok := testEngineGkrSolvingData[k] - if !ok { - return errors.New("solving data not found") + + info, err := gkrtypes.StoringToSolvingInfo(*h.info, gkrgates.Get) + if err != nil { + return fmt.Errorf("failed to convert storing info to solving info: %w", err) } - delete(testEngineGkrSolvingData, k) // TODO @Tabaie autogenerate this or decide not to if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { - return bls12377.ProveHint(hashName, data.(*bls12377.SolvingData))(mod, ins, outs) + data := bls12377.NewSolvingData(info, bls12377.WithAssignment(h.assignment)) + return bls12377.ProveHint(info.HashName, data)(mod, ins, outs) } if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { - return bls12381.ProveHint(hashName, data.(*bls12381.SolvingData))(mod, ins, outs) + data := bls12381.NewSolvingData(info, bls12381.WithAssignment(h.assignment)) + return bls12381.ProveHint(info.HashName, data)(mod, ins, outs) } if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { - return bls24315.ProveHint(hashName, data.(*bls24315.SolvingData))(mod, ins, outs) + data := bls24315.NewSolvingData(info, bls24315.WithAssignment(h.assignment)) + return bls24315.ProveHint(info.HashName, data)(mod, ins, outs) } if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { - return bls24317.ProveHint(hashName, data.(*bls24317.SolvingData))(mod, ins, outs) + data := bls24317.NewSolvingData(info, bls24317.WithAssignment(h.assignment)) + return bls24317.ProveHint(info.HashName, data)(mod, ins, outs) } if mod.Cmp(ecc.BN254.ScalarField()) == 0 { - return bn254.ProveHint(hashName, data.(*bn254.SolvingData))(mod, ins, outs) + data := bn254.NewSolvingData(info, bn254.WithAssignment(h.assignment)) + return bn254.ProveHint(info.HashName, data)(mod, ins, outs) } if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { - return bw6633.ProveHint(hashName, data.(*bw6633.SolvingData))(mod, ins, outs) + data := bw6633.NewSolvingData(info, bw6633.WithAssignment(h.assignment)) + return bw6633.ProveHint(info.HashName, data)(mod, ins, outs) } if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { - return bw6761.ProveHint(hashName, data.(*bw6761.SolvingData))(mod, ins, outs) + data := bw6761.NewSolvingData(info, bw6761.WithAssignment(h.assignment)) + return bw6761.ProveHint(info.HashName, data)(mod, ins, outs) } return errors.New("unsupported modulus") diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 32979de0fe..87c863e7ee 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -76,7 +76,12 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio } res.toStore.HashName = fiatshamirHashName - res.hints = gadget.NewTestEngineHints(&res.toStore) + + var err error + res.hints, err = gadget.NewTestEngineHints(&res.toStore) + if err != nil { + panic(fmt.Errorf("failed to create GKR hints: %w", err)) + } for _, opt := range options { opt(&res) From 6c83740a95086e9921c5621b44553dade189f0f6 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 14:13:21 -0500 Subject: [PATCH 15/92] fix package tests --- .../backend/template/gkr/solver_hints.go.tmpl | 41 +++++++++++++++--- internal/gkr/bls12-377/gkr.go | 2 +- internal/gkr/bls12-377/solver_hints.go | 43 ++++++++++++++++--- internal/gkr/bls12-381/gkr.go | 2 +- internal/gkr/bls12-381/solver_hints.go | 43 ++++++++++++++++--- internal/gkr/bls24-315/gkr.go | 2 +- internal/gkr/bls24-315/solver_hints.go | 43 ++++++++++++++++--- internal/gkr/bls24-317/gkr.go | 2 +- internal/gkr/bls24-317/solver_hints.go | 43 ++++++++++++++++--- internal/gkr/bn254/gkr.go | 2 +- internal/gkr/bn254/solver_hints.go | 4 +- internal/gkr/bw6-633/gkr.go | 2 +- internal/gkr/bw6-633/solver_hints.go | 43 ++++++++++++++++--- internal/gkr/bw6-761/gkr.go | 2 +- internal/gkr/bw6-761/solver_hints.go | 43 ++++++++++++++++--- internal/gkr/engine_hints.go | 14 +++--- internal/gkr/gkrinfo/info.go | 9 +--- internal/gkr/gkrtypes/types.go | 3 +- internal/gkr/small_rational/gkr.go | 2 +- std/gkrapi/api_test.go | 4 +- 20 files changed, 275 insertions(+), 74 deletions(-) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index b9f570e26e..d8ff837ced 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -3,7 +3,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" @@ -17,20 +16,34 @@ import ( type SolvingData struct { assignment WireAssignment circuit gkrtypes.Circuit - workers *utils.WorkerPool maxNbIn int // maximum number of inputs for a gate in the circuit printsByInstance map[uint32][]gkrinfo.PrintInfo } -func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} + +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment + } +} + +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } + d := SolvingData{ - workers: utils.NewWorkerPool(), circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), } - d.circuit.SetNbUniqueOutputs() d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) @@ -38,6 +51,22 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { d.assignment[i] = make([]fr.Element, info.NbInstances) } + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range d.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) + } + } + } + } + return &d } @@ -102,7 +131,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_{{.FieldID}}") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index f5dfad020e..b92ac1249d 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index b63ceebb83..d03272551d 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -10,7 +10,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" @@ -24,20 +23,34 @@ import ( type SolvingData struct { assignment WireAssignment circuit gkrtypes.Circuit - workers *utils.WorkerPool maxNbIn int // maximum number of inputs for a gate in the circuit printsByInstance map[uint32][]gkrinfo.PrintInfo } -func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} + +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment + } +} + +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } + d := SolvingData{ - workers: utils.NewWorkerPool(), circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), } - d.circuit.SetNbUniqueOutputs() d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) @@ -45,6 +58,22 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { d.assignment[i] = make([]fr.Element, info.NbInstances) } + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range d.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) + } + } + } + } + return &d } @@ -53,7 +82,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { // TODO use idiomatic printf tag + if !ins[0].IsUint64() { return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } @@ -108,7 +137,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS12_377") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index f5617a59d4..82084049d9 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 93b65be2bf..2c6db19a72 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -10,7 +10,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" @@ -24,20 +23,34 @@ import ( type SolvingData struct { assignment WireAssignment circuit gkrtypes.Circuit - workers *utils.WorkerPool maxNbIn int // maximum number of inputs for a gate in the circuit printsByInstance map[uint32][]gkrinfo.PrintInfo } -func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} + +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment + } +} + +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } + d := SolvingData{ - workers: utils.NewWorkerPool(), circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), } - d.circuit.SetNbUniqueOutputs() d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) @@ -45,6 +58,22 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { d.assignment[i] = make([]fr.Element, info.NbInstances) } + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range d.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) + } + } + } + } + return &d } @@ -53,7 +82,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { // TODO use idiomatic printf tag + if !ins[0].IsUint64() { return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } @@ -108,7 +137,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS12_381") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7d89baf7ef..f182c9176b 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 49317d446c..c67b58605b 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -10,7 +10,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" @@ -24,20 +23,34 @@ import ( type SolvingData struct { assignment WireAssignment circuit gkrtypes.Circuit - workers *utils.WorkerPool maxNbIn int // maximum number of inputs for a gate in the circuit printsByInstance map[uint32][]gkrinfo.PrintInfo } -func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} + +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment + } +} + +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } + d := SolvingData{ - workers: utils.NewWorkerPool(), circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), } - d.circuit.SetNbUniqueOutputs() d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) @@ -45,6 +58,22 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { d.assignment[i] = make([]fr.Element, info.NbInstances) } + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range d.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) + } + } + } + } + return &d } @@ -53,7 +82,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { // TODO use idiomatic printf tag + if !ins[0].IsUint64() { return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } @@ -108,7 +137,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS24_315") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index fc9908b918..a284f14ae9 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index c02c8ab135..9482066d24 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -10,7 +10,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" @@ -24,20 +23,34 @@ import ( type SolvingData struct { assignment WireAssignment circuit gkrtypes.Circuit - workers *utils.WorkerPool maxNbIn int // maximum number of inputs for a gate in the circuit printsByInstance map[uint32][]gkrinfo.PrintInfo } -func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} + +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment + } +} + +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } + d := SolvingData{ - workers: utils.NewWorkerPool(), circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), } - d.circuit.SetNbUniqueOutputs() d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) @@ -45,6 +58,22 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { d.assignment[i] = make([]fr.Element, info.NbInstances) } + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range d.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) + } + } + } + } + return &d } @@ -53,7 +82,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { // TODO use idiomatic printf tag + if !ins[0].IsUint64() { return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } @@ -108,7 +137,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS24_317") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 04cf3512af..14269151b3 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 88d521f531..53079dfd57 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -51,9 +51,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), } - d.circuit.SetNbUniqueOutputs() d.maxNbIn = d.circuit.MaxGateNbIn() + d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { d.assignment[i] = make([]fr.Element, info.NbInstances) } @@ -82,7 +82,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { // TODO use idiomatic printf tag + if !ins[0].IsUint64() { return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index cc1245e726..ec1067f736 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 21bc7961a5..0351a8cbe3 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -10,7 +10,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" @@ -24,20 +23,34 @@ import ( type SolvingData struct { assignment WireAssignment circuit gkrtypes.Circuit - workers *utils.WorkerPool maxNbIn int // maximum number of inputs for a gate in the circuit printsByInstance map[uint32][]gkrinfo.PrintInfo } -func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} + +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment + } +} + +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } + d := SolvingData{ - workers: utils.NewWorkerPool(), circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), } - d.circuit.SetNbUniqueOutputs() d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) @@ -45,6 +58,22 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { d.assignment[i] = make([]fr.Element, info.NbInstances) } + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range d.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) + } + } + } + } + return &d } @@ -53,7 +82,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { // TODO use idiomatic printf tag + if !ins[0].IsUint64() { return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } @@ -108,7 +137,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BW6_633") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index f90f28114b..ad5197feef 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index dacd7eb35f..190d43caf2 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -10,7 +10,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" @@ -24,20 +23,34 @@ import ( type SolvingData struct { assignment WireAssignment circuit gkrtypes.Circuit - workers *utils.WorkerPool maxNbIn int // maximum number of inputs for a gate in the circuit printsByInstance map[uint32][]gkrinfo.PrintInfo } -func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} + +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment + } +} + +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } + d := SolvingData{ - workers: utils.NewWorkerPool(), circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), } - d.circuit.SetNbUniqueOutputs() d.maxNbIn = d.circuit.MaxGateNbIn() d.assignment = make(WireAssignment, len(d.circuit)) @@ -45,6 +58,22 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { d.assignment[i] = make([]fr.Element, info.NbInstances) } + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range d.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) + } + } + } + } + return &d } @@ -53,7 +82,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo) *SolvingData { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { // TODO use idiomatic printf tag + if !ins[0].IsUint64() { return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) } @@ -108,7 +137,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BW6_761") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index dd6408c718..9ee1635a45 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -34,9 +34,10 @@ func NewTestEngineHints(info *gkrinfo.StoringInfo) (*TestEngineHints, error) { } return &TestEngineHints{ - info: info, - circuit: circuit, - gateIns: make([]frontend.Variable, circuit.MaxGateNbIn()), + info: info, + circuit: circuit, + gateIns: make([]frontend.Variable, circuit.MaxGateNbIn()), + assignment: make(gkrtypes.WireAssignment, len(circuit)), }, err } @@ -47,9 +48,10 @@ func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) e // TODO handle prints - if in0 := ins[0].Uint64(); !ins[0].IsUint64() || in0 >= uint64(len(h.info.Circuit)) || in0 > 0xffffffff { + instanceI := len(h.assignment[0]) + if in0 := ins[0].Uint64(); !ins[0].IsUint64() || in0 > 0xffffffff { return errors.New("first input must be a uint32 instance index") - } else if in0 != uint64(h.info.NbInstances) || h.info.NbInstances != len(h.assignment[0]) { + } else if in0 != uint64(instanceI) || h.info.NbInstances-1 != instanceI { return errors.New("first input must equal the number of instances, and calls to Solve must be done in order of instance index") } @@ -65,7 +67,7 @@ func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) e inI++ } else { for gateInI, inWI := range w.Inputs { - h.gateIns[gateInI] = h.assignment[inWI][gateInI] + h.gateIns[gateInI] = h.assignment[inWI][instanceI] } val = w.Gate.Evaluate(api, h.gateIns[:len(w.Inputs)]...) } diff --git a/internal/gkr/gkrinfo/info.go b/internal/gkr/gkrinfo/info.go index f7cb5e589d..c8629f1db1 100644 --- a/internal/gkr/gkrinfo/info.go +++ b/internal/gkr/gkrinfo/info.go @@ -13,9 +13,8 @@ type ( } Wire struct { - Gate string - Inputs []int - NbUniqueOutputs int + Gate string + Inputs []int } Circuit []Wire @@ -46,10 +45,6 @@ func (w Wire) IsInput() bool { return len(w.Inputs) == 0 } -func (w Wire) IsOutput() bool { - return w.NbUniqueOutputs == 0 -} - func (d *StoringInfo) NewInputVariable() int { i := len(d.Circuit) d.Circuit = append(d.Circuit, Wire{}) diff --git a/internal/gkr/gkrtypes/types.go b/internal/gkr/gkrtypes/types.go index 12cdabf3d9..1a07cba46c 100644 --- a/internal/gkr/gkrtypes/types.go +++ b/internal/gkr/gkrtypes/types.go @@ -181,7 +181,7 @@ func (c Circuit) OutputsList() [][]int { return res } -func (c Circuit) SetNbUniqueOutputs() { +func (c Circuit) setNbUniqueOutputs() { for i := range c { c[i].NbUniqueOutputs = 0 @@ -237,6 +237,7 @@ func CircuitInfoToCircuit(info gkrinfo.Circuit, gateGetter func(name gkr.GateNam return nil, fmt.Errorf("gate \"%s\" not found", info[i].Gate) } } + resCircuit.setNbUniqueOutputs() return resCircuit, nil } diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index e8e78f4b96..cdf62359f2 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod small_rational.SmallRational - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/std/gkrapi/api_test.go b/std/gkrapi/api_test.go index 5cd9163ed8..ba8e595527 100644 --- a/std/gkrapi/api_test.go +++ b/std/gkrapi/api_test.go @@ -131,7 +131,7 @@ func (c *mulNoDependencyCircuit) Define(api frontend.API) error { gkrApi := New() x := gkrApi.NewInput() y := gkrApi.NewInput() - z := gkrApi.Add(x, y) + z := gkrApi.Mul(x, y) gkrCircuit := gkrApi.Compile(api, c.hashName) @@ -205,8 +205,8 @@ func (c *mulWithDependencyCircuit) Define(api frontend.API) error { return fmt.Errorf("failed to add instance: %w", err) } + api.AssertIsEqual(instanceOut[z], api.Mul(state, c.Y[i])) state = instanceOut[z] // update state for the next iteration - api.AssertIsEqual(state, api.Mul(state, c.Y[i])) } return nil } From d6e382f500c19582bd858d59fe07451fa4af371e Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 14:16:54 -0500 Subject: [PATCH 16/92] refactor: remove println --- internal/gkr/bls12-381/solver_hints.go | 26 ++++----------------- internal/gkr/gkrinfo/info.go | 9 -------- std/gkrapi/api.go | 32 -------------------------- 3 files changed, 5 insertions(+), 62 deletions(-) diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 2c6db19a72..50d1c2e6c3 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -12,7 +12,6 @@ import ( "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,10 +20,9 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit - printsByInstance map[uint32][]gkrinfo.PrintInfo + assignment WireAssignment + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit } type newSolvingDataSettings struct { @@ -46,9 +44,8 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), - printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -109,19 +106,6 @@ func SolveHint(data *SolvingData) hint.Hint { } } - prints := data.printsByInstance[uint32(instanceI)] - delete(data.printsByInstance, uint32(instanceI)) - for _, p := range prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v - } - } - fmt.Println(serializable...) - } return nil } } diff --git a/internal/gkr/gkrinfo/info.go b/internal/gkr/gkrinfo/info.go index c8629f1db1..690908ea35 100644 --- a/internal/gkr/gkrinfo/info.go +++ b/internal/gkr/gkrinfo/info.go @@ -59,12 +59,3 @@ func (d *StoringInfo) Is() bool { type ConstraintSystem interface { SetGkrInfo(info StoringInfo) error } - -// NewPrintInfoMap partitions printInfo into map elements, indexed by instance -func NewPrintInfoMap(printInfo []PrintInfo) map[uint32][]PrintInfo { - res := make(map[uint32][]PrintInfo) - for i := range printInfo { - res[printInfo[i].Instance] = append(res[printInfo[i].Instance], printInfo[i]) - } - return res -} diff --git a/std/gkrapi/api.go b/std/gkrapi/api.go index 6c6d4b0976..ae3c2b7954 100644 --- a/std/gkrapi/api.go +++ b/std/gkrapi/api.go @@ -58,35 +58,3 @@ func (api *API) Sub(i1, i2 gkr.Variable) gkr.Variable { func (api *API) Mul(i1, i2 gkr.Variable) gkr.Variable { return api.namedGate2PlusIn(gkr.Mul2, i1, i2) } - -// Println writes to the standard output. -// instance determines which values are chosen for gkr.Variable input. -func (api *API) Println(instance int, a ...any) { - api.toStore.Prints = append(api.toStore.Prints, newPrint(instance, a...)) -} - -// Println writes to the standard output. -// instance determines which values are chosen for gkr.Variable input. -func (c *Circuit) Println(instance int, a ...any) { - c.toStore.Prints = append(c.toStore.Prints, newPrint(instance, a...)) -} - -func newPrint(instance int, a ...any) gkrinfo.PrintInfo { - isVar := make([]bool, len(a)) - vals := make([]any, len(a)) - for i := range a { - v, ok := a[i].(gkr.Variable) - isVar[i] = ok - if ok { - vals[i] = uint32(v) - } else { - vals[i] = a[i] - } - } - - return gkrinfo.PrintInfo{ - Values: vals, - Instance: uint32(instance), - IsGkrVar: isVar, - } -} From a8f30a43d42a90831aa3d03bdee4ad3f2a87777b Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 14:18:11 -0500 Subject: [PATCH 17/92] chore: generify print removal --- .../backend/template/gkr/solver_hints.go.tmpl | 16 ------------ internal/gkr/bls12-377/solver_hints.go | 26 ++++--------------- internal/gkr/bls24-315/solver_hints.go | 26 ++++--------------- internal/gkr/bls24-317/solver_hints.go | 26 ++++--------------- internal/gkr/bn254/solver_hints.go | 26 ++++--------------- internal/gkr/bw6-633/solver_hints.go | 26 ++++--------------- internal/gkr/bw6-761/solver_hints.go | 26 ++++--------------- 7 files changed, 30 insertions(+), 142 deletions(-) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index d8ff837ced..7892021532 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -5,7 +5,6 @@ import ( "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -17,7 +16,6 @@ type SolvingData struct { assignment WireAssignment circuit gkrtypes.Circuit maxNbIn int // maximum number of inputs for a gate in the circuit - printsByInstance map[uint32][]gkrinfo.PrintInfo } type newSolvingDataSettings struct { @@ -41,7 +39,6 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) d := SolvingData{ circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), - printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -103,19 +100,6 @@ func SolveHint(data *SolvingData) hint.Hint { } } - prints := data.printsByInstance[uint32(instanceI)] - delete(data.printsByInstance, uint32(instanceI)) - for _, p := range prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v - } - } - fmt.Println(serializable...) - } return nil } } diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index d03272551d..e018ce2726 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -12,7 +12,6 @@ import ( "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,10 +20,9 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit - printsByInstance map[uint32][]gkrinfo.PrintInfo + assignment WireAssignment + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit } type newSolvingDataSettings struct { @@ -46,9 +44,8 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), - printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -109,19 +106,6 @@ func SolveHint(data *SolvingData) hint.Hint { } } - prints := data.printsByInstance[uint32(instanceI)] - delete(data.printsByInstance, uint32(instanceI)) - for _, p := range prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v - } - } - fmt.Println(serializable...) - } return nil } } diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index c67b58605b..285ca7b9f9 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -12,7 +12,6 @@ import ( "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,10 +20,9 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit - printsByInstance map[uint32][]gkrinfo.PrintInfo + assignment WireAssignment + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit } type newSolvingDataSettings struct { @@ -46,9 +44,8 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), - printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -109,19 +106,6 @@ func SolveHint(data *SolvingData) hint.Hint { } } - prints := data.printsByInstance[uint32(instanceI)] - delete(data.printsByInstance, uint32(instanceI)) - for _, p := range prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v - } - } - fmt.Println(serializable...) - } return nil } } diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index 9482066d24..b6e9047533 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -12,7 +12,6 @@ import ( "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,10 +20,9 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit - printsByInstance map[uint32][]gkrinfo.PrintInfo + assignment WireAssignment + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit } type newSolvingDataSettings struct { @@ -46,9 +44,8 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), - printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -109,19 +106,6 @@ func SolveHint(data *SolvingData) hint.Hint { } } - prints := data.printsByInstance[uint32(instanceI)] - delete(data.printsByInstance, uint32(instanceI)) - for _, p := range prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v - } - } - fmt.Println(serializable...) - } return nil } } diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 53079dfd57..50d7c11364 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -12,7 +12,6 @@ import ( "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,10 +20,9 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit - printsByInstance map[uint32][]gkrinfo.PrintInfo + assignment WireAssignment + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit } type newSolvingDataSettings struct { @@ -46,9 +44,8 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), - printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -109,19 +106,6 @@ func SolveHint(data *SolvingData) hint.Hint { } } - prints := data.printsByInstance[uint32(instanceI)] - delete(data.printsByInstance, uint32(instanceI)) - for _, p := range prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v - } - } - fmt.Println(serializable...) - } return nil } } diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 0351a8cbe3..65813d9ac0 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -12,7 +12,6 @@ import ( "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,10 +20,9 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit - printsByInstance map[uint32][]gkrinfo.PrintInfo + assignment WireAssignment + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit } type newSolvingDataSettings struct { @@ -46,9 +44,8 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), - printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -109,19 +106,6 @@ func SolveHint(data *SolvingData) hint.Hint { } } - prints := data.printsByInstance[uint32(instanceI)] - delete(data.printsByInstance, uint32(instanceI)) - for _, p := range prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v - } - } - fmt.Println(serializable...) - } return nil } } diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 190d43caf2..7971b5540a 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -12,7 +12,6 @@ import ( "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" algo_utils "github.com/consensys/gnark/internal/utils" @@ -21,10 +20,9 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit - printsByInstance map[uint32][]gkrinfo.PrintInfo + assignment WireAssignment + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit } type newSolvingDataSettings struct { @@ -46,9 +44,8 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), - printsByInstance: gkrinfo.NewPrintInfoMap(info.Prints), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -109,19 +106,6 @@ func SolveHint(data *SolvingData) hint.Hint { } } - prints := data.printsByInstance[uint32(instanceI)] - delete(data.printsByInstance, uint32(instanceI)) - for _, p := range prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v - } - } - fmt.Println(serializable...) - } return nil } } From 8ac5f9f2b9ef5b81530377995fdf70c8e9bdde15 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 14:49:22 -0500 Subject: [PATCH 18/92] feat: GetValue --- constraint/bls12-377/solver.go | 1 + constraint/bls12-381/solver.go | 1 + constraint/bls24-315/solver.go | 1 + constraint/bls24-317/solver.go | 1 + constraint/bn254/solver.go | 1 + constraint/bw6-633/solver.go | 1 + constraint/bw6-761/solver.go | 1 + .../backend/template/gkr/solver_hints.go.tmpl | 17 ++++ .../template/representations/solver.go.tmpl | 1 + internal/gkr/bls12-377/solver_hints.go | 17 ++++ internal/gkr/bls12-381/solver_hints.go | 17 ++++ internal/gkr/bls24-315/solver_hints.go | 17 ++++ internal/gkr/bls24-317/solver_hints.go | 17 ++++ internal/gkr/bn254/solver_hints.go | 17 ++++ internal/gkr/bw6-633/solver_hints.go | 17 ++++ internal/gkr/bw6-761/solver_hints.go | 17 ++++ internal/gkr/engine_hints.go | 11 +++ internal/gkr/gkrinfo/info.go | 17 ++-- internal/gkr/gkrtypes/types.go | 2 - std/gkrapi/api_test.go | 88 +++++++++---------- std/gkrapi/compile.go | 12 +++ 21 files changed, 216 insertions(+), 58 deletions(-) diff --git a/constraint/bls12-377/solver.go b/constraint/bls12-377/solver.go index 206fea5702..f57abfae52 100644 --- a/constraint/bls12-377/solver.go +++ b/constraint/bls12-377/solver.go @@ -57,6 +57,7 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, } gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } diff --git a/constraint/bls12-381/solver.go b/constraint/bls12-381/solver.go index 4e7835fef1..67f12ef2aa 100644 --- a/constraint/bls12-381/solver.go +++ b/constraint/bls12-381/solver.go @@ -57,6 +57,7 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, } gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } diff --git a/constraint/bls24-315/solver.go b/constraint/bls24-315/solver.go index 5dc3fc2ef7..05d4c6f11c 100644 --- a/constraint/bls24-315/solver.go +++ b/constraint/bls24-315/solver.go @@ -57,6 +57,7 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, } gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } diff --git a/constraint/bls24-317/solver.go b/constraint/bls24-317/solver.go index f007d16494..29af5c28b2 100644 --- a/constraint/bls24-317/solver.go +++ b/constraint/bls24-317/solver.go @@ -57,6 +57,7 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, } gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } diff --git a/constraint/bn254/solver.go b/constraint/bn254/solver.go index a7674e9542..5e9b70c548 100644 --- a/constraint/bn254/solver.go +++ b/constraint/bn254/solver.go @@ -57,6 +57,7 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, } gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } diff --git a/constraint/bw6-633/solver.go b/constraint/bw6-633/solver.go index f294f7d826..7fc43652f6 100644 --- a/constraint/bw6-633/solver.go +++ b/constraint/bw6-633/solver.go @@ -57,6 +57,7 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, } gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } diff --git a/constraint/bw6-761/solver.go b/constraint/bw6-761/solver.go index e10f2f21d7..d226b03a53 100644 --- a/constraint/bw6-761/solver.go +++ b/constraint/bw6-761/solver.go @@ -57,6 +57,7 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, } gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index 7892021532..cd89a9ffff 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -70,6 +70,23 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) // this module assumes that wire and instance indexes respect dependencies +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + if !ins[0].IsUint64() || !ins[1].IsUint64() { + return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) + } + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} + func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() diff --git a/internal/generator/backend/template/representations/solver.go.tmpl b/internal/generator/backend/template/representations/solver.go.tmpl index 202642b87a..ddb0b7428c 100644 --- a/internal/generator/backend/template/representations/solver.go.tmpl +++ b/internal/generator/backend/template/representations/solver.go.tmpl @@ -49,6 +49,7 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, } gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index e018ce2726..43384769ac 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -76,6 +76,23 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) // this module assumes that wire and instance indexes respect dependencies +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + if !ins[0].IsUint64() || !ins[1].IsUint64() { + return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) + } + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} + func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 50d1c2e6c3..bcf449e99c 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -76,6 +76,23 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) // this module assumes that wire and instance indexes respect dependencies +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + if !ins[0].IsUint64() || !ins[1].IsUint64() { + return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) + } + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} + func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 285ca7b9f9..6f254299a0 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -76,6 +76,23 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) // this module assumes that wire and instance indexes respect dependencies +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + if !ins[0].IsUint64() || !ins[1].IsUint64() { + return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) + } + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} + func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index b6e9047533..6121504a68 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -76,6 +76,23 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) // this module assumes that wire and instance indexes respect dependencies +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + if !ins[0].IsUint64() || !ins[1].IsUint64() { + return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) + } + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} + func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 50d7c11364..6bd76425a1 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -76,6 +76,23 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) // this module assumes that wire and instance indexes respect dependencies +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + if !ins[0].IsUint64() || !ins[1].IsUint64() { + return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) + } + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} + func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 65813d9ac0..3678eab8bf 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -76,6 +76,23 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) // this module assumes that wire and instance indexes respect dependencies +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + if !ins[0].IsUint64() || !ins[1].IsUint64() { + return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) + } + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} + func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 7971b5540a..baa1303132 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -76,6 +76,23 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) // this module assumes that wire and instance indexes respect dependencies +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + if !ins[0].IsUint64() || !ins[1].IsUint64() { + return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) + } + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} + func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 9ee1635a45..3f5884bc13 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -118,7 +118,18 @@ func (h *TestEngineHints) Prove(mod *big.Int, ins, outs []*big.Int) error { } return errors.New("unsupported modulus") +} +// GetAssignment returns the assignment for a particular wire and instance. +func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big.Int) error { + if len(ins) != 3 || !ins[0].IsUint64() || !ins[1].IsUint64() { + return errors.New("expected 3 inputs: wire index, instance index, and dummy output from the same instance") + } + if len(outs) != 1 { + return errors.New("expected 1 output: the value of the wire at the given instance") + } + *outs[0] = utils.FromInterface(h.assignment[ins[0].Uint64()][ins[1].Uint64()]) + return nil } type gateAPI struct{ *big.Int } diff --git a/internal/gkr/gkrinfo/info.go b/internal/gkr/gkrinfo/info.go index 690908ea35..6d14e37dfb 100644 --- a/internal/gkr/gkrinfo/info.go +++ b/internal/gkr/gkrinfo/info.go @@ -19,18 +19,13 @@ type ( Circuit []Wire - PrintInfo struct { - Values []any - Instance uint32 - IsGkrVar []bool - } StoringInfo struct { - Circuit Circuit - NbInstances int - HashName string - SolveHintID solver.HintID - ProveHintID solver.HintID - Prints []PrintInfo + Circuit Circuit + NbInstances int + HashName string + GetAssignmentHintID solver.HintID + SolveHintID solver.HintID + ProveHintID solver.HintID } Permutations struct { diff --git a/internal/gkr/gkrtypes/types.go b/internal/gkr/gkrtypes/types.go index 1a07cba46c..3226f1aa2f 100644 --- a/internal/gkr/gkrtypes/types.go +++ b/internal/gkr/gkrtypes/types.go @@ -150,7 +150,6 @@ type SolvingInfo struct { Circuit Circuit NbInstances int HashName string - Prints []gkrinfo.PrintInfo } // OutputsList for each wire, returns the set of indexes of wires it is input to. @@ -247,7 +246,6 @@ func StoringToSolvingInfo(info gkrinfo.StoringInfo, gateGetter func(name gkr.Gat Circuit: circuit, NbInstances: info.NbInstances, HashName: info.HashName, - Prints: info.Prints, }, err } diff --git a/std/gkrapi/api_test.go b/std/gkrapi/api_test.go index ba8e595527..3b2be818bd 100644 --- a/std/gkrapi/api_test.go +++ b/std/gkrapi/api_test.go @@ -1,7 +1,6 @@ package gkrapi import ( - "bytes" "fmt" "hash" "math/big" @@ -26,7 +25,7 @@ import ( "github.com/stretchr/testify/require" ) -// compressThreshold --> if linear expressions are larger than this, the frontend will introduce +// compressThreshold → if linear expressions are larger than this, the frontend will introduce // intermediate constraints. The lower this number is, the faster compile time should be (to a point) // but resulting circuit will have more constraints (slower proving time). const compressThreshold = 1000 @@ -358,7 +357,7 @@ func mimcGate(api gkr.GateAPI, input ...frontend.Variable) frontend.Variable { } sum := api.Add(input[0], input[1] /*, m.Ark*/) - sumCubed := api.Mul(sum, sum, sum) // sum^3 + sumCubed := api.Mul(sum, sum, sum) // sum³ return api.Mul(sumCubed, sumCubed, sum) } @@ -522,12 +521,6 @@ func mimcNoGkrCircuits(mimcDepth, nbInstances int) (circuit, assignment frontend return } -func panicIfError(err error) { - if err != nil { - panic(err) - } -} - func assertSliceEqual[T comparable](t *testing.T, expected, seen []T) { assert.Equal(t, len(expected), len(seen)) for i := range seen { @@ -549,7 +542,7 @@ func (m MiMCCipherGate) Evaluate(api frontend.API, input ...frontend.Variable) f } sum := api.Add(input[0], input[1], m.Ark) - sumCubed := api.Mul(sum, sum, sum) // sum^3 + sumCubed := api.Mul(sum, sum, sum) // sum³ return api.Mul(sumCubed, sumCubed, sum) } @@ -602,46 +595,51 @@ func init() { } } -func ExamplePrintln() { +// pow3Circuit computes x⁴ and also checks the correctness of intermediate value x². +// This is to demonstrate the use of [Circuit.GetValue] and should not be done +// in production code, as it negates the performance benefits of using GKR in the first place. +type pow4Circuit struct { + X []frontend.Variable +} - circuit := &mulNoDependencyCircuit{ - X: make([]frontend.Variable, 2), - Y: make([]frontend.Variable, 2), - hashName: "MIMC", - } +func (c *pow4Circuit) Define(api frontend.API) error { + gkrApi := New() + x := gkrApi.NewInput() + x2 := gkrApi.Mul(x, x) // x² + x4 := gkrApi.Mul(x2, x2) // x⁴ - assignment := &mulNoDependencyCircuit{ - X: []frontend.Variable{10, 11}, - Y: []frontend.Variable{12, 13}, - } + gkrCircuit := gkrApi.Compile(api, "MIMC") + + for i := range c.X { + instanceIn := make(map[gkr.Variable]frontend.Variable) + instanceIn[x] = c.X[i] + + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + + api.AssertIsEqual(gkrCircuit.GetValue(x, i), c.X[i]) // x + + v := api.Mul(c.X[i], c.X[i]) // x² + api.AssertIsEqual(gkrCircuit.GetValue(x2, i), v) // x² - field := ecc.BN254.ScalarField() + v = api.Mul(v, v) // x⁴ + api.AssertIsEqual(gkrCircuit.GetValue(x4, i), v) // x⁴ + api.AssertIsEqual(instanceOut[x4], v) // x⁴ + } - // with test engine - err := test.IsSolved(circuit, assignment, field) - panicIfError(err) + return nil +} - // with groth16 / serialized CS - firstCs, err := frontend.Compile(field, r1cs.NewBuilder, circuit) - panicIfError(err) +func TestPow4Circuit_GetValue(t *testing.T) { + assignment := pow4Circuit{ + X: []frontend.Variable{1, 2, 3, 4, 5}, + } - var bb bytes.Buffer - _, err = firstCs.WriteTo(&bb) - panicIfError(err) - cs := groth16.NewCS(ecc.BN254) - _, err = cs.ReadFrom(&bb) - panicIfError(err) + circuit := pow4Circuit{ + X: make([]frontend.Variable, len(assignment.X)), + } - pk, _, err := groth16.Setup(cs) - panicIfError(err) - w, err := frontend.NewWitness(assignment, field) - panicIfError(err) - _, err = groth16.Prove(cs, pk, w) - panicIfError(err) - - // Output: - // values of x and y in instance number 0 10 12 - // value of z in instance number 1 143 - // values of x and y in instance number 0 10 12 - // value of z in instance number 1 143 + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) } diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 87c863e7ee..4cd6c804e5 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -104,6 +104,7 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio } res.toStore.ProveHintID = solver.GetHintID(res.hints.Prove) + res.toStore.GetAssignmentHintID = solver.GetHintID(res.hints.GetAssignment) parentApi.Compiler().Defer(res.verify) @@ -257,3 +258,14 @@ func init() { return &h, err }) } + +// GetValue is a debugging utility returning the value of variable v at instance i. +// While v can be an input or output variable, GetValue is most useful for querying intermediate values in the circuit. +func (c *Circuit) GetValue(v gkr.Variable, i int) frontend.Variable { + // last input to ensure the solver's work is done before GetAssignment is called + res, err := c.api.Compiler().NewHint(c.hints.GetAssignment, 1, int(v), i, c.assignments[c.outs[0]][i]) + if err != nil { + panic(err) + } + return res[0] +} From 5bb879dfbce7477b790ad9b6d4fe8b355c8d6641 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 14:51:59 -0500 Subject: [PATCH 19/92] fix all gkrapi tests pass --- internal/gkr/engine_hints.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 3f5884bc13..7e60a21547 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -46,8 +46,6 @@ func NewTestEngineHints(info *gkrinfo.StoringInfo) (*TestEngineHints, error) { // The first input is the index of the instance. The rest are the inputs of the circuit, in their nominal order. func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) error { - // TODO handle prints - instanceI := len(h.assignment[0]) if in0 := ins[0].Uint64(); !ins[0].IsUint64() || in0 > 0xffffffff { return errors.New("first input must be a uint32 instance index") @@ -73,6 +71,7 @@ func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) e } if w.IsOutput() { *outs[outI] = utils.FromInterface(val) + outI++ } h.assignment[wI] = append(h.assignment[wI], val) } @@ -80,7 +79,6 @@ func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) e } func (h *TestEngineHints) Prove(mod *big.Int, ins, outs []*big.Int) error { - // todo handle prints info, err := gkrtypes.StoringToSolvingInfo(*h.info, gkrgates.Get) if err != nil { From 7f9a0f1aaec5beb09a6d776afc972acc5d10c088 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 15:24:14 -0500 Subject: [PATCH 20/92] fix gkr-poseidon2 --- .../poseidon2/gkr-poseidon2/gkr.go | 178 +++++------------- .../poseidon2/gkr-poseidon2/gkr_test.go | 4 +- 2 files changed, 52 insertions(+), 130 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index 330efdd589..f05836a798 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -1,24 +1,16 @@ package gkr_poseidon2 import ( - "errors" "fmt" - "math/big" "sync" "github.com/consensys/gnark/constraint/solver/gkrgates" - "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark-crypto/ecc" - frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" poseidon2Bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" - stdHash "github.com/consensys/gnark/std/hash" - "github.com/consensys/gnark/std/hash/mimc" ) // extKeyGate applies the external matrix mul, then adds the round key @@ -117,62 +109,67 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { } type GkrCompressions struct { - api frontend.API - ins1 []frontend.Variable - ins2 []frontend.Variable - outs []frontend.Variable + api frontend.API + gkrCircuit *gkrapi.Circuit + in1, in2, out gkr.Variable } -// NewGkrCompressions returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) +// NewGkrCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. -// Note that the solver will need the function RegisterGkrSolverOptions to be called with the desired curves -func NewGkrCompressions(api frontend.API) *GkrCompressions { - res := GkrCompressions{ - api: api, +// Note that the solver will need the function RegisterGkrGates to be called with the desired curves +func NewGkrCompressor(api frontend.API) *GkrCompressions { + if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { + panic("currently only BL12-377 is supported") + } + gkrApi, in1, in2, out, err := defineCircuitBls12377() + if err != nil { + panic(fmt.Errorf("failed to define GKR circuit: %v", err)) + } + return &GkrCompressions{ + api: api, + gkrCircuit: gkrApi.Compile(api, "MIMC"), + in1: in1, + in2: in2, + out: out, } - api.Compiler().Defer(res.finalize) - return &res } func (p *GkrCompressions) Compress(a, b frontend.Variable) frontend.Variable { - s, err := p.api.Compiler().NewHint(permuteHint, 1, a, b) + outs, err := p.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{p.in1: a, p.in2: b}) if err != nil { panic(err) } - p.ins1 = append(p.ins1, a) - p.ins2 = append(p.ins2, b) - p.outs = append(p.outs, s[0]) - return s[0] + + return outs[p.out] } -// defineCircuit defines the GKR circuit for the Poseidon2 permutation over BLS12-377 +// defineCircuitBls12377 defines the GKR circuit for the Poseidon2 permutation over BLS12-377 // insLeft and insRight are the inputs to the permutation // they must be padded to a power of 2 -func defineCircuit(insLeft, insRight []frontend.Variable) (*gkrapi.API, gkr.Variable, error) { +func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, err error) { // variable indexes const ( xI = iota yI ) + if err = registerGatesBls12377(); err != nil { + return + } + // poseidon2 parameters gateNamer := newRoundGateNamer(poseidon2Bls12377.GetDefaultParameters()) rF := poseidon2Bls12377.GetDefaultParameters().NbFullRounds rP := poseidon2Bls12377.GetDefaultParameters().NbPartialRounds halfRf := rF / 2 - gkrApi := gkrapi.New() + gkrApi = gkrapi.New() - x, err := gkrApi.Import(insLeft) - if err != nil { - return nil, -1, err - } - y, err := gkrApi.Import(insRight) - y0 := y // save to feed forward at the end - if err != nil { - return nil, -1, err - } + x := gkrApi.NewInput() + y := gkrApi.NewInput() + + in1, in2 = x, y // save to feed forward at the end // *** helper functions to register and apply gates *** @@ -241,86 +238,9 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkrapi.API, gkr.Vari } // apply the external matrix one last time to obtain the final value of y - y = gkrApi.NamedGate(gateNamer.linear(yI, rP+rF), y, x, y0) + out = gkrApi.NamedGate(gateNamer.linear(yI, rP+rF), y, x, in2) - return gkrApi, y, nil -} - -func (p *GkrCompressions) finalize(api frontend.API) error { - if p.api != api { - panic("unexpected API") - } - - // register MiMC to be used as a random oracle in the GKR proof - stdHash.Register("MIMC", func(api frontend.API) (stdHash.FieldHasher, error) { - m, err := mimc.NewMiMC(api) - return &m, err - }) - - // register gates - registerGkrSolverOptions(api) - - // pad instances into a power of 2 - // TODO @Tabaie the GKR API to do this automatically? - ins1Padded := make([]frontend.Variable, ecc.NextPowerOfTwo(uint64(len(p.ins1)))) - ins2Padded := make([]frontend.Variable, len(ins1Padded)) - copy(ins1Padded, p.ins1) - copy(ins2Padded, p.ins2) - for i := len(p.ins1); i < len(ins1Padded); i++ { - ins1Padded[i] = 0 - ins2Padded[i] = 0 - } - - gkrApi, y, err := defineCircuit(ins1Padded, ins2Padded) - if err != nil { - return err - } - - // connect to output - // TODO can we save 1 constraint per instance by giving the desired outputs to the gkr api? - solution, err := gkrApi.Solve(api) - if err != nil { - return err - } - yVals := solution.Export(y) - for i := range p.outs { - api.AssertIsEqual(yVals[i], p.outs[i]) - } - - // verify GKR proof - allVals := make([]frontend.Variable, 0, 3*len(p.ins1)) - allVals = append(allVals, p.ins1...) - allVals = append(allVals, p.ins2...) - allVals = append(allVals, p.outs...) - challenge, err := p.api.(frontend.Committer).Commit(allVals...) - if err != nil { - return err - } - return solution.Verify("MIMC", challenge) -} - -// registerGkrSolverOptions is a wrapper for RegisterGkrSolverOptions -// that performs the registration for the curve associated with api. -func registerGkrSolverOptions(api frontend.API) { - RegisterGkrSolverOptions(utils.FieldToCurve(api.Compiler().Field())) -} - -func permuteHint(m *big.Int, ins, outs []*big.Int) error { - if m.Cmp(ecc.BLS12_377.ScalarField()) != 0 { - return errors.New("only bls12-377 supported") - } - if len(ins) != 2 || len(outs) != 1 { - return errors.New("expected 2 inputs and 1 output") - } - var x [2]frBls12377.Element - x[0].SetBigInt(ins[0]) - x[1].SetBigInt(ins[1]) - y0 := x[1] - - err := bls12377Permutation().Permutation(x[:]) - x[1].Add(&x[1], &y0) // feed forward - x[1].BigInt(outs[0]) - return err + return } var bls12377Permutation = sync.OnceValue(func() *poseidon2Bls12377.Permutation { @@ -328,16 +248,15 @@ var bls12377Permutation = sync.OnceValue(func() *poseidon2Bls12377.Permutation { return poseidon2Bls12377.NewPermutation(2, params.NbFullRounds, params.NbPartialRounds) // TODO @Tabaie add NewDefaultPermutation to gnark-crypto }) -// RegisterGkrSolverOptions registers the GKR gates corresponding to the given curves for the solver -func RegisterGkrSolverOptions(curves ...ecc.ID) { +// RegisterGkrGates registers the GKR gates corresponding to the given curves for the solver +func RegisterGkrGates(curves ...ecc.ID) { if len(curves) == 0 { panic("expected at least one curve") } - solver.RegisterHint(permuteHint) for _, curve := range curves { switch curve { case ecc.BLS12_377: - if err := registerGkrGatesBls12377(); err != nil { + if err := registerGatesBls12377(); err != nil { panic(err) } default: @@ -346,7 +265,7 @@ func RegisterGkrSolverOptions(curves ...ecc.ID) { } } -func registerGkrGatesBls12377() error { +func registerGatesBls12377() error { const ( x = iota y @@ -356,29 +275,31 @@ func registerGkrGatesBls12377() error { halfRf := p.NbFullRounds / 2 gateNames := newRoundGateNamer(p) - if err := gkrgates.Register(pow2Gate, 1, gkrgates.WithUnverifiedDegree(2), gkrgates.WithNoSolvableVar()); err != nil { + if _, err := gkrgates.Register(pow2Gate, 1, gkrgates.WithUnverifiedDegree(2), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(pow4Gate, 1, gkrgates.WithUnverifiedDegree(4), gkrgates.WithNoSolvableVar()); err != nil { + if _, err := gkrgates.Register(pow4Gate, 1, gkrgates.WithUnverifiedDegree(4), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(pow2TimesGate, 2, gkrgates.WithUnverifiedDegree(3), gkrgates.WithNoSolvableVar()); err != nil { + if _, err := gkrgates.Register(pow2TimesGate, 2, gkrgates.WithUnverifiedDegree(3), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(pow4TimesGate, 2, gkrgates.WithUnverifiedDegree(5), gkrgates.WithNoSolvableVar()); err != nil { + if _, err := gkrgates.Register(pow4TimesGate, 2, gkrgates.WithUnverifiedDegree(5), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(intGate2, 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0)); err != nil { + if _, err := gkrgates.Register(intGate2, 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } extKeySBox := func(round int, varIndex int) error { - return gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round))) + _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(ecc.BLS12_377)) + return err } intKeySBox2 := func(round int) error { - return gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round))) + _, err := gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(ecc.BLS12_377)) + return err } fullRound := func(i int) error { @@ -422,7 +343,8 @@ func registerGkrGatesBls12377() error { } } - return gkrgates.Register(extAddGate, 3, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, p.NbPartialRounds+p.NbFullRounds))) + _, err := gkrgates.Register(extAddGate, 3, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, p.NbPartialRounds+p.NbFullRounds)), gkrgates.WithCurves(ecc.BLS12_377)) + return err } type roundGateNamer string diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 1503054a59..22e9def87f 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -37,7 +37,7 @@ func TestGkrCompression(t *testing.T) { Outs: outs, } - RegisterGkrSolverOptions(ecc.BLS12_377) + RegisterGkrGates(ecc.BLS12_377) test.NewAssert(t).CheckCircuit(&testGkrPermutationCircuit{Ins: make([][2]frontend.Variable, len(ins)), Outs: make([]frontend.Variable, len(outs))}, test.WithValidAssignment(&circuit), test.WithCurves(ecc.BLS12_377)) } @@ -49,7 +49,7 @@ type testGkrPermutationCircuit struct { func (c *testGkrPermutationCircuit) Define(api frontend.API) error { - pos2 := NewGkrCompressions(api) + pos2 := NewGkrCompressor(api) api.AssertIsEqual(len(c.Ins), len(c.Outs)) for i := range c.Ins { api.AssertIsEqual(c.Outs[i], pos2.Compress(c.Ins[i][0], c.Ins[i][1])) From cfdb9d3d46730d1774fcd7dea93f0b78da1a023d Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 15:48:15 -0500 Subject: [PATCH 21/92] fix: reduce in test engine --- internal/gkr/engine_hints.go | 2 ++ std/permutation/poseidon2/gkr-poseidon2/gkr.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 7e60a21547..7968c12793 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -147,6 +147,7 @@ func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { x, y := utils.FromInterface(b), utils.FromInterface(c) x.Mul(&x, &y) + x.Mod(&x, g.Int) // reduce y = utils.FromInterface(a) x.Add(&x, &y) return &x @@ -177,6 +178,7 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend y = utils.FromInterface(v) x.Mul(&x, &y) } + x.Mod(&x, g.Int) // reduce return &x } diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index f05836a798..04ca60b49a 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -38,7 +38,7 @@ func pow4Gate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { // pow4TimesGate computes a, b -> a⁴ * b func pow4TimesGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { - panic("expected 1 input") + panic("expected 2 input") } y := api.Mul(x[0], x[0]) y = api.Mul(y, y) From 12788f34d416474fb4229ba7dc481830d3a41067 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 16:01:00 -0500 Subject: [PATCH 22/92] fix: rename GkrCompressions -> GkrPermutations --- std/permutation/poseidon2/gkr-poseidon2/gkr.go | 10 +++++----- std/permutation/poseidon2/gkr-poseidon2/gkr_test.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index 04ca60b49a..d9a7dcfbbb 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -108,17 +108,17 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Add(api.Mul(x[0], 2), x[1], x[2]) } -type GkrCompressions struct { +type GkrPermutations struct { api frontend.API gkrCircuit *gkrapi.Circuit in1, in2, out gkr.Variable } -// NewGkrCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) +// NewGkrPermutations returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. // Note that the solver will need the function RegisterGkrGates to be called with the desired curves -func NewGkrCompressor(api frontend.API) *GkrCompressions { +func NewGkrPermutations(api frontend.API) *GkrPermutations { if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { panic("currently only BL12-377 is supported") } @@ -126,7 +126,7 @@ func NewGkrCompressor(api frontend.API) *GkrCompressions { if err != nil { panic(fmt.Errorf("failed to define GKR circuit: %v", err)) } - return &GkrCompressions{ + return &GkrPermutations{ api: api, gkrCircuit: gkrApi.Compile(api, "MIMC"), in1: in1, @@ -135,7 +135,7 @@ func NewGkrCompressor(api frontend.API) *GkrCompressions { } } -func (p *GkrCompressions) Compress(a, b frontend.Variable) frontend.Variable { +func (p *GkrPermutations) Compress(a, b frontend.Variable) frontend.Variable { outs, err := p.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{p.in1: a, p.in2: b}) if err != nil { panic(err) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 22e9def87f..1562aac070 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -49,7 +49,7 @@ type testGkrPermutationCircuit struct { func (c *testGkrPermutationCircuit) Define(api frontend.API) error { - pos2 := NewGkrCompressor(api) + pos2 := NewGkrPermutations(api) api.AssertIsEqual(len(c.Ins), len(c.Outs)) for i := range c.Ins { api.AssertIsEqual(c.Outs[i], pos2.Compress(c.Ins[i][0], c.Ins[i][1])) From 9a0bf0e7d171fbfcca15a6975c847b01701ae824 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 16:15:01 -0500 Subject: [PATCH 23/92] bench: gkrposeidon2 --- .../poseidon2/gkr-poseidon2/gkr_test.go | 46 +++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 1562aac070..7d64aedb92 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -2,6 +2,8 @@ package gkr_poseidon2 import ( "fmt" + "os" + "runtime/pprof" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -12,8 +14,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestGkrCompression(t *testing.T) { - const n = 2 +func gkrPermutationsCircuits(t require.TestingT, n int) (circuit, assignment testGkrPermutationCircuit) { var k int64 ins := make([][2]frontend.Variable, n) outs := make([]frontend.Variable, n) @@ -32,14 +33,19 @@ func TestGkrCompression(t *testing.T) { k += 2 } - circuit := testGkrPermutationCircuit{ - Ins: ins, - Outs: outs, - } + return testGkrPermutationCircuit{ + Ins: make([][2]frontend.Variable, len(ins)), + Outs: make([]frontend.Variable, len(outs)), + }, testGkrPermutationCircuit{ + Ins: ins, + Outs: outs, + } +} - RegisterGkrGates(ecc.BLS12_377) +func TestGkrCompression(t *testing.T) { + circuit, assignment := gkrPermutationsCircuits(t, 2) - test.NewAssert(t).CheckCircuit(&testGkrPermutationCircuit{Ins: make([][2]frontend.Variable, len(ins)), Outs: make([]frontend.Variable, len(outs))}, test.WithValidAssignment(&circuit), test.WithCurves(ecc.BLS12_377)) + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BLS12_377)) } type testGkrPermutationCircuit struct { @@ -67,3 +73,27 @@ func TestGkrPermutationCompiles(t *testing.T) { require.NoError(t, err) fmt.Println(cs.GetNbConstraints(), "constraints") } + +func BenchmarkGkrPermutations(b *testing.B) { + circuit, assignmment := gkrPermutationsCircuits(b, 50000) + + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(b, err) + + witness, err := frontend.NewWitness(&assignmment, ecc.BLS12_377.ScalarField()) + require.NoError(b, err) + + // cpu profile + f, err := os.Create("cpu.pprof") + require.NoError(b, err) + defer func() { + require.NoError(b, f.Close()) + }() + + err = pprof.StartCPUProfile(f) + require.NoError(b, err) + defer pprof.StopCPUProfile() + + _, err = cs.Solve(witness) + require.NoError(b, err) +} From 618beffd04333607dd74f631d881cbb20707a435 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 18:54:39 -0500 Subject: [PATCH 24/92] fix pad for bls12377 --- internal/gkr/bls12-377/solver_hints.go | 9 +++-- internal/gkr/engine_hints.go | 2 +- internal/gkr/gkrtypes/types.go | 1 + internal/utils/slices.go | 12 +++++++ internal/utils/slices_test.go | 25 ++++++++++++++ std/gkrapi/compile.go | 46 +++++++++++++------------- 6 files changed, 69 insertions(+), 26 deletions(-) create mode 100644 internal/utils/slices_test.go diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 43384769ac..5a7c995b29 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -9,6 +9,7 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" @@ -43,6 +44,8 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) opt(&s) } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + d := SolvingData{ circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), @@ -50,9 +53,8 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) d.maxNbIn = d.circuit.MaxGateNbIn() - d.assignment = make(WireAssignment, len(d.circuit)) for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) + d.assignment[i] = make([]fr.Element, nbPaddedInstances) } if s.assignment != nil { @@ -68,6 +70,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } } diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 7968c12793..74b15c77ba 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -49,7 +49,7 @@ func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) e instanceI := len(h.assignment[0]) if in0 := ins[0].Uint64(); !ins[0].IsUint64() || in0 > 0xffffffff { return errors.New("first input must be a uint32 instance index") - } else if in0 != uint64(instanceI) || h.info.NbInstances-1 != instanceI { + } else if in0 != uint64(instanceI) || h.info.NbInstances != instanceI { return errors.New("first input must equal the number of instances, and calls to Solve must be done in order of instance index") } diff --git a/internal/gkr/gkrtypes/types.go b/internal/gkr/gkrtypes/types.go index 3226f1aa2f..d313a7bc59 100644 --- a/internal/gkr/gkrtypes/types.go +++ b/internal/gkr/gkrtypes/types.go @@ -228,6 +228,7 @@ func CircuitInfoToCircuit(info gkrinfo.Circuit, gateGetter func(name gkr.GateNam resCircuit := make(Circuit, len(info)) for i := range info { if info[i].Gate == "" && len(info[i].Inputs) == 0 { + resCircuit[i].Gate = Identity() // input wire continue } resCircuit[i].Inputs = info[i].Inputs diff --git a/internal/utils/slices.go b/internal/utils/slices.go index dd2e2db31f..f493bf4bca 100644 --- a/internal/utils/slices.go +++ b/internal/utils/slices.go @@ -16,3 +16,15 @@ func References[T any](v []T) []*T { } return res } + +// ExtendRepeatLast extends the slice s by repeating the last element until it reaches the length n. +func ExtendRepeatLast[T any](s []T, n int) []T { + if n <= len(s) { + return s[:n] + } + s = s[:len(s):len(s)] // ensure s is a slice with a capacity equal to its length + for len(s) < n { + s = append(s, s[len(s)-1]) // append the last element until the length is n + } + return s +} diff --git a/internal/utils/slices_test.go b/internal/utils/slices_test.go new file mode 100644 index 0000000000..f61ec18fed --- /dev/null +++ b/internal/utils/slices_test.go @@ -0,0 +1,25 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtendRepeatLast(t *testing.T) { + // normal case + s := []int{1, 2, 3} + u := ExtendRepeatLast(s, 5) + assert.Equal(t, []int{1, 2, 3, 3, 3}, u) + + // don't overwrite super-slice + s = []int{1, 2, 3} + u = ExtendRepeatLast(s[:1], 2) + assert.Equal(t, []int{1, 1}, u) + assert.Equal(t, []int{1, 2, 3}, s) + + // trim if n < len(s) + s = []int{1, 2, 3} + u = ExtendRepeatLast(s, 2) + assert.Equal(t, []int{1, 2}, u) +} diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 4cd6c804e5..898c65a41f 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -41,14 +41,6 @@ func New() *API { return &API{} } -// log2 returns -1 if x is not a power of 2 -func log2(x uint) int { - if bits.OnesCount(x) != 1 { - return -1 - } - return bits.TrailingZeros(x) -} - // NewInput creates a new input variable. func (api *API) NewInput() gkr.Variable { return gkr.Variable(api.toStore.NewInputVariable()) @@ -80,7 +72,7 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio var err error res.hints, err = gadget.NewTestEngineHints(&res.toStore) if err != nil { - panic(fmt.Errorf("failed to create GKR hints: %w", err)) + panic(fmt.Errorf("failed to call GKR hints: %w", err)) } for _, opt := range options { @@ -103,8 +95,9 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio } } - res.toStore.ProveHintID = solver.GetHintID(res.hints.Prove) res.toStore.GetAssignmentHintID = solver.GetHintID(res.hints.GetAssignment) + res.toStore.ProveHintID = solver.GetHintID(res.hints.Prove) + res.toStore.SolveHintID = solver.GetHintID(res.hints.Solve) parentApi.Compiler().Defer(res.verify) @@ -125,23 +118,20 @@ func (c *Circuit) AddInstance(input map[gkr.Variable]frontend.Variable) (map[gkr } hintIn := make([]frontend.Variable, 1+len(c.ins)) // first input denotes the instance number hintIn[0] = c.toStore.NbInstances - for hintInI, in := range c.ins { - if inV, ok := input[in]; !ok { - return nil, fmt.Errorf("missing entry for input variable %d", in) + for hintInI, wI := range c.ins { + if inV, ok := input[wI]; !ok { + return nil, fmt.Errorf("missing entry for input variable %d", wI) } else { hintIn[hintInI+1] = inV + c.assignments[wI] = append(c.assignments[wI], inV) } } - if c.toStore.NbInstances == 0 { - c.toStore.SolveHintID = solver.GetHintID(c.hints.Solve) - } - - c.toStore.NbInstances++ outsSerialized, err := c.api.Compiler().NewHint(c.hints.Solve, len(c.outs), hintIn...) if err != nil { - return nil, fmt.Errorf("failed to create solve hint: %w", err) + return nil, fmt.Errorf("failed to call solve hint: %w", err) } + c.toStore.NbInstances++ res := make(map[gkr.Variable]frontend.Variable, len(c.outs)) for i, v := range c.outs { res[v] = outsSerialized[i] @@ -157,11 +147,22 @@ func (c *Circuit) verify(api frontend.API) error { panic("api mismatch") } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(c.toStore.NbInstances))) + // pad instances to the next power of 2 by repeating the last instance + if c.toStore.NbInstances < nbPaddedInstances && c.toStore.NbInstances > 0 { + for _, wI := range c.ins { + c.assignments[wI] = utils.ExtendRepeatLast(c.assignments[wI], nbPaddedInstances) + } + for _, wI := range c.outs { + c.assignments[wI] = utils.ExtendRepeatLast(c.assignments[wI], nbPaddedInstances) + } + } + if err := api.(gkrinfo.ConstraintSystem).SetGkrInfo(c.toStore); err != nil { return err } - if len(c.outs) == 0 || len(c.assignments[0]) == 0 { + if len(c.outs) == 0 || len(c.assignments[0]) == 0 { // wire 0 is always an input wire return nil } @@ -194,7 +195,6 @@ func (c *Circuit) verify(api frontend.API) error { if err != nil { return fmt.Errorf("failed to create circuit data for snark: %w", err) } - logNbInstances := log2(uint(c.assignments.NbInstances())) hintIns := make([]frontend.Variable, len(initialChallenges)+1) // hack: adding one of the outputs of the solve hint to ensure "prove" is called after "solve" firstOutputAssignment := c.assignments[c.outs[0]] @@ -203,7 +203,7 @@ func (c *Circuit) verify(api frontend.API) error { copy(hintIns[1:], initialChallenges) if proofSerialized, err = api.Compiler().NewHint( - c.hints.Prove, gadget.ProofSize(forSnark.circuit, logNbInstances), hintIns...); err != nil { + c.hints.Prove, gadget.ProofSize(forSnark.circuit, bits.TrailingZeros(uint(nbPaddedInstances))), hintIns...); err != nil { return err } c.toStore.ProveHintID = solver.GetHintID(c.hints.Prove) @@ -253,7 +253,7 @@ func newCircuitDataForSnark(curve ecc.ID, info gkrinfo.StoringInfo, assignment g func init() { // TODO Move this to the hash package if the import cycle issue is fixed. - hash.Register("mimc", func(api frontend.API) (hash.FieldHasher, error) { + hash.Register("MIMC", func(api frontend.API) (hash.FieldHasher, error) { h, err := mimc.NewMiMC(api) return &h, err }) From 22b77f2d604842e012ac1c6e7beb80469ab0fee8 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 18:58:23 -0500 Subject: [PATCH 25/92] some more padding fixes --- .../backend/template/gkr/solver_hints.go.tmpl | 10 +++++++--- internal/gkr/bls12-377/solver_hints.go | 5 ++--- internal/gkr/bls12-381/solver_hints.go | 10 +++++++--- internal/gkr/bls24-315/solver_hints.go | 10 +++++++--- internal/gkr/bls24-317/solver_hints.go | 10 +++++++--- internal/gkr/bn254/solver_hints.go | 10 +++++++--- internal/gkr/bw6-633/solver_hints.go | 10 +++++++--- internal/gkr/bw6-761/solver_hints.go | 10 +++++++--- 8 files changed, 51 insertions(+), 24 deletions(-) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index cd89a9ffff..df8e758be1 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -2,6 +2,7 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" @@ -43,9 +44,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) d.maxNbIn = d.circuit.MaxGateNbIn() - d.assignment = make(WireAssignment, len(d.circuit)) + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) + d.assignment[i] = make([]fr.Element, nbPaddedInstances) } if s.assignment != nil { @@ -56,11 +57,14 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) if len(s.assignment[i]) != info.NbInstances { panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) } - for j := range d.assignment[i] { + for j := range s.assignment[i] { if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } } diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 5a7c995b29..1e7c4e1f31 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -44,8 +44,6 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) opt(&s) } - nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) - d := SolvingData{ circuit: info.Circuit, assignment: make(WireAssignment, len(info.Circuit)), @@ -53,6 +51,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) d.maxNbIn = d.circuit.MaxGateNbIn() + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) for i := range d.assignment { d.assignment[i] = make([]fr.Element, nbPaddedInstances) } @@ -65,7 +64,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) if len(s.assignment[i]) != info.NbInstances { panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) } - for j := range d.assignment[i] { + for j := range s.assignment[i] { if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index bcf449e99c..fa88d12c42 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -9,6 +9,7 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" @@ -50,9 +51,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) d.maxNbIn = d.circuit.MaxGateNbIn() - d.assignment = make(WireAssignment, len(d.circuit)) + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) + d.assignment[i] = make([]fr.Element, nbPaddedInstances) } if s.assignment != nil { @@ -63,11 +64,14 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) if len(s.assignment[i]) != info.NbInstances { panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) } - for j := range d.assignment[i] { + for j := range s.assignment[i] { if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } } diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 6f254299a0..012be51dee 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -9,6 +9,7 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" @@ -50,9 +51,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) d.maxNbIn = d.circuit.MaxGateNbIn() - d.assignment = make(WireAssignment, len(d.circuit)) + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) + d.assignment[i] = make([]fr.Element, nbPaddedInstances) } if s.assignment != nil { @@ -63,11 +64,14 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) if len(s.assignment[i]) != info.NbInstances { panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) } - for j := range d.assignment[i] { + for j := range s.assignment[i] { if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } } diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index 6121504a68..3fda62407b 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -9,6 +9,7 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" @@ -50,9 +51,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) d.maxNbIn = d.circuit.MaxGateNbIn() - d.assignment = make(WireAssignment, len(d.circuit)) + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) + d.assignment[i] = make([]fr.Element, nbPaddedInstances) } if s.assignment != nil { @@ -63,11 +64,14 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) if len(s.assignment[i]) != info.NbInstances { panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) } - for j := range d.assignment[i] { + for j := range s.assignment[i] { if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } } diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 6bd76425a1..5c8d3e7b1d 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -9,6 +9,7 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" @@ -50,9 +51,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) d.maxNbIn = d.circuit.MaxGateNbIn() - d.assignment = make(WireAssignment, len(d.circuit)) + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) + d.assignment[i] = make([]fr.Element, nbPaddedInstances) } if s.assignment != nil { @@ -63,11 +64,14 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) if len(s.assignment[i]) != info.NbInstances { panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) } - for j := range d.assignment[i] { + for j := range s.assignment[i] { if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } } diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 3678eab8bf..43e90fdfde 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -9,6 +9,7 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" @@ -50,9 +51,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) d.maxNbIn = d.circuit.MaxGateNbIn() - d.assignment = make(WireAssignment, len(d.circuit)) + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) + d.assignment[i] = make([]fr.Element, nbPaddedInstances) } if s.assignment != nil { @@ -63,11 +64,14 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) if len(s.assignment[i]) != info.NbInstances { panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) } - for j := range d.assignment[i] { + for j := range s.assignment[i] { if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } } diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index baa1303132..82c9dd0eb8 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -9,6 +9,7 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" @@ -50,9 +51,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) d.maxNbIn = d.circuit.MaxGateNbIn() - d.assignment = make(WireAssignment, len(d.circuit)) + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) + d.assignment[i] = make([]fr.Element, nbPaddedInstances) } if s.assignment != nil { @@ -63,11 +64,14 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) if len(s.assignment[i]) != info.NbInstances { panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) } - for j := range d.assignment[i] { + for j := range s.assignment[i] { if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } } From 250a35bc3cbbfe3fc7720d7d8f34ed7e4950c577 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 20:49:28 -0500 Subject: [PATCH 26/92] fix: padding issue in bn254 --- internal/gkr/bn254/solver_hints.go | 31 ++++++++++++++++++++++++------ std/gkrapi/api_test.go | 17 ++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 5c8d3e7b1d..484d2d7122 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -21,9 +21,10 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } type newSolvingDataSettings struct { @@ -45,8 +46,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -69,6 +71,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value } @@ -134,7 +137,13 @@ func SolveHint(data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -151,3 +160,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/std/gkrapi/api_test.go b/std/gkrapi/api_test.go index 3b2be818bd..5128f337b6 100644 --- a/std/gkrapi/api_test.go +++ b/std/gkrapi/api_test.go @@ -17,6 +17,7 @@ import ( "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/gkrapi/gkr" stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" @@ -643,3 +644,19 @@ func TestPow4Circuit_GetValue(t *testing.T) { test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) } + +func TestWitnessExtend(t *testing.T) { + circuit := doubleNoDependencyCircuit{X: make([]frontend.Variable, 3), hashName: "-1"} + assignment := doubleNoDependencyCircuit{X: []frontend.Variable{0, 0, 1}} + + cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(t, err) + + witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + require.NoError(t, err) + + _, err = cs.Solve(witness) + require.NoError(t, err) + + //test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} From 15867602fb4ddbd9e15777ddcec4974fbcaf08e9 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 20:53:57 -0500 Subject: [PATCH 27/92] chore generify fix --- .../backend/template/gkr/solver_hints.go.tmpl | 31 +++++++++++++++---- internal/gkr/bls12-377/solver_hints.go | 31 +++++++++++++++---- internal/gkr/bls12-381/solver_hints.go | 31 +++++++++++++++---- internal/gkr/bls24-315/solver_hints.go | 31 +++++++++++++++---- internal/gkr/bls24-317/solver_hints.go | 31 +++++++++++++++---- internal/gkr/bw6-633/solver_hints.go | 31 +++++++++++++++---- internal/gkr/bw6-761/solver_hints.go | 31 +++++++++++++++---- 7 files changed, 175 insertions(+), 42 deletions(-) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index df8e758be1..917fef8f7f 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -14,9 +14,10 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } type newSolvingDataSettings struct { @@ -38,8 +39,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -62,6 +64,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value } @@ -128,7 +131,13 @@ func SolveHint(data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -144,4 +153,14 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { return proof.SerializeToBigInts(outs) } +} + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +{{ print "// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}}"}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } } \ No newline at end of file diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 1e7c4e1f31..258c4990e6 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -21,9 +21,10 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } type newSolvingDataSettings struct { @@ -45,8 +46,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -69,6 +71,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value } @@ -134,7 +137,13 @@ func SolveHint(data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -151,3 +160,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index fa88d12c42..81dadc3be5 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -21,9 +21,10 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } type newSolvingDataSettings struct { @@ -45,8 +46,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -69,6 +71,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value } @@ -134,7 +137,13 @@ func SolveHint(data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -151,3 +160,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 012be51dee..fe313bd479 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -21,9 +21,10 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } type newSolvingDataSettings struct { @@ -45,8 +46,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -69,6 +71,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value } @@ -134,7 +137,13 @@ func SolveHint(data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -151,3 +160,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index 3fda62407b..1734b80b6b 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -21,9 +21,10 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } type newSolvingDataSettings struct { @@ -45,8 +46,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -69,6 +71,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value } @@ -134,7 +137,13 @@ func SolveHint(data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -151,3 +160,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 43e90fdfde..3d4cee67f4 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -21,9 +21,10 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } type newSolvingDataSettings struct { @@ -45,8 +46,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -69,6 +71,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value } @@ -134,7 +137,13 @@ func SolveHint(data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -151,3 +160,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 82c9dd0eb8..c75a88eb4f 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -21,9 +21,10 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - maxNbIn int // maximum number of inputs for a gate in the circuit + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } type newSolvingDataSettings struct { @@ -45,8 +46,9 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) } d := SolvingData{ - circuit: info.Circuit, - assignment: make(WireAssignment, len(info.Circuit)), + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } d.maxNbIn = d.circuit.MaxGateNbIn() @@ -69,6 +71,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value } @@ -134,7 +137,13 @@ func SolveHint(data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -151,3 +160,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} From e593d9087a9e09bfadcb8c2176a8e9d18e83fab7 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 3 Jun 2025 23:55:49 -0500 Subject: [PATCH 28/92] Let Uint64 panic --- .../generator/backend/template/gkr/solver_hints.go.tmpl | 6 ------ internal/gkr/bls12-377/solver_hints.go | 6 ------ internal/gkr/bls12-381/solver_hints.go | 6 ------ internal/gkr/bls24-315/solver_hints.go | 6 ------ internal/gkr/bls24-317/solver_hints.go | 6 ------ internal/gkr/bn254/solver_hints.go | 6 ------ internal/gkr/bw6-633/solver_hints.go | 6 ------ internal/gkr/bw6-761/solver_hints.go | 6 ------ 8 files changed, 48 deletions(-) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index 917fef8f7f..29698e0e3b 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -84,9 +84,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() - if !ins[0].IsUint64() || !ins[1].IsUint64() { - return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) - } data.assignment[wireI][instanceI].BigInt(outs[0]) @@ -97,9 +94,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { - return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) - } gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 258c4990e6..04c5f52586 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -90,9 +90,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() - if !ins[0].IsUint64() || !ins[1].IsUint64() { - return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) - } data.assignment[wireI][instanceI].BigInt(outs[0]) @@ -103,9 +100,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { - return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) - } gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 81dadc3be5..e92e543398 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -90,9 +90,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() - if !ins[0].IsUint64() || !ins[1].IsUint64() { - return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) - } data.assignment[wireI][instanceI].BigInt(outs[0]) @@ -103,9 +100,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { - return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) - } gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index fe313bd479..f57537b985 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -90,9 +90,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() - if !ins[0].IsUint64() || !ins[1].IsUint64() { - return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) - } data.assignment[wireI][instanceI].BigInt(outs[0]) @@ -103,9 +100,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { - return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) - } gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index 1734b80b6b..d2cc4d32b1 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -90,9 +90,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() - if !ins[0].IsUint64() || !ins[1].IsUint64() { - return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) - } data.assignment[wireI][instanceI].BigInt(outs[0]) @@ -103,9 +100,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { - return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) - } gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 484d2d7122..5813b89661 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -90,9 +90,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() - if !ins[0].IsUint64() || !ins[1].IsUint64() { - return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) - } data.assignment[wireI][instanceI].BigInt(outs[0]) @@ -103,9 +100,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { - return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) - } gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 3d4cee67f4..ef945e25f7 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -90,9 +90,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() - if !ins[0].IsUint64() || !ins[1].IsUint64() { - return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) - } data.assignment[wireI][instanceI].BigInt(outs[0]) @@ -103,9 +100,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { - return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) - } gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index c75a88eb4f..1a91928171 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -90,9 +90,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() - if !ins[0].IsUint64() || !ins[1].IsUint64() { - return fmt.Errorf("inputs to GetAssignmentHint must be the wire index and instance index; provided values %s and %s don't fit in 64 bits", ins[0], ins[1]) - } data.assignment[wireI][instanceI].BigInt(outs[0]) @@ -103,9 +100,6 @@ func GetAssignmentHint(data *SolvingData) hint.Hint { func SolveHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { instanceI := ins[0].Uint64() - if !ins[0].IsUint64() { - return fmt.Errorf("first input to solving hint must be the instance index; provided value %s doesn't fit in 64 bits", ins[0]) - } gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 From 9df619ab1dfa758975193dda6ced8ae8a4177790 Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 4 Jun 2025 07:59:54 -0500 Subject: [PATCH 29/92] Update constraint/solver/gkrgates/registry.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- constraint/solver/gkrgates/registry.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/constraint/solver/gkrgates/registry.go b/constraint/solver/gkrgates/registry.go index 71b4969883..23b821cf63 100644 --- a/constraint/solver/gkrgates/registry.go +++ b/constraint/solver/gkrgates/registry.go @@ -131,7 +131,7 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) (register return false, err } if !gateVer.equal(f, g.Evaluate, nbIn) { - return false, fmt.Errorf("mismatch with already registered gate \"%s\" (degree %d) over curve %s", s.name, s.degree, curve) + return false, fmt.Errorf("mismatch with already registered gate \"%s\" (degree %d) over curve %s", s.name, g.Degree(), curve) } } From 8c88c52a604a29682981d052a3943f01c82b56b4 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 4 Jun 2025 08:03:35 -0500 Subject: [PATCH 30/92] refactor: use assert.EqualError --- constraint/solver/gkrgates/registry_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/constraint/solver/gkrgates/registry_test.go b/constraint/solver/gkrgates/registry_test.go index 5a9e6e871d..7fe739b152 100644 --- a/constraint/solver/gkrgates/registry_test.go +++ b/constraint/solver/gkrgates/registry_test.go @@ -67,7 +67,7 @@ func TestRegister(t *testing.T) { t.Run("zero", func(t *testing.T) { const gateName gkr.GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate \"%s\": %v", gateName, gkrtypes.ErrZeroFunction) + expectedError := fmt.Errorf("for gate \"%s\": %v", gateName, gkrtypes.ErrZeroFunction).Error() zeroGate := func(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Sub(x[0], x[0]) } @@ -75,13 +75,13 @@ func TestRegister(t *testing.T) { // Attempt to register the zero gate without specifying a degree registered, err := Register(zeroGate, 1, WithName(gateName)) assert.Error(t, err, "error must be returned for zero polynomial") - assert.Equal(t, expectedError, err, "error message must match expected error") + assert.EqualError(t, err, expectedError, "error message must match expected error") assert.False(t, registered, "registration must fail for zero polynomial") // Attempt to register the zero gate with a specified degree registered, err = Register(zeroGate, 1, WithName(gateName), WithDegree(2)) assert.Error(t, err, "error must be returned for zero polynomial with degree") - assert.Equal(t, expectedError, err, "error message must match expected error") + assert.EqualError(t, err, expectedError, "error message must match expected error") assert.False(t, registered, "registration must fail for zero polynomial with degree") }) } From 499ff522c0e2555e985e0c6ec6ccb31ff7f0b5d8 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 5 Jun 2025 17:58:58 -0500 Subject: [PATCH 31/92] refactor: import hash/all in test --- std/permutation/poseidon2/gkr-poseidon2/gkr_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 7d64aedb92..601d80cb70 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -10,6 +10,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" + _ "github.com/consensys/gnark/std/hash/all" "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" ) From be747e027e35f5ceccaffe6ea835e75c1ba59f8c Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 5 Jun 2025 18:26:51 -0500 Subject: [PATCH 32/92] refactor: shorten import name --- std/permutation/poseidon2/gkr-poseidon2/gkr.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index d9a7dcfbbb..471c1f3887 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -9,7 +9,7 @@ import ( "github.com/consensys/gnark/std/gkrapi/gkr" "github.com/consensys/gnark-crypto/ecc" - poseidon2Bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" ) @@ -159,9 +159,9 @@ func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, er } // poseidon2 parameters - gateNamer := newRoundGateNamer(poseidon2Bls12377.GetDefaultParameters()) - rF := poseidon2Bls12377.GetDefaultParameters().NbFullRounds - rP := poseidon2Bls12377.GetDefaultParameters().NbPartialRounds + gateNamer := newRoundGateNamer(bls12377.GetDefaultParameters()) + rF := bls12377.GetDefaultParameters().NbFullRounds + rP := bls12377.GetDefaultParameters().NbPartialRounds halfRf := rF / 2 gkrApi = gkrapi.New() @@ -243,9 +243,9 @@ func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, er return } -var bls12377Permutation = sync.OnceValue(func() *poseidon2Bls12377.Permutation { - params := poseidon2Bls12377.GetDefaultParameters() - return poseidon2Bls12377.NewPermutation(2, params.NbFullRounds, params.NbPartialRounds) // TODO @Tabaie add NewDefaultPermutation to gnark-crypto +var bls12377Permutation = sync.OnceValue(func() *bls12377.Permutation { + params := bls12377.GetDefaultParameters() + return bls12377.NewPermutation(2, params.NbFullRounds, params.NbPartialRounds) // TODO @Tabaie add NewDefaultPermutation to gnark-crypto }) // RegisterGkrGates registers the GKR gates corresponding to the given curves for the solver @@ -271,7 +271,7 @@ func registerGatesBls12377() error { y ) - p := poseidon2Bls12377.GetDefaultParameters() + p := bls12377.GetDefaultParameters() halfRf := p.NbFullRounds / 2 gateNames := newRoundGateNamer(p) From 4b43ce33f59ef939396d9f5253f5f7e989c7e06b Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 11:12:18 -0500 Subject: [PATCH 33/92] perf: addition rather than multiplication in gates --- std/permutation/poseidon2/gkr-poseidon2/gkr.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index 471c1f3887..a9af651bda 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -20,7 +20,7 @@ func extKeyGate(roundKey frontend.Variable) gkr.GateFunction { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(api.Mul(x[0], 2), x[1], roundKey) + return api.Add(x[0], x[0], x[1], roundKey) } } @@ -71,7 +71,7 @@ func extGate2(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(api.Mul(x[1], 2), x[0]) + return api.Add(x[1], x[1], x[0]) } // intKeyGate2 applies the internal matrix mul, then adds the round key @@ -80,7 +80,7 @@ func intKeyGate2(roundKey frontend.Variable) gkr.GateFunction { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(api.Mul(x[1], 3), x[0], roundKey) + return api.Add(x[1], x[1], x[1], x[0], roundKey) } } @@ -89,7 +89,7 @@ func intGate2(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(api.Mul(x[1], 3), x[0]) + return api.Add(x[1], x[1], x[1], x[0]) } // extGate applies the first row of the external matrix @@ -97,7 +97,7 @@ func extGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(api.Mul(x[0], 2), x[1]) + return api.Add(x[0], x[0], x[1]) } // extAddGate applies the first row of the external matrix to the first two elements and adds the third @@ -105,7 +105,7 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 3 { panic("expected 3 inputs") } - return api.Add(api.Mul(x[0], 2), x[1], x[2]) + return api.Add(x[0], x[0], x[1], x[2]) } type GkrPermutations struct { From 6b86526d708e66cc5e3846452f3259a9a41b724f Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 12:57:43 -0500 Subject: [PATCH 34/92] Revert "perf: addition rather than multiplication in gates" This reverts commit 4b43ce33f59ef939396d9f5253f5f7e989c7e06b. --- std/permutation/poseidon2/gkr-poseidon2/gkr.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index a9af651bda..471c1f3887 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -20,7 +20,7 @@ func extKeyGate(roundKey frontend.Variable) gkr.GateFunction { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(x[0], x[0], x[1], roundKey) + return api.Add(api.Mul(x[0], 2), x[1], roundKey) } } @@ -71,7 +71,7 @@ func extGate2(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(x[1], x[1], x[0]) + return api.Add(api.Mul(x[1], 2), x[0]) } // intKeyGate2 applies the internal matrix mul, then adds the round key @@ -80,7 +80,7 @@ func intKeyGate2(roundKey frontend.Variable) gkr.GateFunction { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(x[1], x[1], x[1], x[0], roundKey) + return api.Add(api.Mul(x[1], 3), x[0], roundKey) } } @@ -89,7 +89,7 @@ func intGate2(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(x[1], x[1], x[1], x[0]) + return api.Add(api.Mul(x[1], 3), x[0]) } // extGate applies the first row of the external matrix @@ -97,7 +97,7 @@ func extGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(x[0], x[0], x[1]) + return api.Add(api.Mul(x[0], 2), x[1]) } // extAddGate applies the first row of the external matrix to the first two elements and adds the third @@ -105,7 +105,7 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 3 { panic("expected 3 inputs") } - return api.Add(x[0], x[0], x[1], x[2]) + return api.Add(api.Mul(x[0], 2), x[1], x[2]) } type GkrPermutations struct { From c4b01b579e908db2f1f22de7a66fc3309fab2b16 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:14:28 -0500 Subject: [PATCH 35/92] feat: generify poseidon2-gkr (not all s-Boxes available yet) --- .../poseidon2/gkr-poseidon2/gkr.go | 114 ++++--- .../poseidon2/gkr-poseidon2/gkr_test.go | 4 +- std/permutation/poseidon2/poseidon2.go | 281 +++++++++++++----- 3 files changed, 253 insertions(+), 146 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index 471c1f3887..dcc51a4031 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -1,15 +1,16 @@ package gkr_poseidon2 import ( + "errors" "fmt" - "sync" "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" + "github.com/consensys/gnark/std/permutation/poseidon2" "github.com/consensys/gnark-crypto/ecc" - bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" ) @@ -117,18 +118,18 @@ type GkrPermutations struct { // NewGkrPermutations returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. -// Note that the solver will need the function RegisterGkrGates to be called with the desired curves +// Note that the solver will need the function RegisterGates to be called with the desired curves func NewGkrPermutations(api frontend.API) *GkrPermutations { if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { panic("currently only BL12-377 is supported") } - gkrApi, in1, in2, out, err := defineCircuitBls12377() + gkrCircuit, in1, in2, out, err := defineCircuit(api) if err != nil { - panic(fmt.Errorf("failed to define GKR circuit: %v", err)) + panic(fmt.Errorf("failed to define GKR circuit: %w", err)) } return &GkrPermutations{ api: api, - gkrCircuit: gkrApi.Compile(api, "MIMC"), + gkrCircuit: gkrCircuit, in1: in1, in2: in2, out: out, @@ -144,27 +145,28 @@ func (p *GkrPermutations) Compress(a, b frontend.Variable) frontend.Variable { return outs[p.out] } -// defineCircuitBls12377 defines the GKR circuit for the Poseidon2 permutation over BLS12-377 +// defineCircuit defines the GKR circuit for the Poseidon2 permutation over BLS12-377 // insLeft and insRight are the inputs to the permutation // they must be padded to a power of 2 -func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, err error) { +func defineCircuit(api frontend.API) (gkrCircuit *gkrapi.Circuit, in1, in2, out gkr.Variable, err error) { // variable indexes const ( xI = iota yI ) - if err = registerGatesBls12377(); err != nil { + curve := utils.FieldToCurve(api.Compiler().Field()) + p, err := poseidon2.GetDefaultParameters(curve) + if err != nil { return } + gateNamer := newRoundGateNamer(&p, curve) - // poseidon2 parameters - gateNamer := newRoundGateNamer(bls12377.GetDefaultParameters()) - rF := bls12377.GetDefaultParameters().NbFullRounds - rP := bls12377.GetDefaultParameters().NbPartialRounds - halfRf := rF / 2 + if err = registerGates(&p, curve); err != nil { + return + } - gkrApi = gkrapi.New() + gkrApi := gkrapi.New() x := gkrApi.NewInput() y := gkrApi.NewInput() @@ -181,9 +183,17 @@ func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, er // apply the s-Box to u // the s-Box gates: u¹⁷ = (u⁴)⁴ * u - sBox := func(u gkr.Variable) gkr.Variable { - v := gkrApi.Gate(pow4Gate, u) // u⁴ - return gkrApi.Gate(pow4TimesGate, v, u) // u¹⁷ + + var sBox func(gkr.Variable) gkr.Variable + switch p.DegreeSBox { + case 17: + sBox = func(u gkr.Variable) gkr.Variable { + v := gkrApi.Gate(pow4Gate, u) // u⁴ + return gkrApi.Gate(pow4TimesGate, v, u) // u¹⁷ + } + default: + err = fmt.Errorf("unsupported s-Box degree %d", p.DegreeSBox) + return } // apply external matrix multiplication and round key addition @@ -208,89 +218,68 @@ func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, er // *** construct the circuit *** - for i := range halfRf { + for i := range p.NbFullRounds / 2 { fullRound(i) } { // i = halfRf: first partial round // still using the external matrix, since the linear operation still belongs to a full (canonical) round - x1 := extKeySBox(halfRf, xI, x, y) + x1 := extKeySBox(p.NbFullRounds/2, xI, x, y) x, y = x1, gkrApi.Gate(extGate2, x, y) } - for i := halfRf + 1; i < halfRf+rP; i++ { + for i := p.NbFullRounds/2 + 1; i < p.NbFullRounds/2+p.NbPartialRounds; i++ { x1 := extKeySBox(i, xI, x, y) // the first row of the internal matrix is the same as that of the external matrix x, y = x1, gkrApi.Gate(intGate2, x, y) } { - i := halfRf + rP + i := p.NbFullRounds/2 + p.NbPartialRounds // first iteration of the final batch of full rounds // still using the internal matrix, since the linear operation still belongs to a partial (canonical) round x1 := extKeySBox(i, xI, x, y) x, y = x1, intKeySBox2(i, x, y) } - for i := halfRf + rP + 1; i < rP+rF; i++ { + for i := p.NbFullRounds/2 + p.NbPartialRounds + 1; i < p.NbPartialRounds+p.NbFullRounds; i++ { fullRound(i) } // apply the external matrix one last time to obtain the final value of y - out = gkrApi.NamedGate(gateNamer.linear(yI, rP+rF), y, x, in2) + out = gkrApi.Gate(extAddGate, y, x, in2) + + gkrCircuit = gkrApi.Compile(api, "MIMC") return } -var bls12377Permutation = sync.OnceValue(func() *bls12377.Permutation { - params := bls12377.GetDefaultParameters() - return bls12377.NewPermutation(2, params.NbFullRounds, params.NbPartialRounds) // TODO @Tabaie add NewDefaultPermutation to gnark-crypto -}) - -// RegisterGkrGates registers the GKR gates corresponding to the given curves for the solver -func RegisterGkrGates(curves ...ecc.ID) { +// RegisterGates registers the GKR gates corresponding to the given curves for the solver. +func RegisterGates(curves ...ecc.ID) error { if len(curves) == 0 { - panic("expected at least one curve") + return errors.New("expected at least one curve") } for _, curve := range curves { - switch curve { - case ecc.BLS12_377: - if err := registerGatesBls12377(); err != nil { - panic(err) - } - default: - panic(fmt.Sprintf("curve %s not currently supported", curve)) + p, err := poseidon2.GetDefaultParameters(curve) + if err != nil { + return fmt.Errorf("failed to get default parameters for curve %s: %w", curve, err) + } + if err = registerGates(&p, curve); err != nil { + return fmt.Errorf("failed to register gates for curve %s: %w", curve, err) } } + return nil } -func registerGatesBls12377() error { +func registerGates(p *poseidon2.Parameters, curve ecc.ID) error { const ( x = iota y ) - p := bls12377.GetDefaultParameters() + gateNames := newRoundGateNamer(p, curve) halfRf := p.NbFullRounds / 2 - gateNames := newRoundGateNamer(p) - - if _, err := gkrgates.Register(pow2Gate, 1, gkrgates.WithUnverifiedDegree(2), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - if _, err := gkrgates.Register(pow4Gate, 1, gkrgates.WithUnverifiedDegree(4), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - if _, err := gkrgates.Register(pow2TimesGate, 2, gkrgates.WithUnverifiedDegree(3), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - if _, err := gkrgates.Register(pow4TimesGate, 2, gkrgates.WithUnverifiedDegree(5), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - - if _, err := gkrgates.Register(intGate2, 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } extKeySBox := func(round int, varIndex int) error { _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(ecc.BLS12_377)) @@ -343,15 +332,14 @@ func registerGatesBls12377() error { } } - _, err := gkrgates.Register(extAddGate, 3, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, p.NbPartialRounds+p.NbFullRounds)), gkrgates.WithCurves(ecc.BLS12_377)) - return err + return nil } type roundGateNamer string // newRoundGateNamer returns an object that returns standardized names for gates in the GKR circuit -func newRoundGateNamer(p fmt.Stringer) roundGateNamer { - return roundGateNamer(p.String()) +func newRoundGateNamer(p *poseidon2.Parameters, curve ecc.ID) roundGateNamer { + return roundGateNamer(fmt.Sprintf("Poseidon2-%s[t=%d,rF=%d,rP=%d,d=%d]", curve.String(), p.Width, p.NbFullRounds, p.NbPartialRounds, p.DegreeSBox)) } // linear is the name of a gate where a polynomial of total degree 1 is applied to the input diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 601d80cb70..2d9ad7ccb5 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -76,12 +76,12 @@ func TestGkrPermutationCompiles(t *testing.T) { } func BenchmarkGkrPermutations(b *testing.B) { - circuit, assignmment := gkrPermutationsCircuits(b, 50000) + circuit, assignment := gkrPermutationsCircuits(b, 50000) cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) require.NoError(b, err) - witness, err := frontend.NewWitness(&assignmment, ecc.BLS12_377.ScalarField()) + witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) require.NoError(b, err) // cpu profile diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index 55afe73be5..317cc977dd 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -23,38 +23,157 @@ var ( type Permutation struct { api frontend.API - params parameters + params Parameters } -// parameters describing the poseidon2 implementation -type parameters struct { +// Parameters describing the poseidon2 implementation +type Parameters struct { // len(preimage)+len(digest)=len(preimage)+ceil(log(2*/r)) - width int + Width int // sbox degree - degreeSBox int + DegreeSBox int // number of full rounds (even number) - nbFullRounds int + NbFullRounds int // number of partial rounds - nbPartialRounds int + NbPartialRounds int // round keys: ordered by round then variable - roundKeys [][]big.Int + RoundKeys [][]big.Int +} + +func GetDefaultParameters(curve ecc.ID) (Parameters, error) { + switch curve { // TODO: assumes pairing based builder, reconsider when supporting other backends + case ecc.BN254: + p := poseidonbn254.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbn254.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS12_381: + p := poseidonbls12381.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls12381.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS12_377: + p := poseidonbls12377.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls12377.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BW6_761: + p := poseidonbw6761.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbw6761.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BW6_633: + p := poseidonbw6633.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbw6633.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS24_315: + p := poseidonbls24315.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls24315.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS24_317: + p := poseidonbls24317.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls24317.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + default: + return Parameters{}, fmt.Errorf("curve %s not supported", curve) + } } // NewPoseidon2 returns a new Poseidon2 hasher with default parameters as // defined in the gnark-crypto library. func NewPoseidon2(api frontend.API) (*Permutation, error) { - switch utils.FieldToCurve(api.Compiler().Field()) { // TODO: assumes pairing based builder, reconsider when supporting other backends - case ecc.BLS12_377: - params := poseidonbls12377.GetDefaultParameters() - return NewPoseidon2FromParameters(api, 2, params.NbFullRounds, params.NbPartialRounds) - // TODO: we don't have default parameters for other curves yet. Update this when we do. - default: - return nil, fmt.Errorf("field %s not supported", api.Compiler().Field().String()) + params, err := GetDefaultParameters(utils.FieldToCurve(api.Compiler().Field())) + if err != nil { + return nil, err } + return &Permutation{ + api: api, + params: params, + }, nil } // NewPoseidon2FromParameters returns a new Poseidon2 hasher with the given parameters. @@ -62,76 +181,76 @@ func NewPoseidon2(api frontend.API) (*Permutation, error) { // is deterministic and depends on the curve ID. See the corresponding NewParameters // function in the gnark-crypto library poseidon2 packages for more details. func NewPoseidon2FromParameters(api frontend.API, width, nbFullRounds, nbPartialRounds int) (*Permutation, error) { - params := parameters{width: width, nbFullRounds: nbFullRounds, nbPartialRounds: nbPartialRounds} + params := Parameters{Width: width, NbFullRounds: nbFullRounds, NbPartialRounds: nbPartialRounds} switch utils.FieldToCurve(api.Compiler().Field()) { // TODO: assumes pairing based builder, reconsider when supporting other backends case ecc.BN254: - params.degreeSBox = poseidonbn254.DegreeSBox() + params.DegreeSBox = poseidonbn254.DegreeSBox() concreteParams := poseidonbn254.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS12_381: - params.degreeSBox = poseidonbls12381.DegreeSBox() + params.DegreeSBox = poseidonbls12381.DegreeSBox() concreteParams := poseidonbls12381.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS12_377: - params.degreeSBox = poseidonbls12377.DegreeSBox() + params.DegreeSBox = poseidonbls12377.DegreeSBox() concreteParams := poseidonbls12377.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BW6_761: - params.degreeSBox = poseidonbw6761.DegreeSBox() + params.DegreeSBox = poseidonbw6761.DegreeSBox() concreteParams := poseidonbw6761.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BW6_633: - params.degreeSBox = poseidonbw6633.DegreeSBox() + params.DegreeSBox = poseidonbw6633.DegreeSBox() concreteParams := poseidonbw6633.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS24_315: - params.degreeSBox = poseidonbls24315.DegreeSBox() + params.DegreeSBox = poseidonbls24315.DegreeSBox() concreteParams := poseidonbls24315.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS24_317: - params.degreeSBox = poseidonbls24317.DegreeSBox() + params.DegreeSBox = poseidonbls24317.DegreeSBox() concreteParams := poseidonbls24317.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } default: @@ -143,25 +262,25 @@ func NewPoseidon2FromParameters(api frontend.API, width, nbFullRounds, nbPartial // sBox applies the sBox on buffer[index] func (h *Permutation) sBox(index int, input []frontend.Variable) { tmp := input[index] - if h.params.degreeSBox == 3 { + if h.params.DegreeSBox == 3 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(tmp, input[index]) - } else if h.params.degreeSBox == 5 { + } else if h.params.DegreeSBox == 5 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) - } else if h.params.degreeSBox == 7 { + } else if h.params.DegreeSBox == 7 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) - } else if h.params.degreeSBox == 17 { + } else if h.params.DegreeSBox == 17 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) - } else if h.params.degreeSBox == -1 { + } else if h.params.DegreeSBox == -1 { input[index] = h.api.Inverse(input[index]) } } @@ -204,30 +323,30 @@ func (h *Permutation) matMulM4InPlace(s []frontend.Variable) { // see https://eprint.iacr.org/2023/323.pdf func (h *Permutation) matMulExternalInPlace(input []frontend.Variable) { - if h.params.width == 2 { + if h.params.Width == 2 { tmp := h.api.Add(input[0], input[1]) input[0] = h.api.Add(tmp, input[0]) input[1] = h.api.Add(tmp, input[1]) - } else if h.params.width == 3 { + } else if h.params.Width == 3 { tmp := h.api.Add(input[0], input[1]) tmp = h.api.Add(tmp, input[2]) input[0] = h.api.Add(input[0], tmp) input[1] = h.api.Add(input[1], tmp) input[2] = h.api.Add(input[2], tmp) - } else if h.params.width == 4 { + } else if h.params.Width == 4 { h.matMulM4InPlace(input) } else { // at this stage t is supposed to be a multiple of 4 // the MDS matrix is circ(2M4,M4,..,M4) h.matMulM4InPlace(input) tmp := make([]frontend.Variable, 4) - for i := 0; i < h.params.width/4; i++ { + for i := 0; i < h.params.Width/4; i++ { tmp[0] = h.api.Add(tmp[0], input[4*i]) tmp[1] = h.api.Add(tmp[1], input[4*i+1]) tmp[2] = h.api.Add(tmp[2], input[4*i+2]) tmp[3] = h.api.Add(tmp[3], input[4*i+3]) } - for i := 0; i < h.params.width/4; i++ { + for i := 0; i < h.params.Width/4; i++ { input[4*i] = h.api.Add(input[4*i], tmp[0]) input[4*i+1] = h.api.Add(input[4*i], tmp[1]) input[4*i+2] = h.api.Add(input[4*i], tmp[2]) @@ -239,12 +358,12 @@ func (h *Permutation) matMulExternalInPlace(input []frontend.Variable) { // when t=2,3 the matrix are respectively [[2,1][1,3]] and [[2,1,1][1,2,1][1,1,3]] // otherwise the matrix is filled with ones except on the diagonal, func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { - if h.params.width == 2 { + if h.params.Width == 2 { sum := h.api.Add(input[0], input[1]) input[0] = h.api.Add(input[0], sum) input[1] = h.api.Mul(2, input[1]) input[1] = h.api.Add(input[1], sum) - } else if h.params.width == 3 { + } else if h.params.Width == 3 { sum := h.api.Add(input[0], input[1]) sum = h.api.Add(sum, input[2]) input[0] = h.api.Add(input[0], sum) @@ -259,10 +378,10 @@ func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { // var sum frontend.Variable // sum = input[0] - // for i := 1; i < h.params.width; i++ { + // for i := 1; i < h.params.Width; i++ { // sum = api.Add(sum, input[i]) // } - // for i := 0; i < h.params.width; i++ { + // for i := 0; i < h.params.Width; i++ { // input[i] = api.Mul(input[i], h.params.diagInternalMatrices[i]) // input[i] = api.Add(input[i], sum) // } @@ -272,40 +391,40 @@ func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { // addRoundKeyInPlace adds the round-th key to the buffer func (h *Permutation) addRoundKeyInPlace(round int, input []frontend.Variable) { - for i := 0; i < len(h.params.roundKeys[round]); i++ { - input[i] = h.api.Add(input[i], h.params.roundKeys[round][i]) + for i := 0; i < len(h.params.RoundKeys[round]); i++ { + input[i] = h.api.Add(input[i], h.params.RoundKeys[round][i]) } } // Permutation applies the permutation on input, and stores the result in input. func (h *Permutation) Permutation(input []frontend.Variable) error { - if len(input) != h.params.width { + if len(input) != h.params.Width { return ErrInvalidSizebuffer } // external matrix multiplication, cf https://eprint.iacr.org/2023/323.pdf page 14 (part 6) h.matMulExternalInPlace(input) - rf := h.params.nbFullRounds / 2 + rf := h.params.NbFullRounds / 2 for i := 0; i < rf; i++ { // one round = matMulExternal(sBox_Full(addRoundKey)) h.addRoundKeyInPlace(i, input) - for j := 0; j < h.params.width; j++ { + for j := 0; j < h.params.Width; j++ { h.sBox(j, input) } h.matMulExternalInPlace(input) } - for i := rf; i < rf+h.params.nbPartialRounds; i++ { + for i := rf; i < rf+h.params.NbPartialRounds; i++ { // one round = matMulInternal(sBox_sparse(addRoundKey)) h.addRoundKeyInPlace(i, input) h.sBox(0, input) h.matMulInternalInPlace(input) } - for i := rf + h.params.nbPartialRounds; i < h.params.nbFullRounds+h.params.nbPartialRounds; i++ { + for i := rf + h.params.NbPartialRounds; i < h.params.NbFullRounds+h.params.NbPartialRounds; i++ { // one round = matMulExternal(sBox_Full(addRoundKey)) h.addRoundKeyInPlace(i, input) - for j := 0; j < h.params.width; j++ { + for j := 0; j < h.params.Width; j++ { h.sBox(j, input) } h.matMulExternalInPlace(input) @@ -321,7 +440,7 @@ func (h *Permutation) Permutation(input []frontend.Variable) error { // Implements the [hash.Compressor] interface for building a Merkle-Damgard // hash construction. func (h *Permutation) Compress(left, right frontend.Variable) frontend.Variable { - if h.params.width != 2 { + if h.params.Width != 2 { panic("poseidon2: Compress can only be used when t=2") } vars := [2]frontend.Variable{left, right} From 4765d61a173989114a1ca98daa9d19d18c0f7798 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:19:38 -0500 Subject: [PATCH 36/92] revert incorrect renaming --- std/permutation/poseidon2/gkr-poseidon2/gkr.go | 10 +++++----- std/permutation/poseidon2/gkr-poseidon2/gkr_test.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index d9a7dcfbbb..800e79ed05 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -108,17 +108,17 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Add(api.Mul(x[0], 2), x[1], x[2]) } -type GkrPermutations struct { +type GkrCompressor struct { api frontend.API gkrCircuit *gkrapi.Circuit in1, in2, out gkr.Variable } -// NewGkrPermutations returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) +// NewGkrCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. // Note that the solver will need the function RegisterGkrGates to be called with the desired curves -func NewGkrPermutations(api frontend.API) *GkrPermutations { +func NewGkrCompressor(api frontend.API) *GkrCompressor { if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { panic("currently only BL12-377 is supported") } @@ -126,7 +126,7 @@ func NewGkrPermutations(api frontend.API) *GkrPermutations { if err != nil { panic(fmt.Errorf("failed to define GKR circuit: %v", err)) } - return &GkrPermutations{ + return &GkrCompressor{ api: api, gkrCircuit: gkrApi.Compile(api, "MIMC"), in1: in1, @@ -135,7 +135,7 @@ func NewGkrPermutations(api frontend.API) *GkrPermutations { } } -func (p *GkrPermutations) Compress(a, b frontend.Variable) frontend.Variable { +func (p *GkrCompressor) Compress(a, b frontend.Variable) frontend.Variable { outs, err := p.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{p.in1: a, p.in2: b}) if err != nil { panic(err) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 601d80cb70..ffa60d8ccb 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -56,7 +56,7 @@ type testGkrPermutationCircuit struct { func (c *testGkrPermutationCircuit) Define(api frontend.API) error { - pos2 := NewGkrPermutations(api) + pos2 := NewGkrCompressor(api) api.AssertIsEqual(len(c.Ins), len(c.Outs)) for i := range c.Ins { api.AssertIsEqual(c.Outs[i], pos2.Compress(c.Ins[i][0], c.Ins[i][1])) From 5342952398b3c4aa7e45ea3edcc37b3eab85293d Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:20:46 -0500 Subject: [PATCH 37/92] fix test --- std/permutation/poseidon2/gkr-poseidon2/gkr_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 2d9ad7ccb5..9808d5f8dc 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + poseidonbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" _ "github.com/consensys/gnark/std/hash/all" @@ -19,6 +20,8 @@ func gkrPermutationsCircuits(t require.TestingT, n int) (circuit, assignment tes var k int64 ins := make([][2]frontend.Variable, n) outs := make([]frontend.Variable, n) + params := poseidonbls12377.GetDefaultParameters() + permutation := poseidonbls12377.NewPermutation(params.Width, params.NbFullRounds, params.NbPartialRounds) for i := range n { var x [2]fr.Element ins[i] = [2]frontend.Variable{k, k + 1} @@ -27,7 +30,7 @@ func gkrPermutationsCircuits(t require.TestingT, n int) (circuit, assignment tes x[1].SetInt64(k + 1) y0 := x[1] - require.NoError(t, bls12377Permutation().Permutation(x[:])) + require.NoError(t, permutation.Permutation(x[:])) x[1].Add(&x[1], &y0) outs[i] = x[1] From 9d9b13b0cc833e503fc9abb693e60959f410cd1e Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:26:10 -0500 Subject: [PATCH 38/92] refactor generify tests --- .../poseidon2/gkr-poseidon2/gkr.go | 8 +-- .../poseidon2/gkr-poseidon2/gkr_test.go | 55 ++++++++----------- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index dcc51a4031..c87cd450e9 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -119,13 +119,13 @@ type GkrPermutations struct { // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. // Note that the solver will need the function RegisterGates to be called with the desired curves -func NewGkrPermutations(api frontend.API) *GkrPermutations { +func NewGkrPermutations(api frontend.API) (*GkrPermutations, error) { if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { - panic("currently only BL12-377 is supported") + return nil, errors.New("currently only BL12-377 is supported") } gkrCircuit, in1, in2, out, err := defineCircuit(api) if err != nil { - panic(fmt.Errorf("failed to define GKR circuit: %w", err)) + return nil, fmt.Errorf("failed to define GKR circuit: %w", err) } return &GkrPermutations{ api: api, @@ -133,7 +133,7 @@ func NewGkrPermutations(api frontend.API) *GkrPermutations { in1: in1, in2: in2, out: out, - } + }, nil } func (p *GkrPermutations) Compress(a, b frontend.Variable) frontend.Variable { diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 9808d5f8dc..b510663959 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -7,62 +7,53 @@ import ( "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - poseidonbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" _ "github.com/consensys/gnark/std/hash/all" + "github.com/consensys/gnark/std/permutation/poseidon2" "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" ) -func gkrPermutationsCircuits(t require.TestingT, n int) (circuit, assignment testGkrPermutationCircuit) { - var k int64 +func gkrPermutationsCircuits(n int) (circuit, assignment testGkrPermutationCircuit) { ins := make([][2]frontend.Variable, n) - outs := make([]frontend.Variable, n) - params := poseidonbls12377.GetDefaultParameters() - permutation := poseidonbls12377.NewPermutation(params.Width, params.NbFullRounds, params.NbPartialRounds) for i := range n { - var x [2]fr.Element - ins[i] = [2]frontend.Variable{k, k + 1} - - x[0].SetInt64(k) - x[1].SetInt64(k + 1) - y0 := x[1] - - require.NoError(t, permutation.Permutation(x[:])) - x[1].Add(&x[1], &y0) - outs[i] = x[1] - - k += 2 + ins[i] = [2]frontend.Variable{i * 2, i*2 + 1} } return testGkrPermutationCircuit{ - Ins: make([][2]frontend.Variable, len(ins)), - Outs: make([]frontend.Variable, len(outs)), + Ins: make([][2]frontend.Variable, len(ins)), }, testGkrPermutationCircuit{ - Ins: ins, - Outs: outs, + Ins: ins, } } func TestGkrCompression(t *testing.T) { - circuit, assignment := gkrPermutationsCircuits(t, 2) + circuit, assignment := gkrPermutationsCircuits(2) test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BLS12_377)) } type testGkrPermutationCircuit struct { - Ins [][2]frontend.Variable - Outs []frontend.Variable + Ins [][2]frontend.Variable + skipCheck bool } func (c *testGkrPermutationCircuit) Define(api frontend.API) error { - pos2 := NewGkrPermutations(api) - api.AssertIsEqual(len(c.Ins), len(c.Outs)) + gkr, err := NewGkrPermutations(api) + if err != nil { + return err + } + pos2, err := poseidon2.NewPoseidon2(api) + if err != nil { + return err + } for i := range c.Ins { - api.AssertIsEqual(c.Outs[i], pos2.Compress(c.Ins[i][0], c.Ins[i][1])) + fromGkr := gkr.Compress(c.Ins[i][0], c.Ins[i][1]) + if !c.skipCheck { + api.AssertIsEqual(pos2.Compress(c.Ins[i][0], c.Ins[i][1]), fromGkr) + } } return nil @@ -71,15 +62,15 @@ func (c *testGkrPermutationCircuit) Define(api frontend.API) error { func TestGkrPermutationCompiles(t *testing.T) { // just measure the number of constraints cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &testGkrPermutationCircuit{ - Ins: make([][2]frontend.Variable, 52000), - Outs: make([]frontend.Variable, 52000), + Ins: make([][2]frontend.Variable, 52000), + skipCheck: true, }) require.NoError(t, err) fmt.Println(cs.GetNbConstraints(), "constraints") } func BenchmarkGkrPermutations(b *testing.B) { - circuit, assignment := gkrPermutationsCircuits(b, 50000) + circuit, assignment := gkrPermutationsCircuits(50000) cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) require.NoError(b, err) From 430661accfe2fc202f86c249d0de1f08d32b74df Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:36:56 -0500 Subject: [PATCH 39/92] feat: gkrposeidon2 compression for all curves --- .../poseidon2/gkr-poseidon2/gkr.go | 26 ++++++++++++++----- .../poseidon2/gkr-poseidon2/gkr_test.go | 2 +- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index c87cd450e9..fea7608ac6 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -47,6 +47,14 @@ func pow4TimesGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Mul(y, x[1]) } +// pow3Gate computes a -> a³ +func pow3Gate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + if len(x) != 1 { + panic("expected 1 input") + } + return api.Mul(x[0], x[0], x[0]) +} + // pow2Gate computes a -> a² func pow2Gate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 1 { @@ -120,9 +128,6 @@ type GkrPermutations struct { // The correctness of the compression functions is proven using GKR. // Note that the solver will need the function RegisterGates to be called with the desired curves func NewGkrPermutations(api frontend.API) (*GkrPermutations, error) { - if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { - return nil, errors.New("currently only BL12-377 is supported") - } gkrCircuit, in1, in2, out, err := defineCircuit(api) if err != nil { return nil, fmt.Errorf("failed to define GKR circuit: %w", err) @@ -182,10 +187,19 @@ func defineCircuit(api frontend.API) (gkrCircuit *gkrapi.Circuit, in1, in2, out // in every round comes from the previous (canonical) round. // apply the s-Box to u - // the s-Box gates: u¹⁷ = (u⁴)⁴ * u var sBox func(gkr.Variable) gkr.Variable switch p.DegreeSBox { + case 5: + sBox = func(u gkr.Variable) gkr.Variable { + v := gkrApi.Gate(pow2Gate, u) // u² + return gkrApi.Gate(pow2TimesGate, v, u) // u⁵ + } + case 7: + sBox = func(u gkr.Variable) gkr.Variable { + v := gkrApi.Gate(pow3Gate, u) // u³ + return gkrApi.Gate(pow2TimesGate, v, u) // u⁷ + } case 17: sBox = func(u gkr.Variable) gkr.Variable { v := gkrApi.Gate(pow4Gate, u) // u⁴ @@ -282,12 +296,12 @@ func registerGates(p *poseidon2.Parameters, curve ecc.ID) error { halfRf := p.NbFullRounds / 2 extKeySBox := func(round int, varIndex int) error { - _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(ecc.BLS12_377)) + _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(curve)) return err } intKeySBox2 := func(round int) error { - _, err := gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(ecc.BLS12_377)) + _, err := gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(curve)) return err } diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index b510663959..592b9abf7c 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -31,7 +31,7 @@ func gkrPermutationsCircuits(n int) (circuit, assignment testGkrPermutationCircu func TestGkrCompression(t *testing.T) { circuit, assignment := gkrPermutationsCircuits(2) - test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BLS12_377)) + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) } type testGkrPermutationCircuit struct { From b75843dec8c8546984630e4d8372aad8fb6f94e7 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:44:18 -0500 Subject: [PATCH 40/92] feat gkr-poseidon2 hasher --- .../poseidon2/gkr-poseidon2/gkr-poseidon2.go | 18 ++++++++++++++++++ std/hash/poseidon2/poseidon2.go | 6 +++--- std/hash/poseidon2/poseidon2_test.go | 16 ++++++++++++---- 3 files changed, 33 insertions(+), 7 deletions(-) create mode 100644 std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go diff --git a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go new file mode 100644 index 0000000000..db1566bbc5 --- /dev/null +++ b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -0,0 +1,18 @@ +package gkr_poseidon2 + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + _ "github.com/consensys/gnark/std/hash/all" + gkr_poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2" +) + +func NewGkrPoseidon2(api frontend.API) (hash.FieldHasher, error) { + f, err := gkr_poseidon2.NewGkrPermutations(api) + if err != nil { + return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) + } + return hash.NewMerkleDamgardHasher(api, f, 0), nil +} diff --git a/std/hash/poseidon2/poseidon2.go b/std/hash/poseidon2/poseidon2.go index 804740ff7c..a5b562fcb2 100644 --- a/std/hash/poseidon2/poseidon2.go +++ b/std/hash/poseidon2/poseidon2.go @@ -8,9 +8,9 @@ import ( "github.com/consensys/gnark/std/permutation/poseidon2" ) -// NewMerkleDamgardHasher returns a Poseidon2 hasher using the Merkle-Damgard +// NewPoseidon2 returns a Poseidon2 hasher using the Merkle-Damgard // construction with the default parameters. -func NewMerkleDamgardHasher(api frontend.API) (hash.FieldHasher, error) { +func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { f, err := poseidon2.NewPoseidon2(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) @@ -19,5 +19,5 @@ func NewMerkleDamgardHasher(api frontend.API) (hash.FieldHasher, error) { } func init() { - hash.Register(hash.POSEIDON2, NewMerkleDamgardHasher) + hash.Register(hash.POSEIDON2, NewPoseidon2) } diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 1ce1d46fef..4a5374258c 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -1,11 +1,13 @@ -package poseidon2 +package poseidon2_test import ( "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" + poseidonbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/poseidon2" + gkr_poseidon2 "github.com/consensys/gnark/std/hash/poseidon2/gkr-poseidon2" "github.com/consensys/gnark/test" ) @@ -15,12 +17,18 @@ type Poseidon2Circuit struct { } func (c *Poseidon2Circuit) Define(api frontend.API) error { - hsh, err := NewMerkleDamgardHasher(api) + hsh, err := poseidon2.NewPoseidon2(api) + if err != nil { + return err + } + gkr, err := gkr_poseidon2.NewGkrPoseidon2(api) if err != nil { return err } hsh.Write(c.Input...) api.AssertIsEqual(hsh.Sum(), c.Expected) + gkr.Write(c.Input...) + api.AssertIsEqual(gkr.Sum(), c.Expected) return nil } @@ -29,7 +37,7 @@ func TestPoseidon2Hash(t *testing.T) { const nbInputs = 5 // prepare expected output - h := poseidon2.NewMerkleDamgardHasher() + h := poseidonbls12377.NewMerkleDamgardHasher() circInput := make([]frontend.Variable, nbInputs) for i := range nbInputs { _, err := h.Write([]byte{byte(i)}) From be8a07e29dc7019aca7370a9a480d7ec3d6c2ac8 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:49:38 -0500 Subject: [PATCH 41/92] fix more renaming --- std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go | 4 ++-- std/hash/poseidon2/poseidon2.go | 6 +++--- std/hash/poseidon2/poseidon2_test.go | 4 ++-- std/internal/mimc/encrypt.go | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go index db1566bbc5..ffc59dee74 100644 --- a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -9,8 +9,8 @@ import ( gkr_poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2" ) -func NewGkrPoseidon2(api frontend.API) (hash.FieldHasher, error) { - f, err := gkr_poseidon2.NewGkrPermutations(api) +func New(api frontend.API) (hash.FieldHasher, error) { + f, err := gkr_poseidon2.NewGkrCompressor(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) } diff --git a/std/hash/poseidon2/poseidon2.go b/std/hash/poseidon2/poseidon2.go index a5b562fcb2..f53b8716f3 100644 --- a/std/hash/poseidon2/poseidon2.go +++ b/std/hash/poseidon2/poseidon2.go @@ -8,9 +8,9 @@ import ( "github.com/consensys/gnark/std/permutation/poseidon2" ) -// NewPoseidon2 returns a Poseidon2 hasher using the Merkle-Damgard +// New returns a Poseidon2 hasher using the Merkle-Damgard // construction with the default parameters. -func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.FieldHasher, error) { f, err := poseidon2.NewPoseidon2(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) @@ -19,5 +19,5 @@ func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { } func init() { - hash.Register(hash.POSEIDON2, NewPoseidon2) + hash.Register(hash.POSEIDON2, New) } diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 4a5374258c..c3998ccc5b 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -17,11 +17,11 @@ type Poseidon2Circuit struct { } func (c *Poseidon2Circuit) Define(api frontend.API) error { - hsh, err := poseidon2.NewPoseidon2(api) + hsh, err := poseidon2.New(api) if err != nil { return err } - gkr, err := gkr_poseidon2.NewGkrPoseidon2(api) + gkr, err := gkr_poseidon2.New(api) if err != nil { return err } diff --git a/std/internal/mimc/encrypt.go b/std/internal/mimc/encrypt.go index 0d45a81506..9c499be976 100644 --- a/std/internal/mimc/encrypt.go +++ b/std/internal/mimc/encrypt.go @@ -106,7 +106,7 @@ func newMimcBW633(api frontend.API) MiMC { } // ------------------------------------------------------------------------------------------------- -// encryptions functions +// encryption functions func pow5(api frontend.API, x frontend.Variable) frontend.Variable { r := api.Mul(x, x) From 1894cfafae5c97778c8eef7eb89637f0a8221049 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 9 Jun 2025 17:49:28 -0500 Subject: [PATCH 42/92] feat: gkrmimc for sbox degree 5 --- std/permutation/gkr-mimc/gkr-mimc.go | 143 ++++++++++++++++++ .../{gkr.go => gkr-poseidon2.go} | 0 .../{gkr_test.go => gkr-poseidon2_test.go} | 0 3 files changed, 143 insertions(+) create mode 100644 std/permutation/gkr-mimc/gkr-mimc.go rename std/permutation/poseidon2/gkr-poseidon2/{gkr.go => gkr-poseidon2.go} (100%) rename std/permutation/poseidon2/gkr-poseidon2/{gkr_test.go => gkr-poseidon2_test.go} (100%) diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go new file mode 100644 index 0000000000..f495f2b981 --- /dev/null +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -0,0 +1,143 @@ +package gkr_mimc + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" + bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" + bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" + bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" + bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" + "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/std/gkrapi" + "github.com/consensys/gnark/std/gkrapi/gkr" +) + +// mimcCompressor implements a compression function by applying +// the Miyaguchi–Preneel transformation to the MiMC encryption function. +type mimcCompressor struct { + gkrCircuit *gkrapi.Circuit + in0, in1, out gkr.Variable +} + +func newGkrCompressor(api frontend.API) (*mimcCompressor, error) { + gkrApi := gkrapi.New() + + in0 := gkrApi.NewInput() + in1 := gkrApi.NewInput() + + y := in1 + + curve := utils.FieldToCurve(api.Compiler().Field()) + params, _, err := getParams(curve) // params is only used for its length + if err != nil { + return nil, err + } + if err = RegisterGates(curve); err != nil { + return nil, err + } + gateNamer := newGateNamer(curve) + + for i := range len(params) - 1 { + y = gkrApi.NamedGate(gateNamer.round(i), in0, y) + } + + y = gkrApi.NamedGate(gateNamer.round(len(params)-1), in0, y, in1) + + return &mimcCompressor{ + gkrCircuit: gkrApi.Compile(api, "poseidon2"), + in0: in0, + in1: in1, + out: y, + }, nil +} + +func RegisterGates(curves ...ecc.ID) error { + for _, curve := range curves { + constants, deg, err := getParams(curve) + if err != nil { + return err + } + gateNamer := newGateNamer(curve) + var lastLayerSBox, nonLastLayerSBox func(*big.Int) gkr.GateFunction + switch deg { + case 5: + lastLayerSBox = addPow5Add + nonLastLayerSBox = addPow5 + default: + return fmt.Errorf("s-Box of degree %d not supported", deg) + } + + for i := range len(constants) - 1 { + if _, err = gkrgates.Register(nonLastLayerSBox(&constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", i, curve, err) + } + } + + if _, err = gkrgates.Register(lastLayerSBox(&constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", len(constants)-1, curve, err) + } + } + return nil +} + +// getParams returns the parameters for the MiMC encryption function for the given curve. +// It also returns the degree of the s-Box +func getParams(curve ecc.ID) ([]big.Int, int, error) { + switch curve { + case ecc.BN254: + return bn254.GetConstants(), 5, nil + case ecc.BLS12_381: + return bls12381.GetConstants(), 5, nil + case ecc.BLS12_377: + return bls12377.GetConstants(), 17, nil + case ecc.BLS24_315: + return bls24315.GetConstants(), 5, nil + case ecc.BLS24_317: + return bls24317.GetConstants(), 7, nil + case ecc.BW6_633: + return bw6633.GetConstants(), 5, nil + case ecc.BW6_761: + return bw6761.GetConstants(), 5, nil + default: + return nil, -1, fmt.Errorf("unsupported curve ID: %s", curve) + } +} + +type gateNamer string + +func newGateNamer(o fmt.Stringer) gateNamer { + return gateNamer("MiMC-" + o.String() + "-round-") +} +func (n gateNamer) round(i int) gkr.GateName { + return gkr.GateName(fmt.Sprintf("%s%d", string(n), i)) +} + +func addPow5(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 2 { + panic("expected two input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Mul(t, t, s) + } +} + +// addPow5Add: (in[0]+in[1]+key)⁵ + in[0] + in[2] +func addPow5Add(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 3 { + panic("expected three input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Add(api.Mul(t, t, s), in[0], in[2]) + } +} diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go similarity index 100% rename from std/permutation/poseidon2/gkr-poseidon2/gkr.go rename to std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go similarity index 100% rename from std/permutation/poseidon2/gkr-poseidon2/gkr_test.go rename to std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go From 241d64533f491642261b5945029d4ac128ccf7da Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 9 Jun 2025 22:19:50 -0500 Subject: [PATCH 43/92] mimc length 1 works --- std/hash/mimc/gkr-mimc/gkr-mimc.go | 17 ++++++ std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 66 ++++++++++++++++++++++ std/permutation/gkr-mimc/gkr-mimc.go | 73 ++++++++++++++++++++++++- 3 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 std/hash/mimc/gkr-mimc/gkr-mimc.go create mode 100644 std/hash/mimc/gkr-mimc/gkr-mimc_test.go diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc.go b/std/hash/mimc/gkr-mimc/gkr-mimc.go new file mode 100644 index 0000000000..26be877f41 --- /dev/null +++ b/std/hash/mimc/gkr-mimc/gkr-mimc.go @@ -0,0 +1,17 @@ +package gkr_mimc + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + gkr_mimc "github.com/consensys/gnark/std/permutation/gkr-mimc" +) + +func New(api frontend.API) (hash.FieldHasher, error) { + f, err := gkr_mimc.NewCompressor(api) + if err != nil { + return nil, fmt.Errorf("could not create mimc hasher: %w", err) + } + return hash.NewMerkleDamgardHasher(api, f, 0), nil +} diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go new file mode 100644 index 0000000000..aa558dfd96 --- /dev/null +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -0,0 +1,66 @@ +package gkr_mimc + +import ( + "slices" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/consensys/gnark/test" +) + +func TestGkrMiMC(t *testing.T) { + lengths := []int{1, 2, 3} + vals := make([]frontend.Variable, len(lengths)*2) + for i := range vals { + vals[i] = i + 1 + } + + for _, length := range lengths[1:2] { + circuit := &testGkrMiMCCircuit{ + In: make([]frontend.Variable, length*2), + } + assignment := &testGkrMiMCCircuit{ + In: slices.Clone(vals[:length*2]), + } + + test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254)) + } +} + +type testGkrMiMCCircuit struct { + In []frontend.Variable + skipCheck bool +} + +func (c *testGkrMiMCCircuit) Define(api frontend.API) error { + gkrmimc, err := New(api) + if err != nil { + return err + } + + plainMiMC, err := mimc.New(api) + if err != nil { + return err + } + + // first check that empty input is handled correctly + api.AssertIsEqual(gkrmimc.Sum(), plainMiMC.Sum()) + + ins := [][]frontend.Variable{c.In[:len(c.In)/2], c.In[len(c.In)/2:]} + for _, in := range ins { + gkrmimc.Reset() + gkrmimc.Write(in...) + res := gkrmimc.Sum() + + if !c.skipCheck { + plainMiMC.Reset() + plainMiMC.Write(in...) + expected := plainMiMC.Sum() + api.AssertIsEqual(res, expected) + } + } + + return nil +} diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index f495f2b981..c1d98d67f1 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -17,6 +17,8 @@ import ( "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" + "github.com/consensys/gnark/std/hash" + _ "github.com/consensys/gnark/std/hash/all" ) // mimcCompressor implements a compression function by applying @@ -26,7 +28,15 @@ type mimcCompressor struct { in0, in1, out gkr.Variable } -func newGkrCompressor(api frontend.API) (*mimcCompressor, error) { +func (c *mimcCompressor) Compress(x frontend.Variable, y frontend.Variable) frontend.Variable { + res, err := c.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{c.in0: x, c.in1: y}) + if err != nil { + panic(err) + } + return res[c.out] +} + +func NewCompressor(api frontend.API) (hash.Compressor, error) { gkrApi := gkrapi.New() in0 := gkrApi.NewInput() @@ -51,7 +61,7 @@ func newGkrCompressor(api frontend.API) (*mimcCompressor, error) { y = gkrApi.NamedGate(gateNamer.round(len(params)-1), in0, y, in1) return &mimcCompressor{ - gkrCircuit: gkrApi.Compile(api, "poseidon2"), + gkrCircuit: gkrApi.Compile(api, "POSEIDON2"), in0: in0, in1: in1, out: y, @@ -70,6 +80,12 @@ func RegisterGates(curves ...ecc.ID) error { case 5: lastLayerSBox = addPow5Add nonLastLayerSBox = addPow5 + case 7: + lastLayerSBox = addPow7Add + nonLastLayerSBox = addPow7 + case 17: + lastLayerSBox = addPow17Add + nonLastLayerSBox = addPow17 default: return fmt.Errorf("s-Box of degree %d not supported", deg) } @@ -141,3 +157,56 @@ func addPow5Add(key *big.Int) gkr.GateFunction { return api.Add(api.Mul(t, t, s), in[0], in[2]) } } + +func addPow7(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 2 { + panic("expected two input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Mul(t, t, t, s) // s⁶ × s + } +} + +// addPow7Add: (in[0]+in[1]+key)⁷ + in[0] + in[2] +func addPow7Add(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 3 { + panic("expected three input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Add(api.Mul(t, t, t, s), in[0], in[2]) // s⁶ × s + in[0] + in[2] + } +} + +// addPow17: (in[0]+in[1]+key)¹⁷ +func addPow17(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 2 { + panic("expected two input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) // s² + t = api.Mul(t, t) // s⁴ + t = api.Mul(t, t) // s⁸ + t = api.Mul(t, t) // s¹⁶ + return api.Mul(t, s) // s¹⁶ × s + } +} + +// addPow17Add: (in[0]+in[1]+key)¹⁷ + in[0] + in[2] +func addPow17Add(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 3 { + panic("expected three input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) // s² + t = api.Mul(t, t) // s⁴ + t = api.Mul(t, t) // s⁸ + t = api.Mul(t, t) // s¹⁶ + return api.Add(api.Mul(t, s), in[0], in[2]) // s¹⁶ × s + in[0] + in[2] + } +} From bb3645966028a73f97e9f8a716697b8ffd47b1b6 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 10 Jun 2025 16:20:36 -0500 Subject: [PATCH 44/92] fix final layer --- internal/gkr/bn254/gkr.go | 4 ++-- internal/gkr/engine_hints.go | 2 +- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 17 +++++++++++++++-- std/permutation/gkr-mimc/gkr-mimc.go | 20 +++++++++++--------- test/engine.go | 2 +- 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 14269151b3..0174caa564 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 74b15c77ba..8c8bc1b797 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -187,7 +187,7 @@ func (g gateAPI) Println(a ...frontend.Variable) { for i := range a { if s, ok := a[i].(fmt.Stringer); ok { strings[i] = s.String() - } else { + } else if strings[i], ok = a[i].(string); !ok { bigInt := utils.FromInterface(a[i]) strings[i] = bigInt.String() } diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index aa558dfd96..559861af2d 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -1,13 +1,16 @@ package gkr_mimc import ( + "fmt" "slices" "testing" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" + "github.com/stretchr/testify/require" ) func TestGkrMiMC(t *testing.T) { @@ -25,7 +28,7 @@ func TestGkrMiMC(t *testing.T) { In: slices.Clone(vals[:length*2]), } - test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254)) + test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment)) } } @@ -58,9 +61,19 @@ func (c *testGkrMiMCCircuit) Define(api frontend.API) error { plainMiMC.Reset() plainMiMC.Write(in...) expected := plainMiMC.Sum() - api.AssertIsEqual(res, expected) + api.AssertIsEqual(expected, res) } } return nil } + +func TestGkrMiMCCompiles(t *testing.T) { + const n = 52000 + circuit := testGkrMiMCCircuit{ + In: make([]frontend.Variable, n), + } + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit, frontend.WithCapacity(27_000_000)) + require.NoError(t, err) + fmt.Println(cs.GetNbConstraints(), "constraints") +} diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index c1d98d67f1..df6517f7cf 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -146,7 +146,7 @@ func addPow5(key *big.Int) gkr.GateFunction { } } -// addPow5Add: (in[0]+in[1]+key)⁵ + in[0] + in[2] +// addPow5Add: (in[0]+in[1]+key)⁵ + 2*in[0] + in[2] func addPow5Add(key *big.Int) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { @@ -154,7 +154,9 @@ func addPow5Add(key *big.Int) gkr.GateFunction { } s := api.Add(in[0], in[1], key) t := api.Mul(s, s) - return api.Add(api.Mul(t, t, s), in[0], in[2]) + t = api.Mul(t, t, s) + + return api.Add(t, in[0], in[0], in[2]) } } @@ -169,7 +171,7 @@ func addPow7(key *big.Int) gkr.GateFunction { } } -// addPow7Add: (in[0]+in[1]+key)⁷ + in[0] + in[2] +// addPow7Add: (in[0]+in[1]+key)⁷ + 2*in[0] + in[2] func addPow7Add(key *big.Int) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { @@ -177,7 +179,7 @@ func addPow7Add(key *big.Int) gkr.GateFunction { } s := api.Add(in[0], in[1], key) t := api.Mul(s, s) - return api.Add(api.Mul(t, t, t, s), in[0], in[2]) // s⁶ × s + in[0] + in[2] + return api.Add(api.Mul(t, t, t, s), in[0], in[0], in[2]) // s⁶ × s + 2*in[0] + in[2] } } @@ -203,10 +205,10 @@ func addPow17Add(key *big.Int) gkr.GateFunction { panic("expected three input") } s := api.Add(in[0], in[1], key) - t := api.Mul(s, s) // s² - t = api.Mul(t, t) // s⁴ - t = api.Mul(t, t) // s⁸ - t = api.Mul(t, t) // s¹⁶ - return api.Add(api.Mul(t, s), in[0], in[2]) // s¹⁶ × s + in[0] + in[2] + t := api.Mul(s, s) // s² + t = api.Mul(t, t) // s⁴ + t = api.Mul(t, t) // s⁸ + t = api.Mul(t, t) // s¹⁶ + return api.Add(api.Mul(t, s), in[0], in[0], in[2]) // s¹⁶ × s + 2*in[0] + in[2] } } diff --git a/test/engine.go b/test/engine.go index 79322af440..aaa63ac7ce 100644 --- a/test/engine.go +++ b/test/engine.go @@ -110,7 +110,7 @@ func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEng defer func() { if r := recover(); r != nil { - err = fmt.Errorf("%v\n%s", r, string(debug.Stack())) + err = fmt.Errorf("%v\n%s", r, debug.Stack()) } }() From 311d9d9368d1fede362e4693cc87b59bebfa521c Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 10 Jun 2025 16:22:20 -0500 Subject: [PATCH 45/92] chore generify Println changes --- internal/generator/backend/template/gkr/gkr.go.tmpl | 4 ++-- internal/gkr/bls12-377/gkr.go | 4 ++-- internal/gkr/bls12-381/gkr.go | 4 ++-- internal/gkr/bls24-315/gkr.go | 4 ++-- internal/gkr/bls24-317/gkr.go | 4 ++-- internal/gkr/bw6-633/gkr.go | 4 ++-- internal/gkr/bw6-761/gkr.go | 4 ++-- internal/gkr/small_rational/gkr.go | 4 ++-- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 3e3881d15f..16d5eb970b 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -771,13 +771,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index b92ac1249d..b8ef9ea973 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 82084049d9..8f72898737 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index f182c9176b..7aee277ba4 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index a284f14ae9..7c679216fc 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index ec1067f736..2c2bda2037 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index ad5197feef..099b015b02 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index cdf62359f2..d085c6305f 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) From b3d1af83bba4f91aa2f9ac024ed7e76fde65c55d Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 10 Jun 2025 17:38:55 -0500 Subject: [PATCH 46/92] fix: use multicommitter --- std/gkrapi/compile.go | 60 ++++++++++----------- std/lookup/logderivlookup/logderivlookup.go | 2 +- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index a5f81a283e..61d3b4a736 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -15,6 +15,7 @@ import ( fiatshamir "github.com/consensys/gnark/std/fiat-shamir" "github.com/consensys/gnark/std/gkrapi/gkr" "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/multicommit" ) type circuitDataForSnark struct { @@ -98,7 +99,7 @@ func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, optio res.toStore.ProveHintID = solver.GetHintID(res.hints.Prove) res.toStore.SolveHintID = solver.GetHintID(res.hints.Solve) - parentApi.Compiler().Defer(res.verify) + parentApi.Compiler().Defer(res.finalize) return &res } @@ -140,12 +141,13 @@ func (c *Circuit) AddInstance(input map[gkr.Variable]frontend.Variable) (map[gkr return res, nil } -// verify encodes the verification circuitry for the GKR circuit -func (c *Circuit) verify(api frontend.API) error { +// finalize encodes the verification circuitry for the GKR circuit +func (c *Circuit) finalize(api frontend.API) error { if api != c.api { panic("api mismatch") } + // pad instances to the next power of 2 nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(c.toStore.NbInstances))) // pad instances to the next power of 2 by repeating the last instance if c.toStore.NbInstances < nbPaddedInstances && c.toStore.NbInstances > 0 { @@ -165,31 +167,27 @@ func (c *Circuit) verify(api frontend.API) error { return nil } - var ( - err error - proofSerialized []frontend.Variable - proof gadget.Proof - initialChallenges []frontend.Variable - ) - if c.getInitialChallenges != nil { - initialChallenges = c.getInitialChallenges() - } else { - // default initial challenge is a commitment to all input and output values - initialChallenges = make([]frontend.Variable, 0, (len(c.ins)+len(c.outs))*len(c.assignments[c.ins[0]])) - for _, in := range c.ins { - initialChallenges = append(initialChallenges, c.assignments[in]...) - } - for _, out := range c.outs { - initialChallenges = append(initialChallenges, c.assignments[out]...) - } + return c.verify(api, c.getInitialChallenges()) + } - if initialChallenges[0], err = api.(frontend.Committer).Commit(initialChallenges...); err != nil { - return fmt.Errorf("failed to commit to in/out values: %w", err) - } - initialChallenges = initialChallenges[:1] // use the commitment as the only initial challenge + // default initial challenge is a commitment to all input and output values + insOuts := make([]frontend.Variable, 0, (len(c.ins)+len(c.outs))*len(c.assignments[c.ins[0]])) + for _, in := range c.ins { + insOuts = append(insOuts, c.assignments[in]...) + } + for _, out := range c.outs { + insOuts = append(insOuts, c.assignments[out]...) } + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + return c.verify(api, []frontend.Variable{commitment}) + }, insOuts...) + + return nil +} + +func (c *Circuit) verify(api frontend.API, initialChallenges []frontend.Variable) error { forSnark, err := newCircuitDataForSnark(utils.FieldToCurve(api.Compiler().Field()), c.toStore, c.assignments) if err != nil { return fmt.Errorf("failed to create circuit data for snark: %w", err) @@ -201,8 +199,13 @@ func (c *Circuit) verify(api frontend.API) error { copy(hintIns[1:], initialChallenges) + var ( + proofSerialized []frontend.Variable + proof gadget.Proof + ) + if proofSerialized, err = api.Compiler().NewHint( - c.hints.Prove, gadget.ProofSize(forSnark.circuit, bits.TrailingZeros(uint(nbPaddedInstances))), hintIns...); err != nil { + c.hints.Prove, gadget.ProofSize(forSnark.circuit, bits.TrailingZeros(uint(len(c.assignments[0])))), hintIns...); err != nil { return err } c.toStore.ProveHintID = solver.GetHintID(c.hints.Prove) @@ -218,12 +221,7 @@ func (c *Circuit) verify(api frontend.API) error { return err } - err = gadget.Verify(api, forSnark.circuit, forSnark.assignments, proof, fiatshamir.WithHash(hsh, initialChallenges...), gadget.WithSortedCircuit(forSnarkSorted)) - if err != nil { - return err - } - - return nil + return gadget.Verify(api, forSnark.circuit, forSnark.assignments, proof, fiatshamir.WithHash(hsh, initialChallenges...), gadget.WithSortedCircuit(forSnarkSorted)) } func slicePtrAt[T any](slice []T) func(int) *T { diff --git a/std/lookup/logderivlookup/logderivlookup.go b/std/lookup/logderivlookup/logderivlookup.go index 63f2bc694d..dbeb042762 100644 --- a/std/lookup/logderivlookup/logderivlookup.go +++ b/std/lookup/logderivlookup/logderivlookup.go @@ -1,4 +1,4 @@ -// Package logderiv implements append-only lookups using log-derivative +// Package logderivlookup implements append-only lookups using log-derivative // argument. // // The lookup is based on log-derivative argument as described in [logderivarg]. From 815adc76b383667b65c74c6f9c59b2eaa94db076 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 10 Jun 2025 17:54:31 -0500 Subject: [PATCH 47/92] feat: use kvstore for caching instances --- .../poseidon2/gkr-poseidon2/gkr-poseidon2.go | 2 +- std/permutation/gkr-mimc/gkr-mimc.go | 39 ++++++++++++++----- .../poseidon2/gkr-poseidon2/gkr-poseidon2.go | 31 ++++++++++++--- .../gkr-poseidon2/gkr-poseidon2_test.go | 2 +- 4 files changed, 57 insertions(+), 17 deletions(-) diff --git a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go index ffc59dee74..88c8baf260 100644 --- a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -10,7 +10,7 @@ import ( ) func New(api frontend.API) (hash.FieldHasher, error) { - f, err := gkr_poseidon2.NewGkrCompressor(api) + f, err := gkr_poseidon2.NewCompressor(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) } diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index df6517f7cf..266ee00e67 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -14,6 +14,7 @@ import ( bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" @@ -21,14 +22,14 @@ import ( _ "github.com/consensys/gnark/std/hash/all" ) -// mimcCompressor implements a compression function by applying +// compressor implements a compression function by applying // the Miyaguchi–Preneel transformation to the MiMC encryption function. -type mimcCompressor struct { +type compressor struct { gkrCircuit *gkrapi.Circuit in0, in1, out gkr.Variable } -func (c *mimcCompressor) Compress(x frontend.Variable, y frontend.Variable) frontend.Variable { +func (c *compressor) Compress(x frontend.Variable, y frontend.Variable) frontend.Variable { res, err := c.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{c.in0: x, c.in1: y}) if err != nil { panic(err) @@ -37,6 +38,20 @@ func (c *mimcCompressor) Compress(x frontend.Variable, y frontend.Variable) fron } func NewCompressor(api frontend.API) (hash.Compressor, error) { + + store, ok := api.(kvstore.Store) + if !ok { + return nil, fmt.Errorf("api of type %T does not implement kvstore.Store", api) + } + + cached := store.GetKeyValue(gkrMiMCKey{}) + if cached != nil { + if compressor, ok := cached.(*compressor); ok { + return compressor, nil + } + return nil, fmt.Errorf("cached value is of type %T, not a compressor", cached) + } + gkrApi := gkrapi.New() in0 := gkrApi.NewInput() @@ -60,12 +75,16 @@ func NewCompressor(api frontend.API) (hash.Compressor, error) { y = gkrApi.NamedGate(gateNamer.round(len(params)-1), in0, y, in1) - return &mimcCompressor{ - gkrCircuit: gkrApi.Compile(api, "POSEIDON2"), - in0: in0, - in1: in1, - out: y, - }, nil + res := + &compressor{ + gkrCircuit: gkrApi.Compile(api, "POSEIDON2"), + in0: in0, + in1: in1, + out: y, + } + + store.SetKeyValue(gkrMiMCKey{}, res) + return res, nil } func RegisterGates(curves ...ecc.ID) error { @@ -212,3 +231,5 @@ func addPow17Add(key *big.Int) gkr.GateFunction { return api.Add(api.Mul(t, s), in[0], in[0], in[2]) // s¹⁶ × s + 2*in[0] + in[2] } } + +type gkrMiMCKey struct{} diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go index 9fbd53246c..4bb0b7a6a7 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -5,9 +5,11 @@ import ( "fmt" "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" + "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/permutation/poseidon2" "github.com/consensys/gnark-crypto/ecc" @@ -117,31 +119,46 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Add(api.Mul(x[0], 2), x[1], x[2]) } -type GkrCompressor struct { +type compressor struct { api frontend.API gkrCircuit *gkrapi.Circuit in1, in2, out gkr.Variable } -// NewGkrCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) +// NewCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. // Note that the solver will need the function RegisterGates to be called with the desired curves -func NewGkrCompressor(api frontend.API) (*GkrCompressor, error) { +func NewCompressor(api frontend.API) (hash.Compressor, error) { + store, ok := api.(kvstore.Store) + if !ok { + return nil, fmt.Errorf("api of type %T does not implement kvstore.Store", api) + } + + cached := store.GetKeyValue(gkrPoseidon2Key{}) + if cached != nil { + if compressor, ok := cached.(*compressor); ok { + return compressor, nil + } + return nil, fmt.Errorf("cached value is of type %T, not a mimcCompressor", cached) + } + gkrCircuit, in1, in2, out, err := defineCircuit(api) if err != nil { return nil, fmt.Errorf("failed to define GKR circuit: %w", err) } - return &GkrCompressor{ + res := &compressor{ api: api, gkrCircuit: gkrCircuit, in1: in1, in2: in2, out: out, - }, nil + } + store.SetKeyValue(gkrPoseidon2Key{}, res) + return res, nil } -func (p *GkrCompressor) Compress(a, b frontend.Variable) frontend.Variable { +func (p *compressor) Compress(a, b frontend.Variable) frontend.Variable { outs, err := p.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{p.in1: a, p.in2: b}) if err != nil { panic(err) @@ -365,3 +382,5 @@ func (n roundGateNamer) linear(varIndex, round int) gkr.GateName { func (n roundGateNamer) integrated(varIndex, round int) gkr.GateName { return gkr.GateName(fmt.Sprintf("x%d-i-op-round=%d;%s", varIndex, round, n)) } + +type gkrPoseidon2Key struct{} diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go index c5214364b8..7c21281d66 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go @@ -41,7 +41,7 @@ type testGkrPermutationCircuit struct { func (c *testGkrPermutationCircuit) Define(api frontend.API) error { - gkr, err := NewGkrCompressor(api) + gkr, err := NewCompressor(api) if err != nil { return err } From 7ab8702d5c8ae79c579f6bca65b45fc93ba890be Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 11 Jun 2025 10:34:50 -0500 Subject: [PATCH 48/92] feat: merkledamgard hasher as statestorer --- std/hash/hash.go | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/std/hash/hash.go b/std/hash/hash.go index c077fd0d37..564f122a48 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -5,6 +5,8 @@ package hash import ( + "fmt" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" ) @@ -110,7 +112,7 @@ type merkleDamgardHasher struct { // NewMerkleDamgardHasher transforms a 2-1 one-way function into a hash // initialState is a value whose preimage is not known -func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) FieldHasher { +func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) StateStorer { return &merkleDamgardHasher{ state: initialState, iv: initialState, @@ -132,3 +134,18 @@ func (h *merkleDamgardHasher) Write(data ...frontend.Variable) { func (h *merkleDamgardHasher) Sum() frontend.Variable { return h.state } + +func (h *merkleDamgardHasher) State() []frontend.Variable { + return []frontend.Variable{h.state} +} + +func (h *merkleDamgardHasher) SetState(state []frontend.Variable) error { + if h.state != h.iv { + return fmt.Errorf("the hasher is not in an initial state; reset before attempting to set the state") + } + if len(state) != 1 { + return fmt.Errorf("expected one state variable, got %d", len(state)) + } + h.state = state[0] + return nil +} From 2fb7daa16a5dcff8806bc4a569621d734bccf321 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 11 Jun 2025 10:51:26 -0500 Subject: [PATCH 49/92] test: SetState --- std/hash/mimc/gkr-mimc/gkr-mimc.go | 2 +- .../poseidon2/gkr-poseidon2/gkr-poseidon2.go | 2 +- std/hash/poseidon2/poseidon2.go | 6 +- std/hash/poseidon2/poseidon2_test.go | 60 ++++++++++++++++++- 4 files changed, 63 insertions(+), 7 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc.go b/std/hash/mimc/gkr-mimc/gkr-mimc.go index 26be877f41..8e6a8766d8 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc.go @@ -8,7 +8,7 @@ import ( gkr_mimc "github.com/consensys/gnark/std/permutation/gkr-mimc" ) -func New(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.StateStorer, error) { f, err := gkr_mimc.NewCompressor(api) if err != nil { return nil, fmt.Errorf("could not create mimc hasher: %w", err) diff --git a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go index 88c8baf260..bbbef1f87c 100644 --- a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -9,7 +9,7 @@ import ( gkr_poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2" ) -func New(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.StateStorer, error) { f, err := gkr_poseidon2.NewCompressor(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) diff --git a/std/hash/poseidon2/poseidon2.go b/std/hash/poseidon2/poseidon2.go index f53b8716f3..e15c4ca587 100644 --- a/std/hash/poseidon2/poseidon2.go +++ b/std/hash/poseidon2/poseidon2.go @@ -10,7 +10,7 @@ import ( // New returns a Poseidon2 hasher using the Merkle-Damgard // construction with the default parameters. -func New(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.StateStorer, error) { f, err := poseidon2.NewPoseidon2(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) @@ -19,5 +19,7 @@ func New(api frontend.API) (hash.FieldHasher, error) { } func init() { - hash.Register(hash.POSEIDON2, New) + hash.Register(hash.POSEIDON2, func(api frontend.API) (hash.FieldHasher, error) { + return New(api) + }) } diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index c3998ccc5b..f6c57736df 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -11,12 +11,12 @@ import ( "github.com/consensys/gnark/test" ) -type Poseidon2Circuit struct { +type poseidon2Circuit struct { Input []frontend.Variable Expected frontend.Variable `gnark:",public"` } -func (c *Poseidon2Circuit) Define(api frontend.API) error { +func (c *poseidon2Circuit) Define(api frontend.API) error { hsh, err := poseidon2.New(api) if err != nil { return err @@ -45,5 +45,59 @@ func TestPoseidon2Hash(t *testing.T) { circInput[i] = i } res := h.Sum(nil) - assert.CheckCircuit(&Poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&Poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377 + assert.CheckCircuit(&poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377 +} + +func TestStateStorer(t *testing.T) { + assignment := testStateStorerCircuit{ + Input: [][]frontend.Variable{ + {0, 1, 2, 3, 4}, + }, + } + + circuit := testStateStorerCircuit{ + Input: make([][]frontend.Variable, len(assignment.Input)), + } + for i := range assignment.Input { + circuit.Input[i] = make([]frontend.Variable, len(assignment.Input[i])) + } + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} + +type testStateStorerCircuit struct { + Input [][]frontend.Variable +} + +func (c *testStateStorerCircuit) Define(api frontend.API) error { + // hashes the whole input in one go + hshFull, err := poseidon2.New(api) + if err != nil { + return err + } + + // hashes the input in two parts + hshPartial, err := poseidon2.New(api) + if err != nil { + return err + } + + for _, input := range c.Input { + // compute desired output + hshFull.Reset() + hshFull.Write(input...) + digest := hshFull.Sum() + + hshPartial.Reset() + hshPartial.Write(input[:len(input)/2]...) + state := hshPartial.State() + hshPartial.Reset() + api.AssertIsEqual(hshPartial.State()[0], 0) + if err = hshPartial.SetState(state); err != nil { + return err + } + hshPartial.Write(input[len(input)/2:]...) + api.AssertIsEqual(hshPartial.Sum(), digest) + } + return nil } From de55a270eb9c81ca37cfefddb0ca52eca2bbe4ce Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 11 Jun 2025 11:01:41 -0500 Subject: [PATCH 50/92] feat: mimc.New to return StateStorer --- std/hash/mimc/mimc.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/std/hash/mimc/mimc.go b/std/hash/mimc/mimc.go index 9d8a98e306..db72f54429 100644 --- a/std/hash/mimc/mimc.go +++ b/std/hash/mimc/mimc.go @@ -34,7 +34,7 @@ func NewMiMC(api frontend.API) (MiMC, error) { // NB! See the package documentation for length extension attack consideration. // // [gnark-crypto]: https://pkg.go.dev/github.com/consensys/gnark-crypto/hash -func New(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.StateStorer, error) { h, err := NewMiMC(api) if err != nil { return nil, err @@ -43,5 +43,7 @@ func New(api frontend.API) (hash.FieldHasher, error) { } func init() { - hash.Register(hash.MIMC, New) + hash.Register(hash.MIMC, func(api frontend.API) (hash.FieldHasher, error) { + return New(api) + }) } From 31a0a598cfe7b562d3bd789c793f10b6b7cdb5dd Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 11 Jun 2025 13:41:31 -0500 Subject: [PATCH 51/92] fix: single instance and no instance edge cases --- std/gkrapi/api_test.go | 34 +++++++++++++++++++++++++++++++++- std/gkrapi/compile.go | 25 ++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/std/gkrapi/api_test.go b/std/gkrapi/api_test.go index 53642bc6a7..0d22ec80a3 100644 --- a/std/gkrapi/api_test.go +++ b/std/gkrapi/api_test.go @@ -649,6 +649,38 @@ func TestWitnessExtend(t *testing.T) { _, err = cs.Solve(witness) require.NoError(t, err) +} + +func TestSingleInstance(t *testing.T) { + circuit := doubleNoDependencyCircuit{ + X: make([]frontend.Variable, 1), + hashName: "MIMC", + } + assignment := doubleNoDependencyCircuit{ + X: []frontend.Variable{10}, + } + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} + +func TestNoInstance(t *testing.T) { + var circuit testNoInstanceCircuit + assignment := testNoInstanceCircuit{0} - //test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} + +type testNoInstanceCircuit struct { + Dummy frontend.Variable // Plonk prover would fail on an empty witness +} + +func (c *testNoInstanceCircuit) Define(api frontend.API) error { + gkrApi := New() + x := gkrApi.NewInput() + y := gkrApi.NewInput() + gkrApi.Mul(x, y) + + gkrApi.Compile(api, "MIMC") + + return nil } diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 61d3b4a736..99f8ab9dcc 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -147,6 +147,11 @@ func (c *Circuit) finalize(api frontend.API) error { panic("api mismatch") } + // if the circuit is empty or with no instances, there is nothing to do. + if len(c.outs) == 0 || len(c.assignments[0]) == 0 { // wire 0 is always an input wire + return nil + } + // pad instances to the next power of 2 nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(c.toStore.NbInstances))) // pad instances to the next power of 2 by repeating the last instance @@ -163,7 +168,25 @@ func (c *Circuit) finalize(api frontend.API) error { return err } - if len(c.outs) == 0 || len(c.assignments[0]) == 0 { // wire 0 is always an input wire + // if the circuit consists of only one instance, directly solve the circuit + if len(c.assignments[c.ins[0]]) == 1 { + circuit, err := gkrtypes.CircuitInfoToCircuit(c.toStore.Circuit, gkrgates.Get) + if err != nil { + return fmt.Errorf("failed to convert GKR info to circuit: %w", err) + } + gateIn := make([]frontend.Variable, circuit.MaxGateNbIn()) + for wI, w := range circuit { + if w.IsInput() { + continue + } + for inI, inWI := range w.Inputs { + gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance + } + res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) + if w.IsOutput() { + api.AssertIsEqual(res, c.assignments[gkr.Variable(wI)][0]) + } + } return nil } From db56157c9334e7715a2a9bd0c32d7e6bb444bfa5 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 11 Jun 2025 15:17:55 -0500 Subject: [PATCH 52/92] fix: single-instance, circuit with depth --- std/gkrapi/api_test.go | 27 +++++++++++++++++---------- std/gkrapi/compile.go | 4 +++- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/std/gkrapi/api_test.go b/std/gkrapi/api_test.go index 0d22ec80a3..1f3d187dae 100644 --- a/std/gkrapi/api_test.go +++ b/std/gkrapi/api_test.go @@ -275,8 +275,6 @@ func benchProof(b *testing.B, circuit, assignment frontend.Circuit) { _, err = groth16.Prove(cs, pk, fullWitness) require.NoError(b, err) fmt.Println("groth16 proved", id, "in", time.Now().UnixMicro()-start, "μs") - - fmt.Println("mimc total calls: fr=", mimcFrTotalCalls, ", snark=", mimcSnarkTotalCalls) } } @@ -364,8 +362,6 @@ func (c constPseudoHash) Write(...frontend.Variable) {} func (c constPseudoHash) Reset() {} -var mimcFrTotalCalls = 0 - type mimcNoGkrCircuit struct { X []frontend.Variable Y []frontend.Variable @@ -418,7 +414,15 @@ func (c *mimcNoDepCircuit) Define(api frontend.API) error { gkrApi := New() x := gkrApi.NewInput() y := gkrApi.NewInput() - gkrApi.Gate(mimcGate, x, y) + + if c.mimcDepth < 1 { + return fmt.Errorf("mimcDepth must be at least 1, got %d", c.mimcDepth) + } + + z := y + for range c.mimcDepth { + z = gkrApi.Gate(mimcGate, x, z) + } gkrCircuit := gkrApi.Compile(api, c.hashName) @@ -652,12 +656,15 @@ func TestWitnessExtend(t *testing.T) { } func TestSingleInstance(t *testing.T) { - circuit := doubleNoDependencyCircuit{ - X: make([]frontend.Variable, 1), - hashName: "MIMC", + circuit := mimcNoDepCircuit{ + X: make([]frontend.Variable, 1), + Y: make([]frontend.Variable, 1), + mimcDepth: 2, + hashName: "MIMC", } - assignment := doubleNoDependencyCircuit{ + assignment := mimcNoDepCircuit{ X: []frontend.Variable{10}, + Y: []frontend.Variable{2}, } test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) @@ -677,7 +684,7 @@ type testNoInstanceCircuit struct { func (c *testNoInstanceCircuit) Define(api frontend.API) error { gkrApi := New() x := gkrApi.NewInput() - y := gkrApi.NewInput() + y := gkrApi.Mul(x, x) gkrApi.Mul(x, y) gkrApi.Compile(api, "MIMC") diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 99f8ab9dcc..390647b89d 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -184,7 +184,9 @@ func (c *Circuit) finalize(api frontend.API) error { } res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) if w.IsOutput() { - api.AssertIsEqual(res, c.assignments[gkr.Variable(wI)][0]) + api.AssertIsEqual(res, c.assignments[wI][0]) + } else { + c.assignments[wI] = append(c.assignments[wI], res) } } return nil From 4ca453c5467bc681708288a88153d0a5b7505d05 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 8 Jul 2025 17:18:34 -0500 Subject: [PATCH 53/92] docs: address cursor comments --- .../backend/template/gkr/solver_hints.go.tmpl | 4 ++-- internal/gkr/bls12-377/solver_hints.go | 4 ++-- internal/gkr/bls12-381/solver_hints.go | 4 ++-- internal/gkr/bls24-315/solver_hints.go | 4 ++-- internal/gkr/bls24-317/solver_hints.go | 4 ++-- internal/gkr/bn254/solver_hints.go | 4 ++-- internal/gkr/bw6-633/solver_hints.go | 4 ++-- internal/gkr/bw6-761/solver_hints.go | 4 ++-- internal/utils/slices.go | 11 ++++++----- 9 files changed, 22 insertions(+), 21 deletions(-) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index 29698e0e3b..04873fdcb8 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -80,7 +80,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) func GetAssignmentHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { if len(ins) != 3 { - return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() @@ -149,7 +149,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } -// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// RepeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. {{ print "// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}}"}} func (a WireAssignment) RepeatUntilEnd(n int) { for i := range a { diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 04c5f52586..c977d4997a 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -86,7 +86,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) func GetAssignmentHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { if len(ins) != 3 { - return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() @@ -155,7 +155,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } -// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// RepeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. // e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} func (a WireAssignment) RepeatUntilEnd(n int) { for i := range a { diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index e92e543398..81572a4ac4 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -86,7 +86,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) func GetAssignmentHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { if len(ins) != 3 { - return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() @@ -155,7 +155,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } -// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// RepeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. // e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} func (a WireAssignment) RepeatUntilEnd(n int) { for i := range a { diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index f57537b985..783cc964c8 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -86,7 +86,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) func GetAssignmentHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { if len(ins) != 3 { - return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() @@ -155,7 +155,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } -// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// RepeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. // e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} func (a WireAssignment) RepeatUntilEnd(n int) { for i := range a { diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index d2cc4d32b1..234a327324 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -86,7 +86,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) func GetAssignmentHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { if len(ins) != 3 { - return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() @@ -155,7 +155,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } -// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// RepeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. // e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} func (a WireAssignment) RepeatUntilEnd(n int) { for i := range a { diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 5813b89661..f855222636 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -86,7 +86,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) func GetAssignmentHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { if len(ins) != 3 { - return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() @@ -155,7 +155,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } -// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// RepeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. // e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} func (a WireAssignment) RepeatUntilEnd(n int) { for i := range a { diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index ef945e25f7..19d347d099 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -86,7 +86,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) func GetAssignmentHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { if len(ins) != 3 { - return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() @@ -155,7 +155,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } -// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// RepeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. // e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} func (a WireAssignment) RepeatUntilEnd(n int) { for i := range a { diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 1a91928171..09e9c13f0f 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -86,7 +86,7 @@ func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) func GetAssignmentHint(data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { if len(ins) != 3 { - return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } wireI := ins[0].Uint64() instanceI := ins[1].Uint64() @@ -155,7 +155,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } -// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// RepeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. // e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} func (a WireAssignment) RepeatUntilEnd(n int) { for i := range a { diff --git a/internal/utils/slices.go b/internal/utils/slices.go index f493bf4bca..bdd86119fa 100644 --- a/internal/utils/slices.go +++ b/internal/utils/slices.go @@ -17,14 +17,15 @@ func References[T any](v []T) []*T { return res } -// ExtendRepeatLast extends the slice s by repeating the last element until it reaches the length n. +// ExtendRepeatLast extends a non-empty slice s by repeating the last element until it reaches the length n. func ExtendRepeatLast[T any](s []T, n int) []T { if n <= len(s) { return s[:n] } - s = s[:len(s):len(s)] // ensure s is a slice with a capacity equal to its length - for len(s) < n { - s = append(s, s[len(s)-1]) // append the last element until the length is n + res := make([]T, n) + copy(res, s) + for i := len(s); i < n; i++ { + res[i] = res[i-1] } - return s + return res } From 8947d9f74925b69b38fc6d7ed543df8556d1d0dc Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 19 Aug 2025 17:01:20 -0500 Subject: [PATCH 54/92] refactor: remove MapRange --- internal/utils/algo_utils.go | 70 ++----------------- std/gkrapi/compile.go | 8 +-- .../poseidon2/gkr-poseidon2/gkr_test.go | 42 +++-------- 3 files changed, 16 insertions(+), 104 deletions(-) diff --git a/internal/utils/algo_utils.go b/internal/utils/algo_utils.go index f836625370..4bee19443e 100644 --- a/internal/utils/algo_utils.go +++ b/internal/utils/algo_utils.go @@ -24,6 +24,7 @@ func Permute[T any](slice []T, permutation []int) { } } +// Map returns [f(in[0]), f(in[1]), ..., f(in[len(in)-1])] func Map[T, S any](in []T, f func(T) S) []S { out := make([]S, len(in)) for i, t := range in { @@ -32,41 +33,6 @@ func Map[T, S any](in []T, f func(T) S) []S { return out } -func MapRange[S any](begin, end int, f func(int) S) []S { - out := make([]S, end-begin) - for i := begin; i < end; i++ { - out[i] = f(i) - } - return out -} - -func SliceAt[T any](slice []T) func(int) T { - return func(i int) T { - return slice[i] - } -} - -func SlicePtrAt[T any](slice []T) func(int) *T { - return func(i int) *T { - return &slice[i] - } -} - -func MapAt[K comparable, V any](mp map[K]V) func(K) V { - return func(k K) V { - return mp[k] - } -} - -// InvertPermutation input permutation must contain exactly 0, ..., len(permutation)-1 -func InvertPermutation(permutation []int) []int { - res := make([]int, len(permutation)) - for i := range permutation { - res[permutation[i]] = i - } - return res -} - // TODO: Move this to gnark-crypto and use it for gkr there as well // TopologicalSort takes a list of lists of dependencies and proposes a sorting of the lists in order of dependence. Such that for any wire, any one it depends on @@ -143,33 +109,11 @@ func (d *topSortData) markDone(i int) { } } -// BinarySearch looks for toFind in a sorted slice, and returns the index at which it either is or would be were it to be inserted. -func BinarySearch(slice []int, toFind int) int { - var start int - for end := len(slice); start != end; { - mid := (start + end) / 2 - if toFind >= slice[mid] { - start = mid - } - if toFind <= slice[mid] { - end = mid - } +// SliceOfRefs returns [&slice[0], &slice[1], ..., &slice[len(slice)-1]] +func SliceOfRefs[T any](slice []T) []*T { + res := make([]*T, len(slice)) + for i := range slice { + res[i] = &slice[i] } - return start -} - -// BinarySearchFunc looks for toFind in an increasing function of domain 0 ... (end-1), and returns the index at which it either is or would be were it to be inserted. -func BinarySearchFunc(eval func(int) int, end int, toFind int) int { - var start int - for start != end { - mid := (start + end) / 2 - val := eval(mid) - if toFind >= val { - start = mid - } - if toFind <= val { - end = mid - } - } - return start + return res } diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 390647b89d..64205b80b8 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -235,7 +235,7 @@ func (c *Circuit) verify(api frontend.API, initialChallenges []frontend.Variable } c.toStore.ProveHintID = solver.GetHintID(c.hints.Prove) - forSnarkSorted := utils.MapRange(0, len(c.toStore.Circuit), slicePtrAt(forSnark.circuit)) + forSnarkSorted := utils.SliceOfRefs(forSnark.circuit) if proof, err = gadget.DeserializeProof(forSnarkSorted, proofSerialized); err != nil { return err @@ -249,12 +249,6 @@ func (c *Circuit) verify(api frontend.API, initialChallenges []frontend.Variable return gadget.Verify(api, forSnark.circuit, forSnark.assignments, proof, fiatshamir.WithHash(hsh, initialChallenges...), gadget.WithSortedCircuit(forSnarkSorted)) } -func slicePtrAt[T any](slice []T) func(int) *T { - return func(i int) *T { - return &slice[i] - } -} - func newCircuitDataForSnark(curve ecc.ID, info gkrinfo.StoringInfo, assignment gkrtypes.WireAssignment) (circuitDataForSnark, error) { circuit, err := gkrtypes.CircuitInfoToCircuit(info.Circuit, gkrgates.Get) if err != nil { diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index ffa60d8ccb..0a230c4381 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -2,8 +2,6 @@ package gkr_poseidon2 import ( "fmt" - "os" - "runtime/pprof" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -15,7 +13,7 @@ import ( "github.com/stretchr/testify/require" ) -func gkrPermutationsCircuits(t require.TestingT, n int) (circuit, assignment testGkrPermutationCircuit) { +func gkrCompressionCircuits(t require.TestingT, n int) (circuit, assignment testGkrCompressionCircuit) { var k int64 ins := make([][2]frontend.Variable, n) outs := make([]frontend.Variable, n) @@ -34,27 +32,27 @@ func gkrPermutationsCircuits(t require.TestingT, n int) (circuit, assignment tes k += 2 } - return testGkrPermutationCircuit{ + return testGkrCompressionCircuit{ Ins: make([][2]frontend.Variable, len(ins)), Outs: make([]frontend.Variable, len(outs)), - }, testGkrPermutationCircuit{ + }, testGkrCompressionCircuit{ Ins: ins, Outs: outs, } } func TestGkrCompression(t *testing.T) { - circuit, assignment := gkrPermutationsCircuits(t, 2) + circuit, assignment := gkrCompressionCircuits(t, 2) test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BLS12_377)) } -type testGkrPermutationCircuit struct { +type testGkrCompressionCircuit struct { Ins [][2]frontend.Variable Outs []frontend.Variable } -func (c *testGkrPermutationCircuit) Define(api frontend.API) error { +func (c *testGkrCompressionCircuit) Define(api frontend.API) error { pos2 := NewGkrCompressor(api) api.AssertIsEqual(len(c.Ins), len(c.Outs)) @@ -65,36 +63,12 @@ func (c *testGkrPermutationCircuit) Define(api frontend.API) error { return nil } -func TestGkrPermutationCompiles(t *testing.T) { +func TestGkrCompressionCompiles(t *testing.T) { // just measure the number of constraints - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &testGkrPermutationCircuit{ + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &testGkrCompressionCircuit{ Ins: make([][2]frontend.Variable, 52000), Outs: make([]frontend.Variable, 52000), }) require.NoError(t, err) fmt.Println(cs.GetNbConstraints(), "constraints") } - -func BenchmarkGkrPermutations(b *testing.B) { - circuit, assignmment := gkrPermutationsCircuits(b, 50000) - - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) - require.NoError(b, err) - - witness, err := frontend.NewWitness(&assignmment, ecc.BLS12_377.ScalarField()) - require.NoError(b, err) - - // cpu profile - f, err := os.Create("cpu.pprof") - require.NoError(b, err) - defer func() { - require.NoError(b, f.Close()) - }() - - err = pprof.StartCPUProfile(f) - require.NoError(b, err) - defer pprof.StopCPUProfile() - - _, err = cs.Solve(witness) - require.NoError(b, err) -} From 8c37435765a54be6acdb4048a78f0b1c388955ba Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 13:27:37 -0500 Subject: [PATCH 55/92] fix: bad merge --- .../poseidon2/gkr-poseidon2/gkr-poseidon2_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go index f3af71983f..b224bf1414 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go @@ -76,11 +76,6 @@ func BenchmarkGkrCompressions(b *testing.B) { witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) require.NoError(b, err) - // cpu profile - defer func() { - require.NoError(b, f.Close()) - }() - _, err = cs.Solve(witness) require.NoError(b, err) } From 1e208500ab2139899e9c42090eba912e54b73358 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 13:31:18 -0500 Subject: [PATCH 56/92] perf: single-elem pool for bls12-377 --- internal/gkr/bls12-377/gate_testing.go | 5 ++ internal/gkr/bls12-377/gkr.go | 98 ++++++++++++++++---------- internal/gkr/bls12-377/solver_hints.go | 5 +- 3 files changed, 68 insertions(+), 40 deletions(-) diff --git a/internal/gkr/bls12-377/gate_testing.go b/internal/gkr/bls12-377/gate_testing.go index 9e5a3868f3..6088644ea0 100644 --- a/internal/gkr/bls12-377/gate_testing.go +++ b/internal/gkr/bls12-377/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index b8ef9ea973..16e7af5c12 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -104,8 +104,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) + api.freeElements() } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +237,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +270,13 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +674,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +731,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +796,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +804,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index c977d4997a..96b6636151 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input @@ -116,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) From 4c69b6518416ed6e5b1af41c89831d7ce9a1f9ab Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 13:47:13 -0500 Subject: [PATCH 57/92] build: generify changes --- .../backend/template/gkr/gate_testing.go.tmpl | 5 + .../backend/template/gkr/gkr.go.tmpl | 96 +++++++++++-------- .../backend/template/gkr/solver_hints.go.tmpl | 3 +- internal/gkr/bls12-377/gkr.go | 2 - internal/gkr/bls12-377/solver_hints.go | 2 +- internal/gkr/bls12-381/gate_testing.go | 5 + internal/gkr/bls12-381/gkr.go | 96 +++++++++++-------- internal/gkr/bls12-381/solver_hints.go | 3 +- internal/gkr/bls24-315/gate_testing.go | 5 + internal/gkr/bls24-315/gkr.go | 96 +++++++++++-------- internal/gkr/bls24-315/solver_hints.go | 3 +- internal/gkr/bls24-317/gate_testing.go | 5 + internal/gkr/bls24-317/gkr.go | 96 +++++++++++-------- internal/gkr/bls24-317/solver_hints.go | 3 +- internal/gkr/bn254/gate_testing.go | 5 + internal/gkr/bn254/gkr.go | 96 +++++++++++-------- internal/gkr/bn254/solver_hints.go | 3 +- internal/gkr/bw6-633/gate_testing.go | 5 + internal/gkr/bw6-633/gkr.go | 96 +++++++++++-------- internal/gkr/bw6-633/solver_hints.go | 3 +- internal/gkr/bw6-761/gate_testing.go | 5 + internal/gkr/bw6-761/gkr.go | 96 +++++++++++-------- internal/gkr/bw6-761/solver_hints.go | 3 +- internal/gkr/small_rational/gate_testing.go | 5 + internal/gkr/small_rational/gkr.go | 96 +++++++++++-------- 25 files changed, 519 insertions(+), 314 deletions(-) diff --git a/internal/generator/backend/template/gkr/gate_testing.go.tmpl b/internal/generator/backend/template/gkr/gate_testing.go.tmpl index a782015cfa..7c24b0d27a 100644 --- a/internal/generator/backend/template/gkr/gate_testing.go.tmpl +++ b/internal/generator/backend/template/gkr/gate_testing.go.tmpl @@ -15,6 +15,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -130,6 +131,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -139,11 +141,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -157,6 +161,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make({{.FieldPackageName}}.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 16d5eb970b..c5c5b21dc2 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -97,8 +97,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []{{ .ElementType for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*{{ .ElementType }})) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*{{ .ElementType }})) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -230,7 +230,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]{{ .ElementType }}, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step {{ .ElementType }} + var ( + step {{ .ElementType }} + api gateAPI + ) res := make([]{{ .ElementType }}, degGJ) @@ -260,11 +263,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*{{ .ElementType }}) + summand := wire.Gate.Evaluate(&api, gateInput...).(*{{ .ElementType }}) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -663,6 +667,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]{{ .ElementType }}, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -720,52 +725,54 @@ func frToBigInts(dst []*big.Int, src []{{ .ElementType }}) { // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*{{ .ElementType }} + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod {{ .ElementType }} - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x {{ .ElementType }} @@ -783,7 +790,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -791,22 +798,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .E return f(api, inVar...).(*{{ .ElementType }}) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *{{ .ElementType }} { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new({{ .ElementType }})) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...{{ .ElementType }}) *{{ .ElementType }} // convertFunc turns f into a function that accepts and returns {{ .ElementType }}. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...{{ .ElementType }}) *{{ .ElementType }} { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *{{ .ElementType }} { +func (api *gateAPI) cast(v frontend.Variable) *{{ .ElementType }} { if x, ok := v.(*{{ .ElementType }}); ok { // fast path, no extra heap allocation return x } - var x {{ .ElementType }} + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index 04873fdcb8..eebc96850f 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -97,7 +97,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 16e7af5c12..c31d447691 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -106,7 +106,6 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb } var api gateAPI gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) - api.freeElements() } evaluation.Mul(&evaluation, &gateEvaluation) @@ -275,7 +274,6 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } - api.freeElements() } mu.Lock() diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 96b6636151..6c504f7692 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-381/gate_testing.go b/internal/gkr/bls12-381/gate_testing.go index 5b281fd634..275ce2efb0 100644 --- a/internal/gkr/bls12-381/gate_testing.go +++ b/internal/gkr/bls12-381/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 8f72898737..12b5aff144 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 81572a4ac4..372d3f2811 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bls24-315/gate_testing.go b/internal/gkr/bls24-315/gate_testing.go index 058b53cc06..c25c46bb4d 100644 --- a/internal/gkr/bls24-315/gate_testing.go +++ b/internal/gkr/bls24-315/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7aee277ba4..c93b8a3c95 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 783cc964c8..aa7d9cd19d 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bls24-317/gate_testing.go b/internal/gkr/bls24-317/gate_testing.go index ed418ff1b0..7aac990c1b 100644 --- a/internal/gkr/bls24-317/gate_testing.go +++ b/internal/gkr/bls24-317/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 7c679216fc..c697f94a7e 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index 234a327324..dc4fe325e4 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bn254/gate_testing.go b/internal/gkr/bn254/gate_testing.go index e9311a3ea5..5d5260ee19 100644 --- a/internal/gkr/bn254/gate_testing.go +++ b/internal/gkr/bn254/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 0174caa564..f5291406cb 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index f855222636..ab61f1ef43 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bw6-633/gate_testing.go b/internal/gkr/bw6-633/gate_testing.go index 8074b9621c..70d8aa7d4b 100644 --- a/internal/gkr/bw6-633/gate_testing.go +++ b/internal/gkr/bw6-633/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 2c2bda2037..5f3acb6842 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 19d347d099..de4a5c49c2 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/bw6-761/gate_testing.go b/internal/gkr/bw6-761/gate_testing.go index 0bae6258dc..f534002d83 100644 --- a/internal/gkr/bw6-761/gate_testing.go +++ b/internal/gkr/bw6-761/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -145,6 +149,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(fr.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 099b015b02..f063ca4fa0 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 09e9c13f0f..41ebcaf4c1 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -103,7 +103,8 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns := make([]frontend.Variable, data.maxNbIn) outsI := 0 - insI := 1 // skip the first input, which is the instance index + insI := 1 // skip the first input, which is the instance index + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wI := range data.circuit { w := &data.circuit[wI] if w.IsInput() { // read from provided input diff --git a/internal/gkr/small_rational/gate_testing.go b/internal/gkr/small_rational/gate_testing.go index 93c4ca4191..11c60d8e9c 100644 --- a/internal/gkr/small_rational/gate_testing.go +++ b/internal/gkr/small_rational/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -117,6 +118,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -126,11 +128,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) @@ -144,6 +148,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error // EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + var api gateAPI x := make(small_rational.Vector, nbIn) x.MustSetRandom() fFr := api.convertFunc(f) diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index d085c6305f..3be9191db4 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -104,8 +104,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []small_rational.S for i, uniqueI := range injectionLeftInv { // map from all to unique inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*small_rational.SmallRational)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*small_rational.SmallRational)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +236,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]small_rational.SmallRational, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step small_rational.SmallRational + var ( + step small_rational.SmallRational + api gateAPI + ) res := make([]small_rational.SmallRational, degGJ) @@ -266,11 +269,12 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*small_rational.SmallRational) + summand := wire.Gate.Evaluate(&api, gateInput...).(*small_rational.SmallRational) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } + api.freeElements() } mu.Lock() for i := range gJ { @@ -668,6 +672,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]small_rational.SmallRational, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,52 +729,54 @@ func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*small_rational.SmallRational + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod small_rational.SmallRational - prod.Mul(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x small_rational.SmallRational @@ -787,7 +794,7 @@ func (gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +802,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRation return f(api, inVar...).(*small_rational.SmallRational) } +// Done with the current task. Put back all the allocated slices. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *small_rational.SmallRational { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(small_rational.SmallRational)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...small_rational.SmallRational) *small_rational.SmallRational // convertFunc turns f into a function that accepts and returns small_rational.SmallRational. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...small_rational.SmallRational) *small_rational.SmallRational { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *small_rational.SmallRational { +func (api *gateAPI) cast(v frontend.Variable) *small_rational.SmallRational { if x, ok := v.(*small_rational.SmallRational); ok { // fast path, no extra heap allocation return x } - var x small_rational.SmallRational + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } From b12c0ee95a5536865a0e9f99a657af0bba58e2d0 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 13:56:40 -0500 Subject: [PATCH 58/92] fix: api pointer receiver --- internal/generator/backend/template/gkr/solver_hints.go.tmpl | 2 +- internal/gkr/bls12-377/solver_hints.go | 2 +- internal/gkr/bls12-381/solver_hints.go | 2 +- internal/gkr/bls24-315/solver_hints.go | 2 +- internal/gkr/bls24-317/solver_hints.go | 2 +- internal/gkr/bn254/solver_hints.go | 2 +- internal/gkr/bw6-633/solver_hints.go | 2 +- internal/gkr/bw6-761/solver_hints.go | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index eebc96850f..82b5c8927d 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -111,7 +111,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 6c504f7692..96b6636151 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 372d3f2811..416eb334e8 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index aa7d9cd19d..ff0267ad5f 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index dc4fe325e4..f2ebc7a410 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index ab61f1ef43..048895e003 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index de4a5c49c2..1e2b9aae00 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 41ebcaf4c1..a64e8ad154 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -117,7 +117,7 @@ func SolveHint(data *SolvingData) hint.Hint { gateIns[i] = &data.assignment[inWI][instanceI] } - data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) From e6a64bc46455e0925013574a3d1109fdfa66377a Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 17 Sep 2025 19:15:13 +0000 Subject: [PATCH 59/92] bench: gkr mimc permutations --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 86 +++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 559861af2d..8a9381bfd7 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -1,11 +1,15 @@ package gkr_mimc import ( + "errors" "fmt" + "os" "slices" "testing" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" @@ -77,3 +81,85 @@ func TestGkrMiMCCompiles(t *testing.T) { require.NoError(t, err) fmt.Println(cs.GetNbConstraints(), "constraints") } + +type hashTreeCircuit struct { + Leaves []frontend.Variable +} + +func (c hashTreeCircuit) Define(api frontend.API) error { + if len(c.Leaves) == 0 { + return errors.New("no hashing to do") + } + + hsh, err := New(api) + if err != nil { + return err + } + + layer := slices.Clone(c.Leaves) + + for len(layer) > 1 { + if len(layer)%2 == 1 { + layer = append(layer, 0) // pad with zero + } + + for i := range len(layer) / 2 { + hsh.Reset() + hsh.Write(layer[2*i], layer[2*i+1]) + layer[i] = hsh.Sum() + } + + layer = layer[:len(layer)/2] + } + + api.AssertIsDifferent(layer[0], 0) + return nil +} + +func loadCs(t require.TestingT, filename string, circuit frontend.Circuit) constraint.ConstraintSystem { + f, err := os.Open(filename) + + if os.IsNotExist(err) { + // actually compile + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, circuit) + require.NoError(t, err) + f, err = os.Create(filename) + require.NoError(t, err) + defer f.Close() + _, err = cs.WriteTo(f) + require.NoError(t, err) + return cs + } + + defer f.Close() + require.NoError(t, err) + + cs := plonk.NewCS(ecc.BLS12_377) + + _, err = cs.ReadFrom(f) + require.NoError(t, err) + + return cs +} + +func BenchmarkHashTree(b *testing.B) { + const size = 1 << 15 // about 2 ^ 16 total hashes + + circuit := hashTreeCircuit{ + Leaves: make([]frontend.Variable, size), + } + assignment := hashTreeCircuit{ + Leaves: make([]frontend.Variable, size), + } + + for i := range assignment.Leaves { + assignment.Leaves[i] = i + } + + cs := loadCs(b, "gkrmimc_hashtree.cs", &circuit) + + w, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) + require.NoError(b, err) + + require.NoError(b, cs.IsSolved(w)) +} From d55dbf4238c216b39f3ca19a38c96cce684c6c2f Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 17 Sep 2025 19:18:43 +0000 Subject: [PATCH 60/92] bench: gkr-mimc permutations --- std/permutation/gkr-mimc/gkr-mimc_test.go | 70 +++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 std/permutation/gkr-mimc/gkr-mimc_test.go diff --git a/std/permutation/gkr-mimc/gkr-mimc_test.go b/std/permutation/gkr-mimc/gkr-mimc_test.go new file mode 100644 index 0000000000..93143b1279 --- /dev/null +++ b/std/permutation/gkr-mimc/gkr-mimc_test.go @@ -0,0 +1,70 @@ +package gkr_mimc + +import ( + "errors" + "slices" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/stretchr/testify/require" +) + +type hashTreeCircuit struct { + Leaves []frontend.Variable +} + +func (c hashTreeCircuit) Define(api frontend.API) error { + if len(c.Leaves) == 0 { + return errors.New("no hashing to do") + } + + hsh, err := NewCompressor(api) + if err != nil { + return err + } + + layer := slices.Clone(c.Leaves) + + for len(layer) > 1 { + if len(layer)%2 == 1 { + layer = append(layer, 0) // pad with zero + } + + for i := range len(layer) / 2 { + layer[i] = hsh.Compress(layer[2*i], layer[2*i+1]) + } + + layer = layer[:len(layer)/2] + } + + api.AssertIsDifferent(layer[0], 0) + return nil +} + +func BenchmarkGkrPermutations(b *testing.B) { + circuit, assignment := hashTreeCircuits(50000) + + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(b, err) + + witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) + require.NoError(b, err) + + _, err = cs.Solve(witness) + require.NoError(b, err) +} + +func hashTreeCircuits(n int) (circuit, assignment hashTreeCircuit) { + leaves := make([]frontend.Variable, n) + for i := range n { + leaves[i] = i + } + + return hashTreeCircuit{ + Leaves: make([]frontend.Variable, len(leaves)), + }, hashTreeCircuit{ + Leaves: leaves, + } +} From 466f16bbab16ca8754220073959dad62d86b9fcd Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 17 Sep 2025 19:43:26 +0000 Subject: [PATCH 61/92] refactor: remove loadCs --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 32 ++----------------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 8a9381bfd7..751ecaa473 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -3,13 +3,10 @@ package gkr_mimc import ( "errors" "fmt" - "os" "slices" "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/plonk" - "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" @@ -116,32 +113,6 @@ func (c hashTreeCircuit) Define(api frontend.API) error { return nil } -func loadCs(t require.TestingT, filename string, circuit frontend.Circuit) constraint.ConstraintSystem { - f, err := os.Open(filename) - - if os.IsNotExist(err) { - // actually compile - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, circuit) - require.NoError(t, err) - f, err = os.Create(filename) - require.NoError(t, err) - defer f.Close() - _, err = cs.WriteTo(f) - require.NoError(t, err) - return cs - } - - defer f.Close() - require.NoError(t, err) - - cs := plonk.NewCS(ecc.BLS12_377) - - _, err = cs.ReadFrom(f) - require.NoError(t, err) - - return cs -} - func BenchmarkHashTree(b *testing.B) { const size = 1 << 15 // about 2 ^ 16 total hashes @@ -156,7 +127,8 @@ func BenchmarkHashTree(b *testing.B) { assignment.Leaves[i] = i } - cs := loadCs(b, "gkrmimc_hashtree.cs", &circuit) + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(b, err) w, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) require.NoError(b, err) From 57e95bcf7a621505abe3503a0a5ac1b087cec159 Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 17 Sep 2025 20:01:43 +0000 Subject: [PATCH 62/92] perf: reset api in gkr solver --- internal/gkr/bls12-377/solver_hints.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 96b6636151..1b28a97da7 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) From 6755e03ad4ef5c888dab63ad7ad405b1b5039c26 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 15:36:55 -0500 Subject: [PATCH 63/92] build: generify api reset --- internal/generator/backend/template/gkr/solver_hints.go.tmpl | 1 + internal/gkr/bls12-381/solver_hints.go | 1 + internal/gkr/bls24-315/solver_hints.go | 1 + internal/gkr/bls24-317/solver_hints.go | 1 + internal/gkr/bn254/solver_hints.go | 1 + internal/gkr/bw6-633/solver_hints.go | 1 + internal/gkr/bw6-761/solver_hints.go | 1 + 7 files changed, 7 insertions(+) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index 82b5c8927d..bfa2d3114a 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -112,6 +112,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index 416eb334e8..e576d3994d 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index ff0267ad5f..c122606692 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index f2ebc7a410..256d6cf9dc 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 048895e003..164d353e9e 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 1e2b9aae00..4c57f6e651 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index a64e8ad154..679fc6270f 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -118,6 +118,7 @@ func SolveHint(data *SolvingData) hint.Hint { } data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { data.assignment[wI][instanceI].BigInt(outs[outsI]) From b1de7358250e74baf1fad62c5aca5db3b6d4fac4 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 15:42:15 -0500 Subject: [PATCH 64/92] `hashTree` -> `merkleTree` --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 10 +++++----- std/permutation/gkr-mimc/gkr-mimc_test.go | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 751ecaa473..95956ff6c4 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -79,11 +79,11 @@ func TestGkrMiMCCompiles(t *testing.T) { fmt.Println(cs.GetNbConstraints(), "constraints") } -type hashTreeCircuit struct { +type merkleTreeCircuit struct { Leaves []frontend.Variable } -func (c hashTreeCircuit) Define(api frontend.API) error { +func (c merkleTreeCircuit) Define(api frontend.API) error { if len(c.Leaves) == 0 { return errors.New("no hashing to do") } @@ -113,13 +113,13 @@ func (c hashTreeCircuit) Define(api frontend.API) error { return nil } -func BenchmarkHashTree(b *testing.B) { +func BenchmarkMerkleTree(b *testing.B) { const size = 1 << 15 // about 2 ^ 16 total hashes - circuit := hashTreeCircuit{ + circuit := merkleTreeCircuit{ Leaves: make([]frontend.Variable, size), } - assignment := hashTreeCircuit{ + assignment := merkleTreeCircuit{ Leaves: make([]frontend.Variable, size), } diff --git a/std/permutation/gkr-mimc/gkr-mimc_test.go b/std/permutation/gkr-mimc/gkr-mimc_test.go index 93143b1279..6cc1bde714 100644 --- a/std/permutation/gkr-mimc/gkr-mimc_test.go +++ b/std/permutation/gkr-mimc/gkr-mimc_test.go @@ -11,11 +11,11 @@ import ( "github.com/stretchr/testify/require" ) -type hashTreeCircuit struct { +type merkleTreeCircuit struct { Leaves []frontend.Variable } -func (c hashTreeCircuit) Define(api frontend.API) error { +func (c merkleTreeCircuit) Define(api frontend.API) error { if len(c.Leaves) == 0 { return errors.New("no hashing to do") } @@ -43,7 +43,7 @@ func (c hashTreeCircuit) Define(api frontend.API) error { return nil } -func BenchmarkGkrPermutations(b *testing.B) { +func BenchmarkMerkleTree(b *testing.B) { circuit, assignment := hashTreeCircuits(50000) cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) @@ -56,15 +56,15 @@ func BenchmarkGkrPermutations(b *testing.B) { require.NoError(b, err) } -func hashTreeCircuits(n int) (circuit, assignment hashTreeCircuit) { +func hashTreeCircuits(n int) (circuit, assignment merkleTreeCircuit) { leaves := make([]frontend.Variable, n) for i := range n { leaves[i] = i } - return hashTreeCircuit{ + return merkleTreeCircuit{ Leaves: make([]frontend.Variable, len(leaves)), - }, hashTreeCircuit{ + }, merkleTreeCircuit{ Leaves: leaves, } } From 6c4c5c665974c8140219a8f65012e7ca2093a559 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 15:55:28 -0500 Subject: [PATCH 65/92] docs: copilot-inspired explanation for `freeElements` --- internal/generator/backend/template/gkr/gkr.go.tmpl | 2 +- internal/gkr/bls12-377/gkr.go | 2 +- internal/gkr/bls12-381/gkr.go | 2 +- internal/gkr/bls24-315/gkr.go | 2 +- internal/gkr/bls24-317/gkr.go | 2 +- internal/gkr/bn254/gkr.go | 2 +- internal/gkr/bw6-633/gkr.go | 2 +- internal/gkr/bw6-761/gkr.go | 2 +- internal/gkr/small_rational/gkr.go | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index c5c5b21dc2..c2d2901aab 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -798,7 +798,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ . return f(api, inVar...).(*{{ .ElementType }}) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index c31d447691..af149027cb 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 12b5aff144..9f2b371057 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index c93b8a3c95..b9dab1c6fe 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index c697f94a7e..ecb1a6bcb8 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index f5291406cb..77cb8d085b 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 5f3acb6842..a443549110 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index f063ca4fa0..0ae90d3c41 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 3be9191db4..0557f938df 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -802,7 +802,7 @@ func (api *gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRatio return f(api, inVar...).(*small_rational.SmallRational) } -// Done with the current task. Put back all the allocated slices. +// Put all elements back in the pool. func (api *gateAPI) freeElements() { api.nbUsed = 0 } From 8d0b353e41734e9c6f740e1f93cba831850c97fd Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 17 Sep 2025 22:55:44 +0000 Subject: [PATCH 66/92] perf: more pool freeing --- internal/generator/backend/template/gkr/gkr.go.tmpl | 2 +- internal/gkr/bls12-377/gkr.go | 2 +- internal/gkr/bls12-381/gkr.go | 2 +- internal/gkr/bls24-315/gkr.go | 2 +- internal/gkr/bls24-317/gkr.go | 2 +- internal/gkr/bn254/gkr.go | 2 +- internal/gkr/bw6-633/gkr.go | 2 +- internal/gkr/bw6-761/gkr.go | 2 +- internal/gkr/small_rational/gkr.go | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index c2d2901aab..2131aea723 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -267,8 +267,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index af149027cb..c28e3df8e3 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 9f2b371057..5c87cdd85d 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index b9dab1c6fe..acbebf56cf 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index ecb1a6bcb8..3d4a80d557 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 77cb8d085b..bed1d89329 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index a443549110..4962d1a4f9 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 0ae90d3c41..6ea38abf29 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 0557f938df..e65c94752e 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -273,8 +273,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } - api.freeElements() } mu.Lock() for i := range gJ { From 5a048b6b4fefe37ba271210086b3b33ccd47ad7b Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 14:10:47 -0500 Subject: [PATCH 67/92] perf: dedicated exp function --- .../backend/template/gkr/gkr.go.tmpl | 26 ++++++++++++++++++ internal/gkr/bls12-377/gkr.go | 27 +++++++++++++++++++ internal/gkr/bls12-381/gkr.go | 26 ++++++++++++++++++ internal/gkr/bls24-315/gkr.go | 26 ++++++++++++++++++ internal/gkr/bls24-317/gkr.go | 26 ++++++++++++++++++ internal/gkr/bn254/gkr.go | 26 ++++++++++++++++++ internal/gkr/bw6-633/gkr.go | 26 ++++++++++++++++++ internal/gkr/bw6-761/gkr.go | 26 ++++++++++++++++++ internal/gkr/engine_hints.go | 12 ++++++--- internal/gkr/gkr.go | 23 +++++++++++++++- internal/gkr/small_rational/gkr.go | 26 ++++++++++++++++++ std/gkrapi/compile.go | 2 +- std/gkrapi/gkr/types.go | 3 +++ std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 5 +++- std/permutation/gkr-mimc/gkr-mimc.go | 7 +---- .../gkr-poseidon2/gkr-poseidon2_test.go | 16 +++++++++++ 16 files changed, 291 insertions(+), 12 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 2131aea723..32accd6dca 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -790,6 +790,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, n uint) frontend.Variable { + var res *{{ .ElementType }} + x := api.cast(i) + + if n % 2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + n /= 2 + + // square and multiply + for n != 0 { + res.Mul(res, res) + + if n % 2 != 0 { + res.Mul(res, x) + } + + n /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index c28e3df8e3..d96f91d7d1 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -794,6 +795,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + if e == 0 { + return 1 + } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) + + res := api.newElement() + x := api.cast(i) + *res = *x + + // square and multiply + for n != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + n-- + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 5c87cdd85d..0941a05deb 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index acbebf56cf..7f70f6a12c 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 3d4a80d557..0320deb5de 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index bed1d89329..a380373951 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 4962d1a4f9..66d957f8f4 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 6ea38abf29..35d6f86bad 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 8c8bc1b797..0e272f12a4 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -130,7 +130,7 @@ func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big. return nil } -type gateAPI struct{ *big.Int } +type gateAPI struct{ mod *big.Int } func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { in1 := utils.FromInterface(i1) @@ -147,7 +147,7 @@ func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { x, y := utils.FromInterface(b), utils.FromInterface(c) x.Mul(&x, &y) - x.Mod(&x, g.Int) // reduce + x.Mod(&x, g.mod) // reduce y = utils.FromInterface(a) x.Add(&x, &y) return &x @@ -178,7 +178,13 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend y = utils.FromInterface(v) x.Mul(&x, &y) } - x.Mod(&x, g.Int) // reduce + x.Mod(&x, g.mod) // reduce + return &x +} + +func (g gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + x := utils.FromInterface(i) + x.Exp(&x, big.NewInt(int64(e)), g.mod) return &x } diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index 955ad8a354..6b8b5ebd4f 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -78,7 +78,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(api frontend.API, r inputEvaluations[i] = uniqueInputEvaluations[uniqueI] } - gateEvaluation = wire.Gate.Evaluate(api, inputEvaluations...) + gateEvaluation = wire.Gate.Evaluate(FrontendApiWrapper{api}, inputEvaluations...) } evaluation = api.Mul(evaluation, gateEvaluation) @@ -383,3 +383,24 @@ func DeserializeProof(sorted []*gkrtypes.Wire, serializedProof []frontend.Variab } return proof, nil } + +// FrontendApiWrapper implements additional functions to satisfy the gkr.GateAPI interface. +type FrontendApiWrapper struct { + frontend.API +} + +func (api FrontendApiWrapper) Exp(i frontend.Variable, e uint8) frontend.Variable { + res := frontend.Variable(1) + if e%2 != 0 { + res = i + } + e /= 2 + for e != 0 { + res = api.Mul(res, res) + if e%2 != 0 { + res = api.Mul(res, i) + } + e /= 2 + } + return res +} diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index e65c94752e..73cf93a146 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -794,6 +794,32 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + var res *small_rational.SmallRational + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x + } + + e /= 2 + + // square and multiply + for e != 0 { + res.Mul(res, res) + + if e%2 != 0 { + res.Mul(res, x) + } + + e /= 2 + } + + return res +} + func (api *gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 64205b80b8..5aca90cad2 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -182,7 +182,7 @@ func (c *Circuit) finalize(api frontend.API) error { for inI, inWI := range w.Inputs { gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance } - res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) + res := w.Gate.Evaluate(gadget.FrontendApiWrapper{API: api}, gateIn[:len(w.Inputs)]...) if w.IsOutput() { api.AssertIsEqual(res, c.assignments[wI][0]) } else { diff --git a/std/gkrapi/gkr/types.go b/std/gkrapi/gkr/types.go index af8a40fcd6..c9b8678b7b 100644 --- a/std/gkrapi/gkr/types.go +++ b/std/gkrapi/gkr/types.go @@ -35,6 +35,9 @@ type GateAPI interface { // Mul returns res = i1 * i2 * ... in Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable + // Exp returns res = iᵉ + Exp(i frontend.Variable, e uint8) frontend.Variable + // Println behaves like fmt.Println but accepts frontend.Variable as parameter // whose value will be resolved at runtime when computed by the solver Println(a ...frontend.Variable) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 95956ff6c4..3ad685e791 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -6,6 +6,7 @@ import ( "slices" "testing" + "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" @@ -29,7 +30,9 @@ func TestGkrMiMC(t *testing.T) { In: slices.Clone(vals[:length*2]), } - test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment)) + allCurves := gnark.Curves() + allCurves = []ecc.ID{ecc.BLS12_377} // TODO REMOVE + test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(allCurves[0], allCurves[1:]...)) } } diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 266ee00e67..c98d100303 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -208,12 +208,7 @@ func addPow17(key *big.Int) gkr.GateFunction { if len(in) != 2 { panic("expected two input") } - s := api.Add(in[0], in[1], key) - t := api.Mul(s, s) // s² - t = api.Mul(t, t) // s⁴ - t = api.Mul(t, t) // s⁸ - t = api.Mul(t, t) // s¹⁶ - return api.Mul(t, s) // s¹⁶ × s + return api.Exp(api.Add(in[0], in[1], key), 17) } } diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go index b224bf1414..b76024c888 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go @@ -2,6 +2,8 @@ package gkr_poseidon2 import ( "fmt" + "math/bits" + "strings" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -79,3 +81,17 @@ func BenchmarkGkrCompressions(b *testing.B) { _, err = cs.Solve(witness) require.NoError(b, err) } + +func TestGenerateTable(t *testing.T) { + var sb strings.Builder + for n := range 256 { + if n%16 == 0 { + sb.WriteString("\"+\n\"") + } + b := uint8(n) + b <<= bits.LeadingZeros8(b) + b = bits.Reverse8(b) + sb.WriteString(fmt.Sprintf("\\x%x", b)) + } + fmt.Println(sb.String()) +} From c6b830c3fdf9137884b4015b4bdff72e3658884a Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 14:14:37 -0500 Subject: [PATCH 68/92] build: generify exp changes --- .../backend/template/gkr/gkr.go.tmpl | 23 ++++++++++--------- internal/gkr/bls12-381/gkr.go | 19 +++++++-------- internal/gkr/bls24-315/gkr.go | 19 +++++++-------- internal/gkr/bls24-317/gkr.go | 19 +++++++-------- internal/gkr/bn254/gkr.go | 19 +++++++-------- internal/gkr/bw6-633/gkr.go | 19 +++++++-------- internal/gkr/bw6-761/gkr.go | 19 +++++++-------- internal/gkr/small_rational/gkr.go | 19 +++++++-------- 8 files changed, 82 insertions(+), 74 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 32accd6dca..0d97343a33 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -6,6 +6,7 @@ import ( fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "math/big" + "math/bits" "strconv" "sync" "github.com/consensys/gnark/frontend" @@ -790,27 +791,27 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, n uint) frontend.Variable { - var res *{{ .ElementType }} - x := api.cast(i) - - if n % 2 == 0 { - res = api.cast(1) - } else { - *res = *x +func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - n /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply for n != 0 { res.Mul(res, res) - if n % 2 != 0 { + if e % 2 != 0 { res.Mul(res, x) } - n /= 2 + e /= 2 + n-- } return res diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 0941a05deb..05560b9014 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7f70f6a12c..8f6feb48b6 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 0320deb5de..495658b89d 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index a380373951..01d2a5e9de 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 66d957f8f4..af578e59d7 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 35d6f86bad..63dabac899 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 73cf93a146..ffd730e69e 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "strconv" "sync" @@ -795,19 +796,18 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *small_rational.SmallRational - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x + if e == 0 { + return 1 } + n := bits.Len8(e) - 1 + e = bits.Reverse8(e) >> (8 - n) - e /= 2 + res := api.newElement() + x := api.cast(i) + *res = *x // square and multiply - for e != 0 { + for n != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,6 +815,7 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 + n-- } return res From 7bb78e7e8e9dcf8653021b8be566bad1d7e7cd82 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 14:18:48 -0500 Subject: [PATCH 69/92] fix: test for all curves --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 3ad685e791..1cd6579e0c 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" @@ -31,8 +32,12 @@ func TestGkrMiMC(t *testing.T) { } allCurves := gnark.Curves() - allCurves = []ecc.ID{ecc.BLS12_377} // TODO REMOVE - test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(allCurves[0], allCurves[1:]...)) + test.NewAssert(t).CheckCircuit( + circuit, + test.WithValidAssignment(assignment), + test.WithCurves(allCurves[0], allCurves[1:]...), + test.WithBackends(backend.PLONK), + ) } } From 58a0fdb0083252917943b1e4f432b14ceee834dc Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Mon, 22 Sep 2025 21:33:41 +0000 Subject: [PATCH 70/92] perf: fastpath for ^17 --- internal/gkr/bls12-377/gkr.go | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index d96f91d7d1..efaa8f9a6f 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -795,15 +795,29 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } +var seventeen = big.NewInt(17) + func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { + + res := api.newElement() + x := api.cast(i) + + if e == 17 { + res.Mul(x, x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + res.Mul(res, x) // x¹⁷ + + return res + } + if e == 0 { return 1 } n := bits.Len8(e) - 1 e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) *res = *x // square and multiply From 4b2c10207523d766c4ea3a45711f6e745a78b687 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:36:56 -0500 Subject: [PATCH 71/92] fix: exp --- internal/gkr/gkr.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index 6b8b5ebd4f..b115c56aa6 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -391,16 +391,13 @@ type FrontendApiWrapper struct { func (api FrontendApiWrapper) Exp(i frontend.Variable, e uint8) frontend.Variable { res := frontend.Variable(1) - if e%2 != 0 { - res = i - } - e /= 2 - for e != 0 { + + for range 8 { res = api.Mul(res, res) - if e%2 != 0 { + if e%128 != 0 { res = api.Mul(res, i) } - e /= 2 + e <<= 1 } return res } From 30c39b4fee11c6fa13ba67474451ea0d056ebd8d Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:38:31 -0500 Subject: [PATCH 72/92] fix: exp. really --- internal/gkr/gkr.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index b115c56aa6..ee43864d7a 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -394,7 +394,7 @@ func (api FrontendApiWrapper) Exp(i frontend.Variable, e uint8) frontend.Variabl for range 8 { res = api.Mul(res, res) - if e%128 != 0 { + if e&128 != 0 { res = api.Mul(res, i) } e <<= 1 From 6c7a7fccb4365761f517044555c32d755b29f40b Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:43:43 -0500 Subject: [PATCH 73/92] test: modernize benchmark --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 1cd6579e0c..569201ffa1 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -2,7 +2,6 @@ package gkr_mimc import ( "errors" - "fmt" "slices" "testing" @@ -77,16 +76,6 @@ func (c *testGkrMiMCCircuit) Define(api frontend.API) error { return nil } -func TestGkrMiMCCompiles(t *testing.T) { - const n = 52000 - circuit := testGkrMiMCCircuit{ - In: make([]frontend.Variable, n), - } - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit, frontend.WithCapacity(27_000_000)) - require.NoError(t, err) - fmt.Println(cs.GetNbConstraints(), "constraints") -} - type merkleTreeCircuit struct { Leaves []frontend.Variable } @@ -141,5 +130,9 @@ func BenchmarkMerkleTree(b *testing.B) { w, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) require.NoError(b, err) - require.NoError(b, cs.IsSolved(w)) + for b.Loop() { + s, err := cs.Solve(w) + require.NoError(b, err) + _ = s + } } From 266c5f6d057498a714b715fda19bd8b22bae968b Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:44:53 -0500 Subject: [PATCH 74/92] Revert "build: generify exp changes" This reverts commit c6b830c3fdf9137884b4015b4bdff72e3658884a. --- .../backend/template/gkr/gkr.go.tmpl | 23 +++++++++---------- internal/gkr/bls12-381/gkr.go | 19 ++++++++------- internal/gkr/bls24-315/gkr.go | 19 ++++++++------- internal/gkr/bls24-317/gkr.go | 19 ++++++++------- internal/gkr/bn254/gkr.go | 19 ++++++++------- internal/gkr/bw6-633/gkr.go | 19 ++++++++------- internal/gkr/bw6-761/gkr.go | 19 ++++++++------- internal/gkr/small_rational/gkr.go | 19 ++++++++------- 8 files changed, 74 insertions(+), 82 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 0d97343a33..32accd6dca 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -6,7 +6,6 @@ import ( fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "math/big" - "math/bits" "strconv" "sync" "github.com/consensys/gnark/frontend" @@ -791,27 +790,27 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 +func (api *gateAPI) Exp(i frontend.Variable, n uint) frontend.Variable { + var res *{{ .ElementType }} + x := api.cast(i) + + if n % 2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + n /= 2 // square and multiply for n != 0 { res.Mul(res, res) - if e % 2 != 0 { + if n % 2 != 0 { res.Mul(res, x) } - e /= 2 - n-- + n /= 2 } return res diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 05560b9014..0941a05deb 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 8f6feb48b6..7f70f6a12c 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 495658b89d..0320deb5de 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 01d2a5e9de..a380373951 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index af578e59d7..66d957f8f4 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 63dabac899..35d6f86bad 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *fr.Element + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index ffd730e69e..73cf93a146 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -796,18 +795,19 @@ func (api *gateAPI) Println(a ...frontend.Variable) { } func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - if e == 0 { - return 1 + var res *small_rational.SmallRational + x := api.cast(i) + + if e%2 == 0 { + res = api.cast(1) + } else { + *res = *x } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - res := api.newElement() - x := api.cast(i) - *res = *x + e /= 2 // square and multiply - for n != 0 { + for e != 0 { res.Mul(res, res) if e%2 != 0 { @@ -815,7 +815,6 @@ func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { } e /= 2 - n-- } return res From 5dd1d61d207a854c477b53a886d99f379c28b956 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:47:02 -0500 Subject: [PATCH 75/92] revert: exp func --- .../backend/template/gkr/gkr.go.tmpl | 26 ------------ internal/gkr/bls12-377/gkr.go | 41 ------------------- internal/gkr/bls12-381/gkr.go | 26 ------------ internal/gkr/bls24-315/gkr.go | 26 ------------ internal/gkr/bls24-317/gkr.go | 26 ------------ internal/gkr/bn254/gkr.go | 26 ------------ internal/gkr/bw6-633/gkr.go | 26 ------------ internal/gkr/bw6-761/gkr.go | 26 ------------ internal/gkr/engine_hints.go | 12 ++---- internal/gkr/small_rational/gkr.go | 26 ------------ std/gkrapi/compile.go | 2 +- std/gkrapi/gkr/types.go | 3 -- std/permutation/gkr-mimc/gkr-mimc.go | 7 +++- .../gkr-poseidon2/gkr-poseidon2_test.go | 16 -------- 14 files changed, 10 insertions(+), 279 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 32accd6dca..2131aea723 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -790,32 +790,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, n uint) frontend.Variable { - var res *{{ .ElementType }} - x := api.cast(i) - - if n % 2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - n /= 2 - - // square and multiply - for n != 0 { - res.Mul(res, res) - - if n % 2 != 0 { - res.Mul(res, x) - } - - n /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index efaa8f9a6f..c28e3df8e3 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/bits" "strconv" "sync" @@ -795,46 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -var seventeen = big.NewInt(17) - -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - - res := api.newElement() - x := api.cast(i) - - if e == 17 { - res.Mul(x, x) // x² - res.Mul(res, res) // x⁴ - res.Mul(res, res) // x⁸ - res.Mul(res, res) // x¹⁶ - res.Mul(res, x) // x¹⁷ - - return res - } - - if e == 0 { - return 1 - } - n := bits.Len8(e) - 1 - e = bits.Reverse8(e) >> (8 - n) - - *res = *x - - // square and multiply - for n != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - n-- - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 0941a05deb..5c87cdd85d 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7f70f6a12c..acbebf56cf 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 0320deb5de..3d4a80d557 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index a380373951..bed1d89329 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 66d957f8f4..4962d1a4f9 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 35d6f86bad..6ea38abf29 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *fr.Element - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 0e272f12a4..8c8bc1b797 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -130,7 +130,7 @@ func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big. return nil } -type gateAPI struct{ mod *big.Int } +type gateAPI struct{ *big.Int } func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { in1 := utils.FromInterface(i1) @@ -147,7 +147,7 @@ func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { x, y := utils.FromInterface(b), utils.FromInterface(c) x.Mul(&x, &y) - x.Mod(&x, g.mod) // reduce + x.Mod(&x, g.Int) // reduce y = utils.FromInterface(a) x.Add(&x, &y) return &x @@ -178,13 +178,7 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend y = utils.FromInterface(v) x.Mul(&x, &y) } - x.Mod(&x, g.mod) // reduce - return &x -} - -func (g gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - x := utils.FromInterface(i) - x.Exp(&x, big.NewInt(int64(e)), g.mod) + x.Mod(&x, g.Int) // reduce return &x } diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 73cf93a146..e65c94752e 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -794,32 +794,6 @@ func (api *gateAPI) Println(a ...frontend.Variable) { fmt.Println(toPrint...) } -func (api *gateAPI) Exp(i frontend.Variable, e uint8) frontend.Variable { - var res *small_rational.SmallRational - x := api.cast(i) - - if e%2 == 0 { - res = api.cast(1) - } else { - *res = *x - } - - e /= 2 - - // square and multiply - for e != 0 { - res.Mul(res, res) - - if e%2 != 0 { - res.Mul(res, x) - } - - e /= 2 - } - - return res -} - func (api *gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { inVar := make([]frontend.Variable, len(in)) for i := range in { diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 5aca90cad2..64205b80b8 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -182,7 +182,7 @@ func (c *Circuit) finalize(api frontend.API) error { for inI, inWI := range w.Inputs { gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance } - res := w.Gate.Evaluate(gadget.FrontendApiWrapper{API: api}, gateIn[:len(w.Inputs)]...) + res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) if w.IsOutput() { api.AssertIsEqual(res, c.assignments[wI][0]) } else { diff --git a/std/gkrapi/gkr/types.go b/std/gkrapi/gkr/types.go index c9b8678b7b..af8a40fcd6 100644 --- a/std/gkrapi/gkr/types.go +++ b/std/gkrapi/gkr/types.go @@ -35,9 +35,6 @@ type GateAPI interface { // Mul returns res = i1 * i2 * ... in Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable - // Exp returns res = iᵉ - Exp(i frontend.Variable, e uint8) frontend.Variable - // Println behaves like fmt.Println but accepts frontend.Variable as parameter // whose value will be resolved at runtime when computed by the solver Println(a ...frontend.Variable) diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index c98d100303..266ee00e67 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -208,7 +208,12 @@ func addPow17(key *big.Int) gkr.GateFunction { if len(in) != 2 { panic("expected two input") } - return api.Exp(api.Add(in[0], in[1], key), 17) + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) // s² + t = api.Mul(t, t) // s⁴ + t = api.Mul(t, t) // s⁸ + t = api.Mul(t, t) // s¹⁶ + return api.Mul(t, s) // s¹⁶ × s } } diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go index b76024c888..b224bf1414 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go @@ -2,8 +2,6 @@ package gkr_poseidon2 import ( "fmt" - "math/bits" - "strings" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -81,17 +79,3 @@ func BenchmarkGkrCompressions(b *testing.B) { _, err = cs.Solve(witness) require.NoError(b, err) } - -func TestGenerateTable(t *testing.T) { - var sb strings.Builder - for n := range 256 { - if n%16 == 0 { - sb.WriteString("\"+\n\"") - } - b := uint8(n) - b <<= bits.LeadingZeros8(b) - b = bits.Reverse8(b) - sb.WriteString(fmt.Sprintf("\\x%x", b)) - } - fmt.Println(sb.String()) -} From 1374bab9805cdb95bc199db8af8f98d188d59131 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:49:36 -0500 Subject: [PATCH 76/92] refactor: modulus name --- internal/gkr/engine_hints.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 8c8bc1b797..9e44dc7a3e 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -130,7 +130,7 @@ func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big. return nil } -type gateAPI struct{ *big.Int } +type gateAPI struct{ mod *big.Int } func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { in1 := utils.FromInterface(i1) @@ -147,7 +147,7 @@ func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { x, y := utils.FromInterface(b), utils.FromInterface(c) x.Mul(&x, &y) - x.Mod(&x, g.Int) // reduce + x.Mod(&x, g.mod) // reduce y = utils.FromInterface(a) x.Add(&x, &y) return &x @@ -178,7 +178,7 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend y = utils.FromInterface(v) x.Mul(&x, &y) } - x.Mod(&x, g.Int) // reduce + x.Mod(&x, g.mod) // reduce return &x } From 9959f8a45cbc26be5dfb1d94a1c48880fe58d5be Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 22 Sep 2025 16:52:05 -0500 Subject: [PATCH 77/92] revert: remove FrontendApiWrapper --- internal/gkr/gkr.go | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index ee43864d7a..955ad8a354 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -78,7 +78,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(api frontend.API, r inputEvaluations[i] = uniqueInputEvaluations[uniqueI] } - gateEvaluation = wire.Gate.Evaluate(FrontendApiWrapper{api}, inputEvaluations...) + gateEvaluation = wire.Gate.Evaluate(api, inputEvaluations...) } evaluation = api.Mul(evaluation, gateEvaluation) @@ -383,21 +383,3 @@ func DeserializeProof(sorted []*gkrtypes.Wire, serializedProof []frontend.Variab } return proof, nil } - -// FrontendApiWrapper implements additional functions to satisfy the gkr.GateAPI interface. -type FrontendApiWrapper struct { - frontend.API -} - -func (api FrontendApiWrapper) Exp(i frontend.Variable, e uint8) frontend.Variable { - res := frontend.Variable(1) - - for range 8 { - res = api.Mul(res, res) - if e&128 != 0 { - res = api.Mul(res, i) - } - e <<= 1 - } - return res -} From 271cff9225786d4e1f27046a090fee99cc8f3e68 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 23 Sep 2025 18:12:53 -0500 Subject: [PATCH 78/92] perf: dedicated exp17 func --- .../generator/backend/template/gkr/gkr.go.tmpl | 10 ++++++++++ internal/gkr/bls12-377/gkr.go | 10 ++++++++++ internal/gkr/bls12-381/gkr.go | 10 ++++++++++ internal/gkr/bls24-315/gkr.go | 10 ++++++++++ internal/gkr/bls24-317/gkr.go | 10 ++++++++++ internal/gkr/bn254/gkr.go | 10 ++++++++++ internal/gkr/bw6-633/gkr.go | 10 ++++++++++ internal/gkr/bw6-761/gkr.go | 10 ++++++++++ internal/gkr/engine_hints.go | 7 +++++++ internal/gkr/gkr.go | 14 +++++++++++++- internal/gkr/small_rational/gkr.go | 10 ++++++++++ std/gkrapi/compile.go | 2 +- std/gkrapi/gkr/types.go | 2 ++ std/permutation/gkr-mimc/gkr-mimc.go | 7 +------ 14 files changed, 114 insertions(+), 8 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 2131aea723..31819961e0 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -772,6 +772,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x {{ .ElementType }} diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index c28e3df8e3..996f4a6b4b 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 5c87cdd85d..863a2f71c8 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index acbebf56cf..dc6eb854ad 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 3d4a80d557..6333662f83 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index bed1d89329..f2361557c6 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 4962d1a4f9..e8b1f9d27d 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 6ea38abf29..27315ad58d 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 9e44dc7a3e..35f8cd5c7c 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -182,6 +182,13 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend return &x } +func (g gateAPI) Exp17(i frontend.Variable) frontend.Variable { + x := utils.FromInterface(i) + var res big.Int + res.Exp(&x, big.NewInt(17), g.mod) + return &res +} + func (g gateAPI) Println(a ...frontend.Variable) { strings := make([]string, len(a)) for i := range a { diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index 955ad8a354..dc2e0c9fd3 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -78,7 +78,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(api frontend.API, r inputEvaluations[i] = uniqueInputEvaluations[uniqueI] } - gateEvaluation = wire.Gate.Evaluate(api, inputEvaluations...) + gateEvaluation = wire.Gate.Evaluate(FrontendAPIWrapper{api}, inputEvaluations...) } evaluation = api.Mul(evaluation, gateEvaluation) @@ -383,3 +383,15 @@ func DeserializeProof(sorted []*gkrtypes.Wire, serializedProof []frontend.Variab } return proof, nil } + +type FrontendAPIWrapper struct { + frontend.API +} + +func (api FrontendAPIWrapper) Exp17(i frontend.Variable) frontend.Variable { + res := api.Mul(i, i) // i^2 + res = api.Mul(res, res) // i^4 + res = api.Mul(res, res) // i^8 + res = api.Mul(res, res) // i^16 + return api.Mul(res, i) // i^17 +} diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index e65c94752e..a884902d4c 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -776,6 +776,16 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } +func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { + res := api.newElement() + x := api.cast(i) + res.Mul(x, x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, x) // x^17 +} + func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x small_rational.SmallRational diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 64205b80b8..d8c0985ae5 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -182,7 +182,7 @@ func (c *Circuit) finalize(api frontend.API) error { for inI, inWI := range w.Inputs { gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance } - res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) + res := w.Gate.Evaluate(gadget.FrontendAPIWrapper{API: api}, gateIn[:len(w.Inputs)]...) if w.IsOutput() { api.AssertIsEqual(res, c.assignments[wI][0]) } else { diff --git a/std/gkrapi/gkr/types.go b/std/gkrapi/gkr/types.go index af8a40fcd6..5acc2cda9f 100644 --- a/std/gkrapi/gkr/types.go +++ b/std/gkrapi/gkr/types.go @@ -38,6 +38,8 @@ type GateAPI interface { // Println behaves like fmt.Println but accepts frontend.Variable as parameter // whose value will be resolved at runtime when computed by the solver Println(a ...frontend.Variable) + + Exp17(a frontend.Variable) frontend.Variable } // GateFunction is a function that evaluates a polynomial over its inputs diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 266ee00e67..2d0e4ada69 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -208,12 +208,7 @@ func addPow17(key *big.Int) gkr.GateFunction { if len(in) != 2 { panic("expected two input") } - s := api.Add(in[0], in[1], key) - t := api.Mul(s, s) // s² - t = api.Mul(t, t) // s⁴ - t = api.Mul(t, t) // s⁸ - t = api.Mul(t, t) // s¹⁶ - return api.Mul(t, s) // s¹⁶ × s + return api.Exp17(api.Add(in[0], in[1], key)) } } From 1519d5276db1919b78a6f4627880e9dd439168ab Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 24 Sep 2025 10:02:13 -0500 Subject: [PATCH 79/92] perf: SumExp17 --- .../backend/template/gkr/gkr.go.tmpl | 12 +++++++--- internal/gkr/bls12-377/gkr.go | 23 +++++++++++++------ internal/gkr/bls12-381/gkr.go | 12 +++++++--- internal/gkr/bls24-315/gkr.go | 12 +++++++--- internal/gkr/bls24-317/gkr.go | 12 +++++++--- internal/gkr/bn254/gkr.go | 12 +++++++--- internal/gkr/bw6-633/gkr.go | 12 +++++++--- internal/gkr/bw6-761/gkr.go | 12 +++++++--- internal/gkr/engine_hints.go | 10 +++++--- internal/gkr/gkr.go | 3 ++- internal/gkr/small_rational/gkr.go | 12 +++++++--- std/gkrapi/gkr/types.go | 2 +- std/permutation/gkr-mimc/gkr-mimc.go | 2 +- 13 files changed, 99 insertions(+), 37 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 31819961e0..bf0fb4ecd9 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -772,9 +772,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a,b,c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 996f4a6b4b..7925291eeb 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -776,14 +776,23 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + var x fr.Element + + if _, err := x.SetInterface(c); err != nil { // a, b are expected to be *fr.Element but not c + panic(err) + } + + x.Add(&x, api.cast(a)) + x.Add(&x, api.cast(b)) + res := api.newElement() - x := api.cast(i) - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + + res.Mul(&x, &x) // x^2 + res.Mul(res, res) // x^4 + res.Mul(res, res) // x^8 + res.Mul(res, res) // x^16 + return res.Mul(res, &x) // x^17 } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 863a2f71c8..0b34a40e23 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index dc6eb854ad..a0a059e6df 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 6333662f83..42f1e5e56c 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index f2361557c6..a84276a35a 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index e8b1f9d27d..4daeb83900 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 27315ad58d..9821871edd 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 35f8cd5c7c..9d53255f4e 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -182,9 +182,13 @@ func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend return &x } -func (g gateAPI) Exp17(i frontend.Variable) frontend.Variable { - x := utils.FromInterface(i) - var res big.Int +func (g gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := utils.FromInterface(a) + res := utils.FromInterface(b) + + x.Add(&x, &res) + res = utils.FromInterface(c) + x.Add(&x, &res) res.Exp(&x, big.NewInt(17), g.mod) return &res } diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index dc2e0c9fd3..b2ea42e8ba 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -388,7 +388,8 @@ type FrontendAPIWrapper struct { frontend.API } -func (api FrontendAPIWrapper) Exp17(i frontend.Variable) frontend.Variable { +func (api FrontendAPIWrapper) SumExp17(a, b, c frontend.Variable) frontend.Variable { + i := api.Add(a, b, c) res := api.Mul(i, i) // i^2 res = api.Mul(res, res) // i^4 res = api.Mul(res, res) // i^8 diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index a884902d4c..37f9490950 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -776,9 +776,15 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (api *gateAPI) Exp17(i frontend.Variable) frontend.Variable { - res := api.newElement() - x := api.cast(i) +func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { + x := api.cast(a) + res := api.cast(b) + x.Add(res, x) + if _, err := res.SetInterface(c); err != nil { + panic(err) + } + x.Add(res, x) + res.Mul(x, x) // x^2 res.Mul(res, res) // x^4 res.Mul(res, res) // x^8 diff --git a/std/gkrapi/gkr/types.go b/std/gkrapi/gkr/types.go index 5acc2cda9f..3c03f75df1 100644 --- a/std/gkrapi/gkr/types.go +++ b/std/gkrapi/gkr/types.go @@ -39,7 +39,7 @@ type GateAPI interface { // whose value will be resolved at runtime when computed by the solver Println(a ...frontend.Variable) - Exp17(a frontend.Variable) frontend.Variable + SumExp17(a, b, c frontend.Variable) frontend.Variable } // GateFunction is a function that evaluates a polynomial over its inputs diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 2d0e4ada69..1fa4cbce93 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -208,7 +208,7 @@ func addPow17(key *big.Int) gkr.GateFunction { if len(in) != 2 { panic("expected two input") } - return api.Exp17(api.Add(in[0], in[1], key)) + return api.SumExp17(in[0], in[1], key) } } From 2e3be9f26ca8318d07f019d99a4615f67c81549c Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 24 Sep 2025 10:38:14 -0500 Subject: [PATCH 80/92] perf: cache key as fr.Element --- internal/gkr/bls12-377/gkr.go | 8 ++------ std/permutation/gkr-mimc/gkr-mimc.go | 14 +++++++++++++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 7925291eeb..13923e0983 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -779,12 +779,8 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { var x fr.Element - if _, err := x.SetInterface(c); err != nil { // a, b are expected to be *fr.Element but not c - panic(err) - } - - x.Add(&x, api.cast(a)) - x.Add(&x, api.cast(b)) + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) res := api.newElement() diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 1fa4cbce93..b2c2b7527e 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -5,6 +5,7 @@ import ( "math/big" "github.com/consensys/gnark-crypto/ecc" + frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" @@ -204,11 +205,22 @@ func addPow7Add(key *big.Int) gkr.GateFunction { // addPow17: (in[0]+in[1]+key)¹⁷ func addPow17(key *big.Int) gkr.GateFunction { + var cachedKey frontend.Variable return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") } - return api.SumExp17(in[0], in[1], key) + if cachedKey == nil { + if _, ok := in[0].(*frBls12377.Element); ok { + var ck frBls12377.Element + ck.SetBigInt(key) + cachedKey = &ck + } else { + return api.SumExp17(in[0], in[1], key) + } + } + + return api.SumExp17(in[0], in[1], cachedKey) } } From a05f8d63e757d58cc41cbec6d4d7db5237e9f015 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 24 Sep 2025 10:59:08 -0500 Subject: [PATCH 81/92] build: generify --- .../backend/template/gkr/gkr.go.tmpl | 23 +++++++++---------- internal/gkr/bls12-377/gkr.go | 10 ++++---- internal/gkr/bls12-381/gkr.go | 23 +++++++++---------- internal/gkr/bls24-315/gkr.go | 23 +++++++++---------- internal/gkr/bls24-317/gkr.go | 23 +++++++++---------- internal/gkr/bn254/gkr.go | 23 +++++++++---------- internal/gkr/bw6-633/gkr.go | 23 +++++++++---------- internal/gkr/bw6-761/gkr.go | 23 +++++++++---------- internal/gkr/small_rational/gkr.go | 23 +++++++++---------- 9 files changed, 93 insertions(+), 101 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index bf0fb4ecd9..12b55c8d30 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -773,19 +773,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a,b,c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x {{ .ElementType }} + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 13923e0983..8f7532228e 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -784,11 +784,11 @@ func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { res := api.newElement() - res.Mul(&x, &x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, &x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 0b34a40e23..b4bdff0594 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index a0a059e6df..c095671e76 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 42f1e5e56c..bfa2fb778d 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index a84276a35a..d075ced6cd 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 4daeb83900..e6932181a6 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 9821871edd..acb157dd5f 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x fr.Element + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 37f9490950..427bac1877 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -777,19 +777,18 @@ func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front } func (api *gateAPI) SumExp17(a, b, c frontend.Variable) frontend.Variable { - x := api.cast(a) - res := api.cast(b) - x.Add(res, x) - if _, err := res.SetInterface(c); err != nil { - panic(err) - } - x.Add(res, x) + var x small_rational.SmallRational + + x.Add(api.cast(a), api.cast(b)) + x.Add(&x, api.cast(c)) + + res := api.newElement() - res.Mul(x, x) // x^2 - res.Mul(res, res) // x^4 - res.Mul(res, res) // x^8 - res.Mul(res, res) // x^16 - return res.Mul(res, x) // x^17 + res.Mul(&x, &x) // x² + res.Mul(res, res) // x⁴ + res.Mul(res, res) // x⁸ + res.Mul(res, res) // x¹⁶ + return res.Mul(res, &x) // x¹⁷ } func (api *gateAPI) Println(a ...frontend.Variable) { From f6613ce75532aaac67454077c5bd30a4591c1481 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 30 Sep 2025 16:56:24 -0500 Subject: [PATCH 82/92] perf: store keys as fr.Elements instead of big.Int --- std/permutation/gkr-mimc/gkr-mimc.go | 100 ++++++++++++++++++--------- 1 file changed, 68 insertions(+), 32 deletions(-) diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index b2c2b7527e..1645552881 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -2,16 +2,21 @@ package gkr_mimc import ( "fmt" - "math/big" "github.com/consensys/gnark-crypto/ecc" - frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + frbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" + frbls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" + frbls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" + frbls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" + frbn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + frbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" + frbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" @@ -95,7 +100,7 @@ func RegisterGates(curves ...ecc.ID) error { return err } gateNamer := newGateNamer(curve) - var lastLayerSBox, nonLastLayerSBox func(*big.Int) gkr.GateFunction + var lastLayerSBox, nonLastLayerSBox func(frontend.Variable) gkr.GateFunction switch deg { case 5: lastLayerSBox = addPow5Add @@ -111,12 +116,12 @@ func RegisterGates(curves ...ecc.ID) error { } for i := range len(constants) - 1 { - if _, err = gkrgates.Register(nonLastLayerSBox(&constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + if _, err = gkrgates.Register(nonLastLayerSBox(constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", i, curve, err) } } - if _, err = gkrgates.Register(lastLayerSBox(&constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + if _, err = gkrgates.Register(lastLayerSBox(constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", len(constants)-1, curve, err) } } @@ -124,23 +129,65 @@ func RegisterGates(curves ...ecc.ID) error { } // getParams returns the parameters for the MiMC encryption function for the given curve. -// It also returns the degree of the s-Box -func getParams(curve ecc.ID) ([]big.Int, int, error) { +// It also returns the degree of the s-Box. +func getParams(curve ecc.ID) ([]frontend.Variable, int, error) { switch curve { case ecc.BN254: - return bn254.GetConstants(), 5, nil + c := bn254.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbn254.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BLS12_381: - return bls12381.GetConstants(), 5, nil + c := bls12381.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls12381.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BLS12_377: - return bls12377.GetConstants(), 17, nil + c := bls12377.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls12377.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 17, nil case ecc.BLS24_315: - return bls24315.GetConstants(), 5, nil + c := bls24315.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls24315.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BLS24_317: - return bls24317.GetConstants(), 7, nil + c := bls24317.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbls24317.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 7, nil case ecc.BW6_633: - return bw6633.GetConstants(), 5, nil + c := bw6633.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbw6633.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil case ecc.BW6_761: - return bw6761.GetConstants(), 5, nil + c := bw6761.GetConstants() + res := make([]frontend.Variable, len(c)) + for i := range res { + var v frbw6761.Element + res[i] = v.SetBigInt(&c[i]) + } + return res, 5, nil default: return nil, -1, fmt.Errorf("unsupported curve ID: %s", curve) } @@ -155,7 +202,7 @@ func (n gateNamer) round(i int) gkr.GateName { return gkr.GateName(fmt.Sprintf("%s%d", string(n), i)) } -func addPow5(key *big.Int) gkr.GateFunction { +func addPow5(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") @@ -167,7 +214,7 @@ func addPow5(key *big.Int) gkr.GateFunction { } // addPow5Add: (in[0]+in[1]+key)⁵ + 2*in[0] + in[2] -func addPow5Add(key *big.Int) gkr.GateFunction { +func addPow5Add(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { panic("expected three input") @@ -180,7 +227,7 @@ func addPow5Add(key *big.Int) gkr.GateFunction { } } -func addPow7(key *big.Int) gkr.GateFunction { +func addPow7(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") @@ -192,7 +239,7 @@ func addPow7(key *big.Int) gkr.GateFunction { } // addPow7Add: (in[0]+in[1]+key)⁷ + 2*in[0] + in[2] -func addPow7Add(key *big.Int) gkr.GateFunction { +func addPow7Add(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { panic("expected three input") @@ -204,28 +251,17 @@ func addPow7Add(key *big.Int) gkr.GateFunction { } // addPow17: (in[0]+in[1]+key)¹⁷ -func addPow17(key *big.Int) gkr.GateFunction { - var cachedKey frontend.Variable +func addPow17(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 2 { panic("expected two input") } - if cachedKey == nil { - if _, ok := in[0].(*frBls12377.Element); ok { - var ck frBls12377.Element - ck.SetBigInt(key) - cachedKey = &ck - } else { - return api.SumExp17(in[0], in[1], key) - } - } - - return api.SumExp17(in[0], in[1], cachedKey) + return api.SumExp17(in[0], in[1], key) } } // addPow17Add: (in[0]+in[1]+key)¹⁷ + in[0] + in[2] -func addPow17Add(key *big.Int) gkr.GateFunction { +func addPow17Add(key frontend.Variable) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { panic("expected three input") From 610d49ca3db9dfb9a31497d746d3d54605e90721 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 7 Jan 2026 10:56:35 -0600 Subject: [PATCH 83/92] fix: match api changes --- std/permutation/gkr-mimc/gkr-mimc.go | 16 +++- .../poseidon2/gkr-poseidon2/gkr-poseidon2.go | 13 ++-- .../poseidon2/gkr-poseidon2/gkr_test.go | 74 ------------------- 3 files changed, 19 insertions(+), 84 deletions(-) delete mode 100644 std/permutation/poseidon2/gkr-poseidon2/gkr_test.go diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 266ee00e67..75736b22ef 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -52,7 +52,10 @@ func NewCompressor(api frontend.API) (hash.Compressor, error) { return nil, fmt.Errorf("cached value is of type %T, not a compressor", cached) } - gkrApi := gkrapi.New() + gkrApi, err := gkrapi.New(api) + if err != nil { + return nil, err + } in0 := gkrApi.NewInput() in1 := gkrApi.NewInput() @@ -75,9 +78,14 @@ func NewCompressor(api frontend.API) (hash.Compressor, error) { y = gkrApi.NamedGate(gateNamer.round(len(params)-1), in0, y, in1) + gkrCircuit, err := gkrApi.Compile("POSEIDON2") + if err != nil { + return nil, err + } + res := &compressor{ - gkrCircuit: gkrApi.Compile(api, "POSEIDON2"), + gkrCircuit: gkrCircuit, in0: in0, in1: in1, out: y, @@ -110,12 +118,12 @@ func RegisterGates(curves ...ecc.ID) error { } for i := range len(constants) - 1 { - if _, err = gkrgates.Register(nonLastLayerSBox(&constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + if err = gkrgates.Register(nonLastLayerSBox(&constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", i, curve, err) } } - if _, err = gkrgates.Register(lastLayerSBox(&constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + if err = gkrgates.Register(lastLayerSBox(&constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", len(constants)-1, curve, err) } } diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go index 4bb0b7a6a7..5b3d34b2ee 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -188,7 +188,10 @@ func defineCircuit(api frontend.API) (gkrCircuit *gkrapi.Circuit, in1, in2, out return } - gkrApi := gkrapi.New() + gkrApi, err := gkrapi.New(api) + if err != nil { + return + } x := gkrApi.NewInput() y := gkrApi.NewInput() @@ -281,7 +284,7 @@ func defineCircuit(api frontend.API) (gkrCircuit *gkrapi.Circuit, in1, in2, out // apply the external matrix one last time to obtain the final value of y out = gkrApi.Gate(extAddGate, y, x, in2) - gkrCircuit = gkrApi.Compile(api, "MIMC") + gkrCircuit, err = gkrApi.Compile("MIMC") return } @@ -313,13 +316,11 @@ func registerGates(p *poseidon2.Parameters, curve ecc.ID) error { halfRf := p.NbFullRounds / 2 extKeySBox := func(round int, varIndex int) error { - _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(curve)) - return err + return gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(curve)) } intKeySBox2 := func(round int) error { - _, err := gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(curve)) - return err + return gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(curve)) } fullRound := func(i int) error { diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go deleted file mode 100644 index 0a230c4381..0000000000 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package gkr_poseidon2 - -import ( - "fmt" - "testing" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/scs" - _ "github.com/consensys/gnark/std/hash/all" - "github.com/consensys/gnark/test" - "github.com/stretchr/testify/require" -) - -func gkrCompressionCircuits(t require.TestingT, n int) (circuit, assignment testGkrCompressionCircuit) { - var k int64 - ins := make([][2]frontend.Variable, n) - outs := make([]frontend.Variable, n) - for i := range n { - var x [2]fr.Element - ins[i] = [2]frontend.Variable{k, k + 1} - - x[0].SetInt64(k) - x[1].SetInt64(k + 1) - y0 := x[1] - - require.NoError(t, bls12377Permutation().Permutation(x[:])) - x[1].Add(&x[1], &y0) - outs[i] = x[1] - - k += 2 - } - - return testGkrCompressionCircuit{ - Ins: make([][2]frontend.Variable, len(ins)), - Outs: make([]frontend.Variable, len(outs)), - }, testGkrCompressionCircuit{ - Ins: ins, - Outs: outs, - } -} - -func TestGkrCompression(t *testing.T) { - circuit, assignment := gkrCompressionCircuits(t, 2) - - test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BLS12_377)) -} - -type testGkrCompressionCircuit struct { - Ins [][2]frontend.Variable - Outs []frontend.Variable -} - -func (c *testGkrCompressionCircuit) Define(api frontend.API) error { - - pos2 := NewGkrCompressor(api) - api.AssertIsEqual(len(c.Ins), len(c.Outs)) - for i := range c.Ins { - api.AssertIsEqual(c.Outs[i], pos2.Compress(c.Ins[i][0], c.Ins[i][1])) - } - - return nil -} - -func TestGkrCompressionCompiles(t *testing.T) { - // just measure the number of constraints - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &testGkrCompressionCircuit{ - Ins: make([][2]frontend.Variable, 52000), - Outs: make([]frontend.Variable, 52000), - }) - require.NoError(t, err) - fmt.Println(cs.GetNbConstraints(), "constraints") -} From 8549a0bcba14f11d683510f98e900a1ff3b58274 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 7 Jan 2026 10:58:44 -0600 Subject: [PATCH 84/92] fix: api use --- std/gkrapi/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/gkrapi/api.go b/std/gkrapi/api.go index 019b954cd8..54a1e11aba 100644 --- a/std/gkrapi/api.go +++ b/std/gkrapi/api.go @@ -30,7 +30,7 @@ func (api *API) NamedGate(gate gkr.GateName, in ...gkr.Variable) gkr.Variable { } func (api *API) Gate(gate gkr.GateFunction, in ...gkr.Variable) gkr.Variable { - if _, err := gkrgates.Register(gate, len(in)); err != nil { + if err := gkrgates.Register(gate, len(in)); err != nil { panic(err) } return api.NamedGate(gkrgates.GetDefaultGateName(gate), in...) From 8bcce8f4f64cfcb894a40462bd9a99ed671ef437 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 7 Jan 2026 11:01:19 -0600 Subject: [PATCH 85/92] remove engine_hints --- internal/gkr/engine_hints.go | 195 ----------------------------------- 1 file changed, 195 deletions(-) delete mode 100644 internal/gkr/engine_hints.go diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go deleted file mode 100644 index 8c8bc1b797..0000000000 --- a/internal/gkr/engine_hints.go +++ /dev/null @@ -1,195 +0,0 @@ -package gkr - -import ( - "errors" - "fmt" - "math/big" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/constraint/solver/gkrgates" - "github.com/consensys/gnark/frontend" - bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" - bls12381 "github.com/consensys/gnark/internal/gkr/bls12-381" - bls24315 "github.com/consensys/gnark/internal/gkr/bls24-315" - bls24317 "github.com/consensys/gnark/internal/gkr/bls24-317" - bn254 "github.com/consensys/gnark/internal/gkr/bn254" - bw6633 "github.com/consensys/gnark/internal/gkr/bw6-633" - bw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" - "github.com/consensys/gnark/internal/gkr/gkrinfo" - "github.com/consensys/gnark/internal/gkr/gkrtypes" - "github.com/consensys/gnark/internal/utils" -) - -type TestEngineHints struct { - assignment gkrtypes.WireAssignment - info *gkrinfo.StoringInfo // we retain a reference to the solving info to allow the caller to modify it between calls to Solve and Prove - circuit gkrtypes.Circuit - gateIns []frontend.Variable -} - -func NewTestEngineHints(info *gkrinfo.StoringInfo) (*TestEngineHints, error) { - circuit, err := gkrtypes.CircuitInfoToCircuit(info.Circuit, gkrgates.Get) - if err != nil { - return nil, err - } - - return &TestEngineHints{ - info: info, - circuit: circuit, - gateIns: make([]frontend.Variable, circuit.MaxGateNbIn()), - assignment: make(gkrtypes.WireAssignment, len(circuit)), - }, - err -} - -// Solve solves one instance of a GKR circuit. -// The first input is the index of the instance. The rest are the inputs of the circuit, in their nominal order. -func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) error { - - instanceI := len(h.assignment[0]) - if in0 := ins[0].Uint64(); !ins[0].IsUint64() || in0 > 0xffffffff { - return errors.New("first input must be a uint32 instance index") - } else if in0 != uint64(instanceI) || h.info.NbInstances != instanceI { - return errors.New("first input must equal the number of instances, and calls to Solve must be done in order of instance index") - } - - api := gateAPI{mod} - - inI := 1 - outI := 0 - for wI := range h.circuit { - w := &h.circuit[wI] - var val frontend.Variable - if w.IsInput() { - val = utils.FromInterface(ins[inI]) - inI++ - } else { - for gateInI, inWI := range w.Inputs { - h.gateIns[gateInI] = h.assignment[inWI][instanceI] - } - val = w.Gate.Evaluate(api, h.gateIns[:len(w.Inputs)]...) - } - if w.IsOutput() { - *outs[outI] = utils.FromInterface(val) - outI++ - } - h.assignment[wI] = append(h.assignment[wI], val) - } - return nil -} - -func (h *TestEngineHints) Prove(mod *big.Int, ins, outs []*big.Int) error { - - info, err := gkrtypes.StoringToSolvingInfo(*h.info, gkrgates.Get) - if err != nil { - return fmt.Errorf("failed to convert storing info to solving info: %w", err) - } - - // TODO @Tabaie autogenerate this or decide not to - if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { - data := bls12377.NewSolvingData(info, bls12377.WithAssignment(h.assignment)) - return bls12377.ProveHint(info.HashName, data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { - data := bls12381.NewSolvingData(info, bls12381.WithAssignment(h.assignment)) - return bls12381.ProveHint(info.HashName, data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { - data := bls24315.NewSolvingData(info, bls24315.WithAssignment(h.assignment)) - return bls24315.ProveHint(info.HashName, data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { - data := bls24317.NewSolvingData(info, bls24317.WithAssignment(h.assignment)) - return bls24317.ProveHint(info.HashName, data)(mod, ins, outs) - } - if mod.Cmp(ecc.BN254.ScalarField()) == 0 { - data := bn254.NewSolvingData(info, bn254.WithAssignment(h.assignment)) - return bn254.ProveHint(info.HashName, data)(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { - data := bw6633.NewSolvingData(info, bw6633.WithAssignment(h.assignment)) - return bw6633.ProveHint(info.HashName, data)(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { - data := bw6761.NewSolvingData(info, bw6761.WithAssignment(h.assignment)) - return bw6761.ProveHint(info.HashName, data)(mod, ins, outs) - } - - return errors.New("unsupported modulus") -} - -// GetAssignment returns the assignment for a particular wire and instance. -func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big.Int) error { - if len(ins) != 3 || !ins[0].IsUint64() || !ins[1].IsUint64() { - return errors.New("expected 3 inputs: wire index, instance index, and dummy output from the same instance") - } - if len(outs) != 1 { - return errors.New("expected 1 output: the value of the wire at the given instance") - } - *outs[0] = utils.FromInterface(h.assignment[ins[0].Uint64()][ins[1].Uint64()]) - return nil -} - -type gateAPI struct{ *big.Int } - -func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - in1 := utils.FromInterface(i1) - in2 := utils.FromInterface(i2) - - in1.Add(&in1, &in2) - for _, v := range in { - inV := utils.FromInterface(v) - in1.Add(&in1, &inV) - } - return &in1 -} - -func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - x, y := utils.FromInterface(b), utils.FromInterface(c) - x.Mul(&x, &y) - x.Mod(&x, g.Int) // reduce - y = utils.FromInterface(a) - x.Add(&x, &y) - return &x -} - -func (g gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - x := utils.FromInterface(i1) - x.Neg(&x) - return &x -} - -func (g gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - x := utils.FromInterface(i1) - y := utils.FromInterface(i2) - x.Sub(&x, &y) - for _, v := range in { - y = utils.FromInterface(v) - x.Sub(&x, &y) - } - return &x -} - -func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - x := utils.FromInterface(i1) - y := utils.FromInterface(i2) - x.Mul(&x, &y) - for _, v := range in { - y = utils.FromInterface(v) - x.Mul(&x, &y) - } - x.Mod(&x, g.Int) // reduce - return &x -} - -func (g gateAPI) Println(a ...frontend.Variable) { - strings := make([]string, len(a)) - for i := range a { - if s, ok := a[i].(fmt.Stringer); ok { - strings[i] = s.String() - } else if strings[i], ok = a[i].(string); !ok { - bigInt := utils.FromInterface(a[i]) - strings[i] = bigInt.String() - } - } -} From a045f1171e63744dd1503756ab7b942dfe0f456d Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 7 Jan 2026 11:06:31 -0600 Subject: [PATCH 86/92] revert change to engine.go --- test/engine.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/engine.go b/test/engine.go index b37bc68acc..ff7f9e2c7a 100644 --- a/test/engine.go +++ b/test/engine.go @@ -110,7 +110,7 @@ func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEng defer func() { if r := recover(); r != nil { - err = fmt.Errorf("%v\n%s", r, debug.Stack()) + err = fmt.Errorf("%v\n%s", r, string(debug.Stack())) } }() From 9c35fd7220595baee2f52c67618fe6e9a74cb841 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 7 Jan 2026 14:21:28 -0600 Subject: [PATCH 87/92] fix: mimc tests --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 8a9381bfd7..a6c2085833 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "os" + "path/filepath" "slices" "testing" @@ -24,7 +25,7 @@ func TestGkrMiMC(t *testing.T) { vals[i] = i + 1 } - for _, length := range lengths[1:2] { + for _, length := range lengths { circuit := &testGkrMiMCCircuit{ In: make([]frontend.Variable, length*2), } @@ -116,23 +117,29 @@ func (c hashTreeCircuit) Define(api frontend.API) error { return nil } -func loadCs(t require.TestingT, filename string, circuit frontend.Circuit) constraint.ConstraintSystem { - f, err := os.Open(filename) +func loadCs(t require.TestingT, fileTitle string, circuit frontend.Circuit) constraint.ConstraintSystem { + filename := filepath.Join(os.TempDir(), fileTitle) + _, err := os.Stat(filename) if os.IsNotExist(err) { // actually compile cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, circuit) require.NoError(t, err) - f, err = os.Create(filename) + f, err := os.Create(filename) require.NoError(t, err) - defer f.Close() + defer func() { + require.NoError(t, f.Close()) + }() _, err = cs.WriteTo(f) require.NoError(t, err) return cs } - defer f.Close() + f, err := os.Open(filename) require.NoError(t, err) + defer func() { + require.NoError(t, f.Close()) + }() cs := plonk.NewCS(ecc.BLS12_377) From dd4e36a8482aebd7a486fe0a4b70aedc1810820a Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 7 Jan 2026 14:38:14 -0600 Subject: [PATCH 88/92] fix: error message --- std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go index 5b3d34b2ee..a6e8c7bb1d 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -140,7 +140,7 @@ func NewCompressor(api frontend.API) (hash.Compressor, error) { if compressor, ok := cached.(*compressor); ok { return compressor, nil } - return nil, fmt.Errorf("cached value is of type %T, not a mimcCompressor", cached) + return nil, fmt.Errorf("cached value is of type %T, not a gkr-poseidon2.Compressor", cached) } gkrCircuit, in1, in2, out, err := defineCircuit(api) From 560e80efa336bf0c9fec75cfff6588cbebee2f2a Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 8 Jan 2026 13:02:32 -0600 Subject: [PATCH 89/92] fix: error on empty list --- std/permutation/gkr-mimc/gkr-mimc.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 75736b22ef..4a0bc0355e 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -1,6 +1,7 @@ package gkr_mimc import ( + "errors" "fmt" "math/big" @@ -96,6 +97,9 @@ func NewCompressor(api frontend.API) (hash.Compressor, error) { } func RegisterGates(curves ...ecc.ID) error { + if len(curves) == 0 { + return errors.New("expected at least one curve") + } for _, curve := range curves { constants, deg, err := getParams(curve) if err != nil { From c80eff5cb961464ca8ed651d30b069042450c192 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 11 Jan 2026 17:21:08 -0600 Subject: [PATCH 90/92] revert: api in the solve hint --- .../backend/template/gkr/solver_hints.go.tmpl | 9 +++++---- internal/gkr/bls12-377/solver_hints.go | 11 ++++++----- internal/gkr/bls12-381/solver_hints.go | 11 ++++++----- internal/gkr/bls24-315/solver_hints.go | 11 ++++++----- internal/gkr/bls24-317/solver_hints.go | 11 ++++++----- internal/gkr/bn254/solver_hints.go | 11 ++++++----- internal/gkr/bw6-633/solver_hints.go | 11 ++++++----- internal/gkr/bw6-761/solver_hints.go | 11 ++++++----- 8 files changed, 47 insertions(+), 39 deletions(-) diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index d7e5c7f7a5..278ca9fdce 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -138,6 +138,7 @@ func SolveHint(data []SolvingData) hint.Hint { // indices for reading inputs and outputs outsI := 0 insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. // we can now iterate over all the wires in the circuit. The wires are already topologically sorted, // i.e. all inputs of a gate appear before the gate itself. So it is safe to iterate linearly. @@ -153,10 +154,10 @@ func SolveHint(data []SolvingData) hint.Hint { for i, inWI := range w.Inputs { gateIns[i] = &data.assignment[inWI][instanceI] } - // evaluate the gate on the inputs - eval := w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*{{ .ElementType }}) - // store the result in the assignment (for the following gates to use) - data.assignment[wI][instanceI].Set(eval) + + // evaluate the gate on the inputs and store the result in the assignment (for the following gates to use) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { // write to provided output. diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 0231c46dfa..7c37f29b2f 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -146,7 +146,8 @@ func SolveHint(data []SolvingData) hint.Hint { // indices for reading inputs and outputs outsI := 0 - insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. // we can now iterate over all the wires in the circuit. The wires are already topologically sorted, // i.e. all inputs of a gate appear before the gate itself. So it is safe to iterate linearly. @@ -162,10 +163,10 @@ func SolveHint(data []SolvingData) hint.Hint { for i, inWI := range w.Inputs { gateIns[i] = &data.assignment[inWI][instanceI] } - // evaluate the gate on the inputs - eval := w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element) - // store the result in the assignment (for the following gates to use) - data.assignment[wI][instanceI].Set(eval) + + // evaluate the gate on the inputs and store the result in the assignment (for the following gates to use) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { // write to provided output. diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index b2924f152d..b270dfe45d 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -146,7 +146,8 @@ func SolveHint(data []SolvingData) hint.Hint { // indices for reading inputs and outputs outsI := 0 - insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. // we can now iterate over all the wires in the circuit. The wires are already topologically sorted, // i.e. all inputs of a gate appear before the gate itself. So it is safe to iterate linearly. @@ -162,10 +163,10 @@ func SolveHint(data []SolvingData) hint.Hint { for i, inWI := range w.Inputs { gateIns[i] = &data.assignment[inWI][instanceI] } - // evaluate the gate on the inputs - eval := w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element) - // store the result in the assignment (for the following gates to use) - data.assignment[wI][instanceI].Set(eval) + + // evaluate the gate on the inputs and store the result in the assignment (for the following gates to use) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { // write to provided output. diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 21d00c720d..3a9b7eda63 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -146,7 +146,8 @@ func SolveHint(data []SolvingData) hint.Hint { // indices for reading inputs and outputs outsI := 0 - insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. // we can now iterate over all the wires in the circuit. The wires are already topologically sorted, // i.e. all inputs of a gate appear before the gate itself. So it is safe to iterate linearly. @@ -162,10 +163,10 @@ func SolveHint(data []SolvingData) hint.Hint { for i, inWI := range w.Inputs { gateIns[i] = &data.assignment[inWI][instanceI] } - // evaluate the gate on the inputs - eval := w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element) - // store the result in the assignment (for the following gates to use) - data.assignment[wI][instanceI].Set(eval) + + // evaluate the gate on the inputs and store the result in the assignment (for the following gates to use) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { // write to provided output. diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index 04826fdacf..fc18d24659 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -146,7 +146,8 @@ func SolveHint(data []SolvingData) hint.Hint { // indices for reading inputs and outputs outsI := 0 - insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. // we can now iterate over all the wires in the circuit. The wires are already topologically sorted, // i.e. all inputs of a gate appear before the gate itself. So it is safe to iterate linearly. @@ -162,10 +163,10 @@ func SolveHint(data []SolvingData) hint.Hint { for i, inWI := range w.Inputs { gateIns[i] = &data.assignment[inWI][instanceI] } - // evaluate the gate on the inputs - eval := w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element) - // store the result in the assignment (for the following gates to use) - data.assignment[wI][instanceI].Set(eval) + + // evaluate the gate on the inputs and store the result in the assignment (for the following gates to use) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { // write to provided output. diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index f28cdf8eae..01b14882e7 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -146,7 +146,8 @@ func SolveHint(data []SolvingData) hint.Hint { // indices for reading inputs and outputs outsI := 0 - insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. // we can now iterate over all the wires in the circuit. The wires are already topologically sorted, // i.e. all inputs of a gate appear before the gate itself. So it is safe to iterate linearly. @@ -162,10 +163,10 @@ func SolveHint(data []SolvingData) hint.Hint { for i, inWI := range w.Inputs { gateIns[i] = &data.assignment[inWI][instanceI] } - // evaluate the gate on the inputs - eval := w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element) - // store the result in the assignment (for the following gates to use) - data.assignment[wI][instanceI].Set(eval) + + // evaluate the gate on the inputs and store the result in the assignment (for the following gates to use) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { // write to provided output. diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 9ee13e82dc..42a801f1fc 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -146,7 +146,8 @@ func SolveHint(data []SolvingData) hint.Hint { // indices for reading inputs and outputs outsI := 0 - insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. // we can now iterate over all the wires in the circuit. The wires are already topologically sorted, // i.e. all inputs of a gate appear before the gate itself. So it is safe to iterate linearly. @@ -162,10 +163,10 @@ func SolveHint(data []SolvingData) hint.Hint { for i, inWI := range w.Inputs { gateIns[i] = &data.assignment[inWI][instanceI] } - // evaluate the gate on the inputs - eval := w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element) - // store the result in the assignment (for the following gates to use) - data.assignment[wI][instanceI].Set(eval) + + // evaluate the gate on the inputs and store the result in the assignment (for the following gates to use) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { // write to provided output. diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 4cd5f99a99..dc5456fdb5 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -146,7 +146,8 @@ func SolveHint(data []SolvingData) hint.Hint { // indices for reading inputs and outputs outsI := 0 - insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + insI := 2 // skip the first two input, which are the circuit and instance indices, respectively. + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. // we can now iterate over all the wires in the circuit. The wires are already topologically sorted, // i.e. all inputs of a gate appear before the gate itself. So it is safe to iterate linearly. @@ -162,10 +163,10 @@ func SolveHint(data []SolvingData) hint.Hint { for i, inWI := range w.Inputs { gateIns[i] = &data.assignment[inWI][instanceI] } - // evaluate the gate on the inputs - eval := w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element) - // store the result in the assignment (for the following gates to use) - data.assignment[wI][instanceI].Set(eval) + + // evaluate the gate on the inputs and store the result in the assignment (for the following gates to use) + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + api.freeElements() } if w.IsOutput() { // write to provided output. From 1488ac96c97b4793afa098c67c71e3acb66cbf02 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 11 Jan 2026 20:58:00 -0600 Subject: [PATCH 91/92] fix: constant is var --- std/permutation/gkr-mimc/gkr-mimc.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 119f434769..94217a4b1a 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -128,12 +128,12 @@ func RegisterGates(curves ...ecc.ID) error { } for i := range len(constants) - 1 { - if err = gkrgates.Register(nonLastLayerSBox(&constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + if err = gkrgates.Register(nonLastLayerSBox(constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", i, curve, err) } } - if err = gkrgates.Register(lastLayerSBox(&constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + if err = gkrgates.Register(lastLayerSBox(constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", len(constants)-1, curve, err) } } From 167604987e29efc9549350a9e9dfd93e9d9d63ae Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 12 Jan 2026 19:04:19 -0600 Subject: [PATCH 92/92] fix: imports --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 3 --- std/permutation/gkr-mimc/gkr-mimc.go | 9 +-------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 84c26124d0..23da56f007 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -3,16 +3,13 @@ package gkr_mimc import ( "errors" "slices" - "testing" - "slices" "testing" "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index 3f897014b0..be33eae301 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -18,15 +18,8 @@ import ( frbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" frbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - - "github.com/consensys/gnark-crypto/ecc" - bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" - bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" - bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" - bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" - bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" - bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" + "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/kvstore"