Skip to content

Commit 3135296

Browse files
committed
impl Copy for Field (3% to 5% faster end2end)
1 parent dd0341a commit 3135296

15 files changed

Lines changed: 333 additions & 350 deletions

File tree

crates/backend/air/src/symbolic.rs

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use core::fmt::Debug;
44
use core::iter::{Product, Sum};
55
use core::marker::PhantomData;
66
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
7-
use std::rc::Rc;
7+
use std::cell::RefCell;
88

99
use field::{Algebra, Field, InjectiveMonomial, PrimeCharacteristicRing};
1010

@@ -66,11 +66,44 @@ pub enum SymbolicOperation {
6666
Neg,
6767
}
6868

69-
#[derive(Clone, Debug, PartialEq, Eq)]
70-
pub enum SymbolicExpression<F> {
69+
#[derive(Copy, Clone, Debug)]
70+
pub struct SymbolicNode<F: Copy> {
71+
pub op: SymbolicOperation,
72+
pub lhs: SymbolicExpression<F>,
73+
pub rhs: SymbolicExpression<F>, // dummy (ZERO) for Neg
74+
}
75+
76+
// We use an arena as a trick to allow SymbolicExpression to be Copy
77+
// (ugly trick but fine in practice since SymbolicExpression is only used once at the start of the program)
78+
thread_local! {
79+
static ARENA: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
80+
}
81+
82+
fn alloc_node<F: Field>(node: SymbolicNode<F>) -> u32 {
83+
ARENA.with(|arena| {
84+
let mut bytes = arena.borrow_mut();
85+
let node_size = std::mem::size_of::<SymbolicNode<F>>();
86+
let idx = bytes.len();
87+
bytes.resize(idx + node_size, 0);
88+
unsafe {
89+
std::ptr::write_unaligned(bytes.as_mut_ptr().add(idx) as *mut SymbolicNode<F>, node);
90+
}
91+
idx as u32
92+
})
93+
}
94+
95+
pub fn get_node<F: Field>(idx: u32) -> SymbolicNode<F> {
96+
ARENA.with(|arena| {
97+
let bytes = arena.borrow();
98+
unsafe { std::ptr::read_unaligned(bytes.as_ptr().add(idx as usize) as *const SymbolicNode<F>) }
99+
})
100+
}
101+
102+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
103+
pub enum SymbolicExpression<F: Copy> {
71104
Variable(SymbolicVariable<F>),
72105
Constant(F),
73-
Operation(Rc<(SymbolicOperation, Vec<Self>)>),
106+
Operation(u32), // index into thread-local arena
74107
}
75108

76109
impl<F: Field> Default for SymbolicExpression<F> {
@@ -119,7 +152,11 @@ where
119152
fn add(self, rhs: T) -> Self {
120153
match (self, rhs.into()) {
121154
(Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs + rhs),
122-
(lhs, rhs) => Self::Operation(Rc::new((SymbolicOperation::Add, vec![lhs, rhs]))),
155+
(lhs, rhs) => Self::Operation(alloc_node(SymbolicNode {
156+
op: SymbolicOperation::Add,
157+
lhs,
158+
rhs,
159+
})),
123160
}
124161
}
125162
}
@@ -129,7 +166,7 @@ where
129166
T: Into<Self>,
130167
{
131168
fn add_assign(&mut self, rhs: T) {
132-
*self = self.clone() + rhs.into();
169+
*self = *self + rhs.into();
133170
}
134171
}
135172

@@ -148,7 +185,11 @@ impl<F: Field, T: Into<Self>> Sub<T> for SymbolicExpression<F> {
148185
fn sub(self, rhs: T) -> Self {
149186
match (self, rhs.into()) {
150187
(Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs - rhs),
151-
(lhs, rhs) => Self::Operation(Rc::new((SymbolicOperation::Sub, vec![lhs, rhs]))),
188+
(lhs, rhs) => Self::Operation(alloc_node(SymbolicNode {
189+
op: SymbolicOperation::Sub,
190+
lhs,
191+
rhs,
192+
})),
152193
}
153194
}
154195
}
@@ -158,7 +199,7 @@ where
158199
T: Into<Self>,
159200
{
160201
fn sub_assign(&mut self, rhs: T) {
161-
*self = self.clone() - rhs.into();
202+
*self = *self - rhs.into();
162203
}
163204
}
164205

@@ -168,7 +209,11 @@ impl<F: Field> Neg for SymbolicExpression<F> {
168209
fn neg(self) -> Self {
169210
match self {
170211
Self::Constant(c) => Self::Constant(-c),
171-
expr => Self::Operation(Rc::new((SymbolicOperation::Neg, vec![expr]))),
212+
expr => Self::Operation(alloc_node(SymbolicNode {
213+
op: SymbolicOperation::Neg,
214+
lhs: expr,
215+
rhs: Self::ZERO, // dummy
216+
})),
172217
}
173218
}
174219
}
@@ -179,7 +224,11 @@ impl<F: Field, T: Into<Self>> Mul<T> for SymbolicExpression<F> {
179224
fn mul(self, rhs: T) -> Self {
180225
match (self, rhs.into()) {
181226
(Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs * rhs),
182-
(lhs, rhs) => Self::Operation(Rc::new((SymbolicOperation::Mul, vec![lhs, rhs]))),
227+
(lhs, rhs) => Self::Operation(alloc_node(SymbolicNode {
228+
op: SymbolicOperation::Mul,
229+
lhs,
230+
rhs,
231+
})),
183232
}
184233
}
185234
}
@@ -189,7 +238,7 @@ where
189238
T: Into<Self>,
190239
{
191240
fn mul_assign(&mut self, rhs: T) {
192-
*self = self.clone() * rhs.into();
241+
*self = *self * rhs.into();
193242
}
194243
}
195244

@@ -258,7 +307,7 @@ impl<F: Field> AirBuilder for SymbolicAirBuilder<F> {
258307
fn declare_values(&mut self, values: &[Self::F]) {
259308
if self.bus_flag_value.is_none() {
260309
assert_eq!(values.len(), 1);
261-
self.bus_flag_value = Some(values[0].clone());
310+
self.bus_flag_value = Some(values[0]);
262311
} else {
263312
assert!(self.bus_data_values.is_none());
264313
self.bus_data_values = Some(values.to_vec());
@@ -276,6 +325,9 @@ pub fn get_symbolic_constraints_and_bus_data_values<F: Field, A: Air>(
276325
where
277326
A::ExtraData: Default,
278327
{
328+
// Clear the arena before building constraints
329+
ARENA.with(|arena| arena.borrow_mut().clear());
330+
279331
let mut builder = SymbolicAirBuilder::<F>::new(air.n_columns(), air.n_down_columns());
280332
air.eval(&mut builder, &Default::default());
281333
(

crates/backend/field/src/exponentiation.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ pub fn exp_1420470955<R: PrimeCharacteristicRing>(val: R) -> R {
1717
// Suspect it's possible to improve this with enough effort.
1818
let p1 = val;
1919
let p100 = p1.exp_power_of_2(2);
20-
let p101 = p100.clone() * p1.clone();
20+
let p101 = p100 * p1;
2121
let p10000 = p100.exp_power_of_2(2);
2222
let p10101 = p10000 * p101;
2323
let p10101000000 = p10101.exp_power_of_2(6);
24-
let p10101010101 = p10101000000.clone() * p10101.clone();
25-
let p101010010101 = p10101000000 * p10101010101.clone();
24+
let p10101010101 = p10101000000 * p10101;
25+
let p101010010101 = p10101000000 * p10101010101;
2626
let p101010010101000000000000 = p101010010101.exp_power_of_2(12);
2727
let p101010010101010101010101 = p101010010101000000000000 * p10101010101;
2828
let p101010010101010101010101000000 = p101010010101010101010101.exp_power_of_2(6);

crates/backend/field/src/field.rs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ use crate::{Packable, PackedFieldExtension, PackedValue};
5757
pub trait PrimeCharacteristicRing:
5858
Sized
5959
+ Default
60-
+ Clone
60+
+ Copy
6161
+ Add<Output = Self>
6262
+ AddAssign
6363
+ Sub<Output = Self>
@@ -138,7 +138,7 @@ pub trait PrimeCharacteristicRing:
138138
#[must_use]
139139
#[inline(always)]
140140
fn double(&self) -> Self {
141-
self.clone() + self.clone()
141+
*self + *self
142142
}
143143

144144
/// The elementary function `halve(a) = a/2`.
@@ -152,7 +152,7 @@ pub trait PrimeCharacteristicRing:
152152
// is circular when PrimeSubfield = Self. It should also be overwritten by
153153
// most rings to avoid the multiplication.
154154
let half = Self::from_prime_subfield(Self::PrimeSubfield::ONE.halve());
155-
self.clone() * half
155+
*self * half
156156
}
157157

158158
/// The elementary function `square(a) = a^2`.
@@ -161,7 +161,7 @@ pub trait PrimeCharacteristicRing:
161161
#[must_use]
162162
#[inline(always)]
163163
fn square(&self) -> Self {
164-
self.clone() * self.clone()
164+
*self * *self
165165
}
166166

167167
/// The elementary function `cube(a) = a^3`.
@@ -170,7 +170,7 @@ pub trait PrimeCharacteristicRing:
170170
#[must_use]
171171
#[inline(always)]
172172
fn cube(&self) -> Self {
173-
self.square() * self.clone()
173+
self.square() * *self
174174
}
175175

176176
/// Computes the arithmetic generalization of boolean `xor`.
@@ -179,7 +179,7 @@ pub trait PrimeCharacteristicRing:
179179
#[must_use]
180180
#[inline(always)]
181181
fn xor(&self, y: &Self) -> Self {
182-
self.clone() + y.clone() - self.clone() * y.clone().double()
182+
*self + *y - *self * y.double()
183183
}
184184

185185
/// Computes the arithmetic generalization of a triple `xor`.
@@ -197,7 +197,7 @@ pub trait PrimeCharacteristicRing:
197197
#[must_use]
198198
#[inline(always)]
199199
fn andn(&self, y: &Self) -> Self {
200-
(Self::ONE - self.clone()) * y.clone()
200+
(Self::ONE - *self) * *y
201201
}
202202

203203
/// The vanishing polynomial for boolean values: `x * (1 - x)`.
@@ -219,12 +219,12 @@ pub trait PrimeCharacteristicRing:
219219
#[must_use]
220220
#[inline]
221221
fn exp_u64(&self, power: u64) -> Self {
222-
let mut current = self.clone();
222+
let mut current = *self;
223223
let mut product = Self::ONE;
224224

225225
for j in 0..bits_u64(power) {
226226
if (power >> j) & 1 != 0 {
227-
product *= current.clone();
227+
product *= current;
228228
}
229229
current = current.square();
230230
}
@@ -242,15 +242,15 @@ pub trait PrimeCharacteristicRing:
242242
fn exp_const_u64<const POWER: u64>(&self) -> Self {
243243
match POWER {
244244
0 => Self::ONE,
245-
1 => self.clone(),
245+
1 => *self,
246246
2 => self.square(),
247247
3 => self.cube(),
248248
4 => self.square().square(),
249-
5 => self.square().square() * self.clone(),
249+
5 => self.square().square() * *self,
250250
6 => self.square().cube(),
251251
7 => {
252252
let x2 = self.square();
253-
let x3 = x2.clone() * self.clone();
253+
let x3 = x2 * *self;
254254
let x4 = x2.square();
255255
x3 * x4
256256
}
@@ -264,7 +264,7 @@ pub trait PrimeCharacteristicRing:
264264
#[must_use]
265265
#[inline]
266266
fn exp_power_of_2(&self, power_log: usize) -> Self {
267-
let mut res = self.clone();
267+
let mut res = *self;
268268
for _ in 0..power_log {
269269
res = res.square();
270270
}
@@ -279,7 +279,7 @@ pub trait PrimeCharacteristicRing:
279279
fn mul_2exp_u64(&self, exp: u64) -> Self {
280280
// Some rings might want to reimplement this to avoid the
281281
// exponentiations (and potentially even the multiplication).
282-
self.clone() * Self::TWO.exp_u64(exp)
282+
*self * Self::TWO.exp_u64(exp)
283283
}
284284

285285
/// Divide by a given power of two. `div_2exp_u64(a, exp) = a/2^exp`
@@ -291,7 +291,7 @@ pub trait PrimeCharacteristicRing:
291291
fn div_2exp_u64(&self, exp: u64) -> Self {
292292
// Some rings might want to reimplement this to avoid the
293293
// exponentiations (and potentially even the multiplication).
294-
self.clone() * Self::from_prime_subfield(Self::PrimeSubfield::ONE.halve().exp_u64(exp))
294+
*self * Self::from_prime_subfield(Self::PrimeSubfield::ONE.halve().exp_u64(exp))
295295
}
296296

297297
/// Construct an iterator which returns powers of `self`: `self^0, self^1, self^2, ...`.
@@ -306,7 +306,7 @@ pub trait PrimeCharacteristicRing:
306306
#[inline]
307307
fn shifted_powers(&self, start: Self) -> Powers<Self> {
308308
Powers {
309-
base: self.clone(),
309+
base: *self,
310310
current: start,
311311
}
312312
}
@@ -315,7 +315,7 @@ pub trait PrimeCharacteristicRing:
315315
#[must_use]
316316
#[inline]
317317
fn dot_product<const N: usize>(u: &[Self; N], v: &[Self; N]) -> Self {
318-
u.iter().zip(v).map(|(x, y)| x.clone() * y.clone()).sum()
318+
u.iter().zip(v).map(|(x, y)| *x * *y).sum()
319319
}
320320

321321
/// Compute the sum of a slice of elements whose length is a compile time constant.
@@ -344,10 +344,10 @@ pub trait PrimeCharacteristicRing:
344344
// I only tested this on `AVX2` though so there might be a better value for other architectures.
345345
match N {
346346
0 => Self::ZERO,
347-
1 => input[0].clone(),
348-
2 => input[0].clone() + input[1].clone(),
349-
3 => input[0].clone() + input[1].clone() + input[2].clone(),
350-
4 => (input[0].clone() + input[1].clone()) + (input[2].clone() + input[3].clone()),
347+
1 => input[0],
348+
2 => input[0] + input[1],
349+
3 => input[0] + input[1] + input[2],
350+
4 => (input[0] + input[1]) + (input[2] + input[3]),
351351
5 => Self::sum_array::<4>(&input[..4]) + Self::sum_array::<1>(&input[4..]),
352352
6 => Self::sum_array::<4>(&input[..4]) + Self::sum_array::<2>(&input[4..]),
353353
7 => Self::sum_array::<4>(&input[..4]) + Self::sum_array::<3>(&input[4..]),
@@ -447,7 +447,7 @@ pub trait BasedVectorSpace<F: PrimeCharacteristicRing>: Sized {
447447
#[must_use]
448448
#[inline]
449449
fn from_basis_coefficients_slice(slice: &[F]) -> Option<Self> {
450-
Self::from_basis_coefficients_iter(slice.iter().cloned())
450+
Self::from_basis_coefficients_iter(slice.iter().copied())
451451
}
452452

453453
/// Fixes a basis for the algebra `A` and uses this to
@@ -958,8 +958,8 @@ impl<R: PrimeCharacteristicRing> Iterator for Powers<R> {
958958
type Item = R;
959959

960960
fn next(&mut self) -> Option<R> {
961-
let result = self.current.clone();
962-
self.current *= self.base.clone();
961+
let result = self.current;
962+
self.current *= self.base;
963963
Some(result)
964964
}
965965
}

crates/backend/koala-bear/src/monty_31/poseidon2_monty.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ pub trait InternalLayerBaseParameters<MP: MontyParameters, const WIDTH: usize>:
2727
#[inline(always)]
2828
fn generic_internal_linear_layer<R: PrimeCharacteristicRing>(state: &mut [R; WIDTH]) {
2929
// We mostly delegate to internal_layer_mat_mul but have to handle state[0] separately.
30-
let part_sum: R = state[1..].iter().cloned().sum();
31-
let full_sum = part_sum.clone() + state[0].clone();
32-
state[0] = part_sum - state[0].clone();
30+
let part_sum: R = state[1..].iter().copied().sum();
31+
let full_sum = part_sum + state[0];
32+
state[0] = part_sum - state[0];
3333
Self::internal_layer_mat_mul(state, full_sum);
3434
}
3535
}

0 commit comments

Comments
 (0)