Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions std/math/emulated/element.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ type Element[T FieldParams] struct {
// on the bit-representation of the element (ToBits, exponentiation etc.).
modReduced bool

// bitsDecomposition caches the bit decomposition of the element to avoid
// redundant ToBits calls. Once computed, the bits are stored here and
// reused on subsequent ToBits calls on the same element.
// bitsOverflow stores the overflow value when bits were computed, to ensure
// cached bits are only used when overflow hasn't changed.
bitsDecomposition []frontend.Variable `gnark:"-"`
bitsOverflow uint `gnark:"-"`

isEvaluated bool
evaluation frontend.Variable `gnark:"-"`

Expand Down Expand Up @@ -148,6 +156,9 @@ func (e *Element[T]) Initialize(field *big.Int) {
// second compilation we may take a shortPath where we assume that modReduce
// flag is set.
e.modReduced = false
// reset bitsDecomposition to avoid stale cached bits from previous compilation
e.bitsDecomposition = nil
Comment thread
cursor[bot] marked this conversation as resolved.
e.bitsOverflow = 0
}

// copy makes a deep copy of the element.
Expand All @@ -158,6 +169,11 @@ func (e *Element[T]) copy() *Element[T] {
r.overflow = e.overflow
r.internal = e.internal
r.modReduced = e.modReduced
if e.bitsDecomposition != nil {
r.bitsDecomposition = make([]frontend.Variable, len(e.bitsDecomposition))
copy(r.bitsDecomposition, e.bitsDecomposition)
r.bitsOverflow = e.bitsOverflow
}
r.isEvaluated = e.isEvaluated
r.evaluation = e.evaluation
if e.witnessValue != nil {
Expand Down
12 changes: 12 additions & 0 deletions std/math/emulated/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,13 @@ func (c *ToBinaryCircuit[T]) Define(api frontend.API) error {
if len(bits) != len(c.Bits) {
return fmt.Errorf("got %d bits, expected %d", len(bits), len(c.Bits))
}
// check that the cached bits match the bits we get from ToBits. This is important as AssertIsInRange relies on the cached bits to be correct.
newBits := f.ToBits(&c.Value)
if len(newBits) != len(bits) {
return fmt.Errorf("got %d bits, expected %d", len(newBits), len(bits))
}
for i := range bits {
api.AssertIsEqual(bits[i], newBits[i])
api.AssertIsEqual(bits[i], c.Bits[i])
}
return nil
Expand Down Expand Up @@ -1180,7 +1186,13 @@ func (c *ToBitsCanonicalCircuit[T]) Define(api frontend.API) error {
}
el := f.newInternalElement(c.Limbs, 0)
bts := f.ToBitsCanonical(el)
// ensure that the bit caching is working correctly by calling ToBitsCanonical twice and comparing the results.
newBts := f.ToBitsCanonical(el)
if len(newBts) != len(bts) {
return fmt.Errorf("got %d bits, expected %d", len(bts), len(c.Expected))
}
for i := range bts {
api.AssertIsEqual(bts[i], newBts[i])
api.AssertIsEqual(bts[i], c.Expected[i])
}
return nil
Expand Down
17 changes: 12 additions & 5 deletions std/math/emulated/field_assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,16 @@ func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) {
}
eBits := f.ToBits(e)
aBits := f.ToBits(a)
ff := func(xbits, ybits []frontend.Variable) []frontend.Variable {
f.assertIsLessOrEqualBits(eBits, aBits)

profile.RecordOperation("emulated.AssertIsLessOrEqual", 4*(len(eBits)+len(aBits)))
}

// assertIsLessOrEqualBits asserts that the value represented by eBits is less
// or equal to the value represented by aBits. Both are in little-endian bit
// order. The slices are padded to the same length internally.
func (f *Field[T]) assertIsLessOrEqualBits(eBits, aBits []frontend.Variable) {
padBits := func(xbits, ybits []frontend.Variable) []frontend.Variable {
diff := len(xbits) - len(ybits)
ybits = append(ybits, make([]frontend.Variable, diff)...)
for i := len(ybits) - diff; i < len(ybits); i++ {
Expand All @@ -90,9 +99,9 @@ func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) {
return ybits
}
if len(eBits) > len(aBits) {
aBits = ff(eBits, aBits)
aBits = padBits(eBits, aBits)
} else {
eBits = ff(aBits, eBits)
eBits = padBits(aBits, eBits)
}
p := make([]frontend.Variable, len(eBits)+1)
p[len(eBits)] = 1
Expand All @@ -104,8 +113,6 @@ func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) {
ll := f.api.Mul(l, eBits[i])
f.api.AssertIsEqual(ll, 0)
}

profile.RecordOperation("emulated.AssertIsLessOrEqual", 4*(len(eBits)+len(aBits)))
}

// AssertIsInRange ensures that a is less than the emulated modulus. When we
Expand Down
56 changes: 45 additions & 11 deletions std/math/emulated/field_binary.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@ import (
// first) order. The returned bits are constrained to be 0-1. The number of
// returned bits is nbLimbs*nbBits+overflow. To obtain the bits of the canonical
// representation of Element, use method [Field.ToBitsCanonical].
//
// The bit decomposition is cached in the Element to avoid redundant computation
// when the same element is decomposed multiple times.
func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable {
// Save cached bits and overflow BEFORE enforceWidthConditional, which calls
// Initialize and resets the cache for deterministic recompilation. This
// matches the pattern used by modReduced flag.
cachedBits := a.bitsDecomposition
Comment thread
cursor[bot] marked this conversation as resolved.
cachedOverflow := a.bitsOverflow

f.enforceWidthConditional(a)
ba, aConst := f.constantValue(a)
if aConst {
Expand All @@ -20,6 +29,19 @@ func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable {
}
return res
}

// Check if we had cached bits that are still valid (same overflow value).
// Overflow can change (e.g., AssertIsInRange sets overflow=0), which affects
// the bit count, so we must verify the cached bits match current overflow.
if cachedBits != nil && cachedOverflow == a.overflow {
// Restore cache and return a copy to prevent callers from mutating
a.bitsDecomposition = cachedBits
a.bitsOverflow = cachedOverflow
res := make([]frontend.Variable, len(cachedBits))
copy(res, cachedBits)
return res
Comment thread
cursor[bot] marked this conversation as resolved.
}
Comment thread
yelhousni marked this conversation as resolved.

var carry frontend.Variable = 0
var fullBits []frontend.Variable
var limbBits []frontend.Variable
Expand All @@ -32,6 +54,10 @@ func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable {
}
fullBits = append(fullBits, limbBits[f.fParams.BitsPerLimb():f.fParams.BitsPerLimb()+a.overflow]...)

// Cache the bits and overflow in the element for future use
a.bitsDecomposition = fullBits
a.bitsOverflow = a.overflow

// Record operation for profiling
profile.RecordOperation("emulated.ToBits", 4*len(fullBits))
return fullBits
Expand All @@ -40,23 +66,31 @@ func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable {
// ToBitsCanonical represents the unique bit representation in the canonical
// format (less that the modulus).
func (f *Field[T]) ToBitsCanonical(a *Element[T]) []frontend.Variable {
// TODO: implement a inline version of this function. We perform binary
// decomposition both in the `ReduceStrict` and `ToBits` methods, but we can
// essentially do them at the same time.
//
// If we do this, then also check in places where we use `Reduce` and
// `ToBits` after that manually (e.g. in point and scalar marshaling) and
// replace them with this method.

nbBits := f.fParams.Modulus().BitLen()
// when the modulus is a power of 2, then we can remove the most significant
// bit as it is always zero.
if f.fParams.Modulus().TrailingZeroBits() == uint(nbBits-1) {
nbBits--
}
ca := f.ReduceStrict(a)
bts := f.ToBits(ca)
return bts[:nbBits]

// Reduce the element first using strict reduction (always performs mulMod).
// This ensures the value is actually reduced mod p, not just has overflow=0.
ca := f.reduce(a, true)

// Get bits of reduced element
caBits := f.ToBits(ca)

// Get bits of modulus-1 (this is cached as a constant, so ToBits is cheap)
modPrev := f.modulusPrev()
modPrevBits := f.ToBits(modPrev)

// Assert that the reduced element is less than the modulus (ca <= modulus-1).
// This avoids calling ToBits again on the same element (which is what
// the original ReduceStrict + AssertIsInRange path would do).
f.assertIsLessOrEqualBits(caBits, modPrevBits)

profile.RecordOperation("emulated.ToBitsCanonical", 4*(len(caBits)+len(modPrevBits)))
return caBits[:nbBits]
}

// FromBits returns a new Element given the bits is little-endian order.
Expand Down