Skip to content

Commit 6834816

Browse files
committed
small recursion improvements
1 parent a438c22 commit 6834816

2 files changed

Lines changed: 4 additions & 14 deletions

File tree

crates/rec_aggregation/recursion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def continue_recursion_ordered(
460460
fs, public_memory_random_point = fs_sample_many_ef(fs, INNER_PUBLIC_MEMORY_LOG_SIZE)
461461
poly_eq_public_mem = poly_eq_extension(public_memory_random_point, INNER_PUBLIC_MEMORY_LOG_SIZE)
462462
public_memory_eval = Array(DIM)
463-
dot_product_be_const(inner_public_memory, poly_eq_public_mem, public_memory_eval, 2**INNER_PUBLIC_MEMORY_LOG_SIZE)
463+
dot_product_be(inner_public_memory, poly_eq_public_mem, public_memory_eval, 2**INNER_PUBLIC_MEMORY_LOG_SIZE)
464464

465465
# WHIR BASE
466466
combination_randomness_gen: Mut
@@ -661,6 +661,7 @@ def fingerprint_2(table_index, data_1, data_2, logup_alphas_eq_poly):
661661
return res
662662

663663

664+
@inline
664665
def fingerprint_bytecode(instr_evals, eval_on_pc, logup_alphas_eq_poly):
665666
res: Mut = dot_product_ee_ret(instr_evals, logup_alphas_eq_poly, N_INSTRUCTION_COLUMNS)
666667
res = add_extension_ret(res, mul_extension_ret(eval_on_pc, logup_alphas_eq_poly + N_INSTRUCTION_COLUMNS * DIM))

crates/rec_aggregation/utils.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -189,28 +189,17 @@ def eval_multilinear_coeffs_rev(coeffs, point, n: Const):
189189
return result
190190

191191

192+
@inline
192193
def dot_product_be_dynamic(a, b, res, n):
193194
debug_assert(n < 400)
194195
match_range(n, range(1, 400), lambda i: dot_product_be(a, b, res, i))
195196
return
196197

197-
198-
def dot_product_be_const(a, b, res, n: Const):
199-
dot_product_be(a, b, res, n)
200-
return
201-
202-
203198
def dot_product_ee_dynamic(a, b, res, n):
204199
debug_assert(n < 400)
205200
match_range(n, range(1, 400), lambda i: dot_product_ee(a, b, res, i))
206201
return
207202

208-
209-
def dot_product_ee_const(a, b, res, n: Const):
210-
dot_product_ee(a, b, res, n)
211-
return
212-
213-
214203
def mle_of_01234567_etc(point, n):
215204
if n == 0:
216205
return ZERO_VEC_PTR
@@ -642,7 +631,7 @@ def dot_product_ee_ret(a, b, n):
642631
def sum_continuous_ef(slice_ef, len):
643632
debug_assert(len <= NUM_REPEATED_ONES)
644633
res = Array(DIM)
645-
dot_product_be_dynamic(REPEATED_ONES_PTR, slice_ef, res, len)
634+
dot_product_be(REPEATED_ONES_PTR, slice_ef, res, len)
646635
return res
647636

648637

0 commit comments

Comments
 (0)