diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index a01fdd8de5..77db185837 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -232,6 +232,16 @@ func (f *Field[T]) mulMod(a, b *Element[T], _ uint, p *Element[T]) *Element[T] { if a.isStrictZero() || b.isStrictZero() { return f.Zero() } + if p == nil { + // fast path - constant multiplication can be folded directly without + // creating a hinted reduction or carrying synthetic overflow metadata. + if ba, aConst := f.constantValue(a); aConst { + if bb, bConst := f.constantValue(b); bConst { + ba.Mul(ba, bb).Mod(ba, f.fParams.Modulus()) + return newConstElement[T](f.api.Compiler().Field(), ba, false) + } + } + } f.enforceWidthConditional(a) f.enforceWidthConditional(b) @@ -748,6 +758,15 @@ func (f *Field[T]) MulNoReduce(a, b *Element[T]) *Element[T] { } func (f *Field[T]) mulNoReduce(a, b *Element[T], nextoverflow uint) *Element[T] { + // fast path - constant multiplication stays constant even on the + // non-reducing path, so avoid growing overflow on a value the compiler can + // still recognize as constant. + if ba, aConst := f.constantValue(a); aConst { + if bb, bConst := f.constantValue(b); bConst { + ba.Mul(ba, bb).Mod(ba, f.fParams.Modulus()) + return newConstElement[T](f.api.Compiler().Field(), ba, false) + } + } resLimbs := make([]frontend.Variable, nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs))) for i := range resLimbs { resLimbs[i] = 0 diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index 9f6190c0f3..192f1b36ae 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -32,6 +32,24 @@ func (f *Field[T]) divPreCond(a, b *Element[T]) (nextOverflow uint, err error) { } func (f *Field[T]) div(a, b *Element[T], _ uint) *Element[T] { + // fast path - constant division can be folded eagerly and avoids calling a + // hint for a value already known at compile time. + if ba, aConst := f.constantValue(a); aConst { + if bb, bConst := f.constantValue(b); bConst { + if bb.Sign() == 0 { + panic("division by zero") + } + if !f.fParams.IsPrime() { + panic("modulus not a prime") + } + inv := new(big.Int).ModInverse(bb, f.fParams.Modulus()) + if inv == nil { + panic("division undefined") + } + ba.Mul(ba, inv).Mod(ba, f.fParams.Modulus()) + return newConstElement[T](f.api.Compiler().Field(), ba, false) + } + } // omit width assertion as for a is done in AssertIsEqual and for b is done in Mul below if !f.fParams.IsPrime() { // TODO shouldn't we still try to do a classic int div in a hint, constraint the result, and let it fail? @@ -66,6 +84,18 @@ func (f *Field[T]) inversePreCond(a, _ *Element[T]) (nextOverflow uint, err erro } func (f *Field[T]) inverse(a, _ *Element[T], _ uint) *Element[T] { + // fast path - constant inversion can be computed directly without using a + // hint or emitting a multiplication check. + if ba, aConst := f.constantValue(a); aConst { + if !f.fParams.IsPrime() { + panic("modulus not a prime") + } + inv := new(big.Int).ModInverse(ba, f.fParams.Modulus()) + if inv == nil { + panic("inverse undefined") + } + return newConstElement[T](f.api.Compiler().Field(), inv, false) + } // omit width assertion as is done in Mul below if !f.fParams.IsPrime() { panic("modulus not a prime") @@ -97,6 +127,18 @@ func (f *Field[T]) sqrtPreCond(a, _ *Element[T]) (nextOverflow uint, err error) } func (f *Field[T]) sqrt(a, _ *Element[T], _ uint) *Element[T] { + // fast path - constant square roots can be computed eagerly when they + // exist, avoiding a hint round-trip for compile-time values. + if ba, aConst := f.constantValue(a); aConst { + if !f.fParams.IsPrime() { + panic("modulus not a prime") + } + root := new(big.Int).ModSqrt(ba, f.fParams.Modulus()) + if root == nil { + panic("no square root") + } + return newConstElement[T](f.api.Compiler().Field(), root, false) + } // omit width assertion as is done in Mul below if !f.fParams.IsPrime() { panic("modulus not a prime") diff --git a/std/math/emulated/field_test.go b/std/math/emulated/field_test.go index a0a355f7fe..5ae183d8f4 100644 --- a/std/math/emulated/field_test.go +++ b/std/math/emulated/field_test.go @@ -131,3 +131,159 @@ func TestSubConstantCircuit(t *testing.T) { _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.IgnoreUnconstrainedInputs()) assert.NoError(err) } + +type SmallMulConstantFastPathCircuit struct { + Dummy frontend.Variable +} + +func (c *SmallMulConstantFastPathCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.Dummy, c.Dummy) + f, err := NewField[Goldilocks](api) + if err != nil { + return err + } + res := f.Mul(f.One(), f.One()) + if res.overflow != 0 { + return fmt.Errorf("mul overflow %d != 0", res.overflow) + } + if _, ok := f.constantValue(res); !ok { + return fmt.Errorf("mul should be constant") + } + f.AssertIsEqual(res, f.One()) + return nil +} + +func TestSmallMulConstantFastPathCircuit(t *testing.T) { + assert := test.NewAssert(t) + assert.CheckCircuit(&SmallMulConstantFastPathCircuit{}, test.WithValidAssignment(&SmallMulConstantFastPathCircuit{Dummy: 1}), test.NoTestEngine()) +} + +type SmallMulNoReduceConstantFastPathCircuit struct { + Dummy frontend.Variable +} + +func (c *SmallMulNoReduceConstantFastPathCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.Dummy, c.Dummy) + f, err := NewField[Goldilocks](api) + if err != nil { + return err + } + res := f.MulNoReduce(f.NewElement(7), f.NewElement(9)) + if res.overflow != 0 { + return fmt.Errorf("mulNoReduce overflow %d != 0", res.overflow) + } + if _, ok := f.constantValue(res); !ok { + return fmt.Errorf("mulNoReduce should be constant") + } + f.AssertIsEqual(res, f.NewElement(63)) + return nil +} + +func TestSmallMulNoReduceConstantFastPathCircuit(t *testing.T) { + assert := test.NewAssert(t) + assert.CheckCircuit(&SmallMulNoReduceConstantFastPathCircuit{}, test.WithValidAssignment(&SmallMulNoReduceConstantFastPathCircuit{Dummy: 1}), test.NoTestEngine()) +} + +type DivConstantFastPathCircuit struct { + Dummy frontend.Variable +} + +func (c *DivConstantFastPathCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.Dummy, c.Dummy) + f, err := NewField[Goldilocks](api) + if err != nil { + return err + } + res := f.Div(f.NewElement(21), f.NewElement(3)) + if res.overflow != 0 { + return fmt.Errorf("div overflow %d != 0", res.overflow) + } + if _, ok := f.constantValue(res); !ok { + return fmt.Errorf("div should be constant") + } + f.AssertIsEqual(res, f.NewElement(7)) + return nil +} + +func TestDivConstantFastPathCircuit(t *testing.T) { + assert := test.NewAssert(t) + assert.CheckCircuit(&DivConstantFastPathCircuit{}, test.WithValidAssignment(&DivConstantFastPathCircuit{Dummy: 1}), test.NoTestEngine()) +} + +type InverseConstantFastPathCircuit struct { + Dummy frontend.Variable +} + +func (c *InverseConstantFastPathCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.Dummy, c.Dummy) + f, err := NewField[Goldilocks](api) + if err != nil { + return err + } + res := f.Inverse(f.NewElement(7)) + if res.overflow != 0 { + return fmt.Errorf("inverse overflow %d != 0", res.overflow) + } + if _, ok := f.constantValue(res); !ok { + return fmt.Errorf("inverse should be constant") + } + f.AssertIsEqual(f.Mul(res, f.NewElement(7)), f.One()) + return nil +} + +func TestInverseConstantFastPathCircuit(t *testing.T) { + assert := test.NewAssert(t) + assert.CheckCircuit(&InverseConstantFastPathCircuit{}, test.WithValidAssignment(&InverseConstantFastPathCircuit{Dummy: 1}), test.NoTestEngine()) +} + +type SqrtConstantFastPathCircuit struct { + Dummy frontend.Variable +} + +func (c *SqrtConstantFastPathCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.Dummy, c.Dummy) + f, err := NewField[Goldilocks](api) + if err != nil { + return err + } + res := f.Sqrt(f.NewElement(9)) + if res.overflow != 0 { + return fmt.Errorf("sqrt overflow %d != 0", res.overflow) + } + if _, ok := f.constantValue(res); !ok { + return fmt.Errorf("sqrt should be constant") + } + f.AssertIsEqual(f.Mul(res, res), f.NewElement(9)) + return nil +} + +func TestSqrtConstantFastPathCircuit(t *testing.T) { + assert := test.NewAssert(t) + assert.CheckCircuit(&SqrtConstantFastPathCircuit{}, test.WithValidAssignment(&SqrtConstantFastPathCircuit{Dummy: 1}), test.NoTestEngine()) +} + +type LargeMulConstantFastPathCircuit struct { + Dummy frontend.Variable +} + +func (c *LargeMulConstantFastPathCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.Dummy, c.Dummy) + f, err := NewField[Secp256k1Fp](api) + if err != nil { + return err + } + res := f.Mul(f.NewElement(7), f.NewElement(9)) + if res.overflow != 0 { + return fmt.Errorf("mulLarge overflow %d != 0", res.overflow) + } + if _, ok := f.constantValue(res); !ok { + return fmt.Errorf("mulLarge should be constant") + } + f.AssertIsEqual(res, f.NewElement(63)) + return nil +} + +func TestLargeMulConstantFastPathCircuit(t *testing.T) { + assert := test.NewAssert(t) + assert.CheckCircuit(&LargeMulConstantFastPathCircuit{}, test.WithValidAssignment(&LargeMulConstantFastPathCircuit{Dummy: 1}), test.NoTestEngine()) +} diff --git a/std/math/emulated/regression_test.go b/std/math/emulated/regression_test.go index 440cbc00aa..48116899cd 100644 --- a/std/math/emulated/regression_test.go +++ b/std/math/emulated/regression_test.go @@ -1,6 +1,7 @@ package emulated import ( + "fmt" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -118,3 +119,63 @@ func TestIssue1021(t *testing.T) { err := test.IsSolved(&testIssue1021Circuit{}, &testIssue1021Circuit{A: ValueOf[BN254Fp](10)}, ecc.BN254.ScalarField()) assert.NoError(err) } + +// testIssueNNExpOneCircuit is a minimized regression for a fuzz-discovered bug +// in small-field emulation. The original reproducer showed that Exp(x, 1) +// could panic during compilation because repeated squaring kept the value equal +// to the constant one while still growing the overflow metadata. +type testIssueNNExpOneCircuit struct { + X Element[emparams.Goldilocks] +} + +func (c *testIssueNNExpOneCircuit) Define(api frontend.API) error { + f, err := NewField[emparams.Goldilocks](api) + if err != nil { + return err + } + res := f.Exp(&c.X, f.One()) + f.AssertIsEqual(res, &c.X) + return nil +} + +func TestRegressionExpOneKeepsVariable(t *testing.T) { + assert := test.NewAssert(t) + circuit := &testIssueNNExpOneCircuit{} + witness := &testIssueNNExpOneCircuit{X: ValueOf[emparams.Goldilocks](42)} + assert.CheckCircuit(circuit, test.WithValidAssignment(witness)) +} + +// testIssueNNMulOneCircuit isolates the lower-level invariant break behind the +// Exp(x, 1) failure. In small-field mode, Mul(1, 1) used to return an element +// that was still recognized as a constant but carried non-zero overflow. A +// subsequent multiplication then tried to reduce that "constant with overflow" +// and panicked. +type testIssueNNMulOneCircuit struct { + Dummy frontend.Variable +} + +func (c *testIssueNNMulOneCircuit) Define(api frontend.API) error { + // add a dummy assertion to ensure we wouldn't have empty circuit + api.AssertIsEqual(c.Dummy, c.Dummy) + f, err := NewField[emparams.Goldilocks](api) + if err != nil { + return err + } + x := f.Mul(f.One(), f.One()) + if x.overflow != 0 { + return fmt.Errorf("Mul(1,1) returned overflow %d", x.overflow) + } + if _, ok := f.constantValue(x); !ok { + return fmt.Errorf("Mul(1,1) should stay constant") + } + y := f.Mul(x, x) + f.AssertIsEqual(y, f.One()) + return nil +} + +func TestRegressionMulOneReductionPath(t *testing.T) { + assert := test.NewAssert(t) + var circuit testIssueNNMulOneCircuit + witness := testIssueNNMulOneCircuit{Dummy: 1} + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.NoTestEngine()) +}