Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/merkle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ ssz = [
"dep:tree_hash_derive",
"dep:ssz_codegen",
]
legacy_compact = []

[[bench]]
name = "mmr_comparison"
Expand Down
6 changes: 4 additions & 2 deletions crates/merkle/benches/mmr_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
#![allow(missing_docs)]
#![allow(unused_crate_dependencies)]

#[cfg(not(feature = "legacy_compact"))]
compile_error!("build this target with `--features legacy_compact`");

// stupid linter issue
use criterion as _;
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use sha2::{Digest, Sha256};
use strata_merkle::mmr::CompactMmr64;
use strata_merkle::{Mmr, Sha256Hasher};
use strata_merkle::{CompactMmr64, Mmr, Sha256Hasher};

type Hash32 = [u8; 32];

Expand Down
64 changes: 2 additions & 62 deletions crates/merkle/src/codec_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,8 @@
use strata_codec::{Codec, CodecError, Decoder, Encoder, VarVec};

use crate::hasher::MerkleHash;
use crate::mmr::CompactMmr64;
use crate::proof::{MerkleProof, RawMerkleProof};

// CompactMmr64

impl<H> Codec for CompactMmr64<H>
where
H: MerkleHash + Codec,
{
fn decode(dec: &mut impl Decoder) -> Result<Self, CodecError> {
let entries = u64::decode(dec)?;
let cap_log2 = u8::decode(dec)?;
// Number of roots equals popcount of entries (one per peak)
let roots_len = entries.count_ones() as usize;
let mut roots = Vec::with_capacity(roots_len);
for _ in 0..roots_len {
roots.push(H::decode(dec)?);
}
Ok(Self {
entries,
cap_log2,
roots,
})
}

fn encode(&self, enc: &mut impl Encoder) -> Result<(), CodecError> {
self.entries.encode(enc)?;
self.cap_log2.encode(enc)?;
// Validate roots length matches expected popcount to avoid misalignment
let expected = self.entries.count_ones() as usize;
if self.roots.len() != expected {
return Err(CodecError::MalformedField("CompactMmr64.roots"));
}
for h in &self.roots {
h.encode(enc)?;
}
Ok(())
}
}

// RawMerkleProof

impl<H> Codec for RawMerkleProof<H>
Expand Down Expand Up @@ -78,28 +40,19 @@ where

fn encode(&self, enc: &mut impl Encoder) -> Result<(), CodecError> {
self.inner.encode(enc)?;
self.index.encode(enc)
self.index.encode(enc)?;
Ok(())
}
}

#[cfg(test)]
mod tests {
use proptest::prelude::*;
use sha2::Sha256;
use strata_codec::{decode_buf_exact, encode_to_vec};

use super::*;
use crate::Mmr;

type H = [u8; 32];
type Hasher = crate::Sha256Hasher;

fn make_hashes(n: usize) -> Vec<H> {
use sha2::Digest;
(0..n)
.map(|i| Sha256::digest(i.to_be_bytes()).into())
.collect()
}

fn arb_hash() -> impl Strategy<Value = H> {
any::<[u8; 32]>()
Expand All @@ -125,18 +78,5 @@ mod tests {
let de: MerkleProof<H> = decode_buf_exact(&bytes).expect("deserialize proof");
prop_assert_eq!(proof, de);
}

#[test]
fn roundtrip_compact_mmr(num_leaves in 1usize..=64) {
let mut mmr = CompactMmr64::<H>::new(8);
let leaves = make_hashes(num_leaves);
for h in leaves.iter() {
Mmr::<Hasher>::add_leaf(&mut mmr, *h).expect("add leaf");
}

let bytes = encode_to_vec(&mmr).expect("serialize compact");
let de: CompactMmr64<H> = decode_buf_exact(&bytes).expect("deserialize compact");
prop_assert_eq!(mmr, de);
}
}
}
59 changes: 58 additions & 1 deletion crates/merkle/src/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ use crate::traits::MmrState;
/// are hashed. The blanket implementation provides the actual MMR algorithms that
/// work with any state backend implementing [`MmrState`].
pub trait Mmr<MH: MerkleHasher>: MmrState<MH::Hash> {
/// Creates a new instance with some arbitrary number of repeated leafs
/// pre-inserted.
fn new_repeated(leaf: MH::Hash, count: u64) -> Self;

/// Returns if the MMR is empty.
fn is_empty(&self) -> bool;

Expand Down Expand Up @@ -69,6 +73,32 @@ where
MH: MerkleHasher,
S: MmrState<MH::Hash>,
{
fn new_repeated(leaf: MH::Hash, count: u64) -> Self {
let mut this = Self::new_empty();

let mut cur = leaf;
for i in 0..u64::BITS {
let shr = count >> i;

// If the bit is set then set the value.
if shr & 1 != 0 {
this.set_peak(i as u8, cur);
}

// If there's no more set bits then break so that we don't keep
// hashing uselessly.
if shr == 0 {
break;
}

// We hash on every iteration because we need to compute the root at
// each step *anyways*.
cur = MH::hash_node(cur, cur);
}

this
}

fn is_empty(&self) -> bool {
self.num_entries() == 0
}
Expand Down Expand Up @@ -311,7 +341,7 @@ mod tests {

use super::*;
use crate::Sha256Hasher;
use crate::mmr::CompactMmr64;
use crate::legacy_compact_mmr::CompactMmr64;
use crate::proof::MerkleProof;
use crate::traits::MmrState;

Expand Down Expand Up @@ -776,6 +806,33 @@ mod tests {
}
}

#[test]
fn test_new_repeated_matches_add_leaf() {
let leaf = make_hash(b"repeated_leaf");

let mut acc = CompactMmr64::<Hash32>::new(64);

for count in 0..=10000 {
let batch = <CompactMmr64<Hash32> as Mmr<Sha256Hasher>>::new_repeated(leaf, count);

assert_eq!(
acc.num_entries(),
batch.num_entries(),
"test: num_entries mismatch for count={count}"
);

assert!(
acc.iter_peaks()
.zip(batch.iter_peaks())
.all(|((h1, v1), (h2, v2))| h1 == h2 && v1 == v2),
"test: peaks mismatch for count={count}"
);

// Add the leaf at the end so it's there for the next iteration.
Mmr::<Sha256Hasher>::add_leaf(&mut acc, leaf).unwrap();
}
}

#[test]
fn test_empty_mmr_cross_impl() {
// Verify all implementations handle empty state correctly
Expand Down
2 changes: 1 addition & 1 deletion crates/merkle/src/hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl<const LEN: usize> MerkleHash for [u8; LEN] {
}

fn eq_ct(a: &Self, b: &Self) -> bool {
// Attempt to constant-time comparison. This is *really hard* to do in
// Attempt at constant-time comparison. This is *really hard* to do in
// Rust, because LLVM likes to obliterate unnecessary instructions.
//
// I could use some of the more advanced libraries for this, but this
Expand Down
Loading
Loading