Skip to content
Merged
47 changes: 34 additions & 13 deletions std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gkr_poseidon2
import (
"errors"
"fmt"
"sync"

"github.com/consensys/gnark/constraint/solver/gkrgates"
"github.com/consensys/gnark/internal/kvstore"
Expand Down Expand Up @@ -103,14 +104,6 @@ func intGate2(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable {
return api.Add(api.Mul(x[1], 3), x[0])
}

// extGate applies the first row of the external matrix
func extGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable {
if len(x) != 2 {
panic("expected 2 inputs")
}
return api.Add(api.Mul(x[0], 2), x[1])
}

// extAddGate applies the first row of the external matrix to the first two elements and adds the third
func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable {
if len(x) != 3 {
Expand Down Expand Up @@ -184,6 +177,7 @@ func defineCircuit(api frontend.API) (gkrCircuit *gkrapi.Circuit, in1, in2, out
}
gateNamer := newRoundGateNamer(&p, curve)

registerStaticGates()
if err = registerGates(&p, curve); err != nil {
return
}
Expand Down Expand Up @@ -294,6 +288,7 @@ func RegisterGates(curves ...ecc.ID) error {
if len(curves) == 0 {
return errors.New("expected at least one curve")
}
registerStaticGates()
for _, curve := range curves {
p, err := poseidon2.GetDefaultParameters(curve)
if err != nil {
Expand All @@ -306,6 +301,37 @@ func RegisterGates(curves ...ecc.ID) error {
return nil
}

var staticGatesOnce sync.Once

func registerStaticGates() {
staticGatesOnce.Do(func() {
if err := gkrgates.Register(pow4Gate, 1, gkrgates.WithUnverifiedDegree(4)); err != nil {
panic(err)
}
if err := gkrgates.Register(pow4TimesGate, 2, gkrgates.WithUnverifiedDegree(5)); err != nil {
panic(err)
}
if err := gkrgates.Register(pow3Gate, 1, gkrgates.WithUnverifiedDegree(3)); err != nil {
panic(err)
}
if err := gkrgates.Register(pow2Gate, 1, gkrgates.WithUnverifiedDegree(2)); err != nil {
panic(err)
}
if err := gkrgates.Register(pow2TimesGate, 2, gkrgates.WithUnverifiedDegree(3)); err != nil {
panic(err)
}
if err := gkrgates.Register(extGate2, 2, gkrgates.WithUnverifiedDegree(1)); err != nil {
panic(err)
}
if err := gkrgates.Register(intGate2, 2, gkrgates.WithUnverifiedDegree(1)); err != nil {
panic(err)
}
if err := gkrgates.Register(extAddGate, 3, gkrgates.WithUnverifiedDegree(1)); err != nil {
panic(err)
}
})
}

func registerGates(p *poseidon2.Parameters, curve ecc.ID) error {
const (
x = iota
Expand Down Expand Up @@ -379,9 +405,4 @@ func (n roundGateNamer) linear(varIndex, round int) gkr.GateName {
return gkr.GateName(fmt.Sprintf("x%d-l-op-round=%d;%s", varIndex, round, n))
}

// integrated is the name of a gate where a polynomial of total degree 1 is applied to the input, followed by an S-box
func (n roundGateNamer) integrated(varIndex, round int) gkr.GateName {
return gkr.GateName(fmt.Sprintf("x%d-i-op-round=%d;%s", varIndex, round, n))
}

type gkrPoseidon2Key struct{}