Skip to content

Commit f875957

Browse files
authored
fix avx512 (panicked on small instances) (#193)
* fix avx512 (panicked on small instances) * add `test_aggregation` --------- Co-authored-by: Tom Wambsgans <TomWambsgans@users.noreply.github.com>
1 parent e560774 commit f875957

6 files changed

Lines changed: 47 additions & 19 deletions

File tree

crates/backend/poly/src/utils.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ pub const fn packing_width<EF: Field>() -> usize {
6161
PFPacking::<EF>::WIDTH
6262
}
6363

64+
pub const fn must_unpack_multilinears<EF: Field>(n_vars: usize) -> bool {
65+
n_vars <= 1 + packing_log_width::<EF>()
66+
}
67+
6468
pub fn batch_fold_multilinears<
6569
EF: PrimeCharacteristicRing + Copy + Send + Sync,
6670
IF: Copy + Sub<Output = IF> + Send + Sync,

crates/backend/sumcheck/src/prove.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,8 @@ where
115115

116116
let mut challenges = Vec::new();
117117
for _ in 0..n_rounds {
118-
// If Packing is enabled, and there are too little variables, we unpack everything:
119-
if multilinears.by_ref().is_packed() && n_vars <= 1 + packing_log_width::<EF>() {
120-
// unpack
118+
if multilinears.by_ref().is_packed() && must_unpack_multilinears::<EF>(n_vars) {
121119
multilinears = multilinears.by_ref().unpack().as_owned_or_clone().into();
122-
// SplitEq handles unpacking transparently via get_unpacked
123120
}
124121

125122
let ps = compute_and_send_polynomial(

crates/backend/sumcheck/src/sc_computation.rs

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -386,17 +386,25 @@ where
386386
|sc, pf, ed| sc.eval_packed_extension(&pf, ed),
387387
packing_unpack_sum,
388388
),
389-
MleGroupRef::BasePacked(multilinears) => sumcheck_compute_core(
390-
multilinears,
391-
degree,
392-
|i| split_eq.map(|seq| seq.get_packed(i)),
393-
computation,
394-
extra_data,
395-
missing_mul_factor,
396-
packed_fold_size,
397-
|sc, pf, ed| sc.eval_packed_base(&pf, ed),
398-
packing_unpack_sum,
399-
),
389+
MleGroupRef::BasePacked(multilinears) => {
390+
if let Some(seq) = split_eq {
391+
assert!(
392+
!seq.is_remainder_mode(),
393+
"BasePacked sumcheck received SplitEq in remainder mode"
394+
);
395+
}
396+
sumcheck_compute_core(
397+
multilinears,
398+
degree,
399+
|i| split_eq.map(|seq| seq.get_packed(i)),
400+
computation,
401+
extra_data,
402+
missing_mul_factor,
403+
packed_fold_size,
404+
|sc, pf, ed| sc.eval_packed_base(&pf, ed),
405+
packing_unpack_sum,
406+
)
407+
}
400408
MleGroupRef::Base(multilinears) => sumcheck_compute_core(
401409
multilinears,
402410
degree,
@@ -507,6 +515,12 @@ where
507515
)
508516
}
509517
MleGroupRef::BasePacked(multilinears) => {
518+
if let Some(seq) = split_eq {
519+
assert!(
520+
!seq.is_remainder_mode(),
521+
"BasePacked fold-and-compute received SplitEq in remainder mode"
522+
);
523+
}
510524
let prev_folded_size = multilinears[0].len() / 2;
511525
let prev_folding_factor_packed = EFPacking::<EF>::from(prev_folding_factor);
512526
sumcheck_fold_and_compute_core(

crates/backend/sumcheck/src/split_eq.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ impl<EF: ExtensionField<PF<EF>>> SplitEq<EF> {
1414
pub fn new(eq_point: &[EF]) -> Self {
1515
let n = eq_point.len();
1616

17-
if n <= packing_log_width::<EF>() * 2 {
17+
if must_unpack_multilinears::<EF>(n + 1) {
1818
return Self {
1919
eq_lo: vec![EF::ONE],
2020
eq_hi_packed: Vec::new(),
@@ -77,7 +77,7 @@ impl<EF: ExtensionField<PF<EF>>> SplitEq<EF> {
7777

7878
#[inline(always)]
7979
pub fn get_packed(&self, i: usize) -> EFPacking<EF> {
80-
debug_assert!(!self.is_remainder_mode(), "get_packed called in remainder mode");
80+
assert!(!self.is_remainder_mode(), "get_packed called in remainder mode");
8181
let packed_hi = self.eq_hi_packed.len();
8282
if self.eq_lo.len() > 1 {
8383
EFPacking::<EF>::from(self.eq_lo[i >> self.log_packed_hi]) * self.eq_hi_packed[i & (packed_hi - 1)]

crates/sub_protocols/src/air_sumcheck.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ where
7979
}
8080

8181
fn compute_bare_round_poly(&mut self) -> DensePolynomial<EF> {
82-
if self.multilinears.is_packed() && self.multilinears.n_vars() <= 1 + packing_log_width::<EF>() {
82+
if self.multilinears.is_packed() && must_unpack_multilinears::<EF>(self.multilinears.n_vars()) {
8383
let old = std::mem::replace(
8484
&mut self.multilinears,
8585
MleGroup::Owned(MleGroupOwned::Extension(vec![])),

tests/test_lean_multisig.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
use lean_multisig::{AggregatedXMSS, setup_prover, xmss_aggregate, xmss_verify_aggregation};
1+
use lean_multisig::{AggregatedXMSS, AggregationTopology, setup_prover, xmss_aggregate, xmss_verify_aggregation};
22
use rand::{RngExt, SeedableRng, rngs::StdRng};
3+
use rec_aggregation::benchmark::run_aggregation_benchmark;
34
use xmss::{
45
signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark},
56
xmss_key_gen, xmss_sign, xmss_verify,
@@ -18,6 +19,18 @@ fn test_xmss_signature() {
1819
xmss_verify(&pub_key, &msg, &signature, slot).unwrap();
1920
}
2021

22+
#[test]
23+
fn test_aggregation() {
24+
for n_signatures in [1, 2, 4, 8, 16, 32, 64, 128] {
25+
let topology = AggregationTopology {
26+
raw_xmss: n_signatures,
27+
children: vec![],
28+
log_inv_rate: 1,
29+
};
30+
run_aggregation_benchmark(&topology, 0, false);
31+
}
32+
}
33+
2134
#[test]
2235
fn test_recursive_aggregation() {
2336
setup_prover();

0 commit comments

Comments
 (0)