diff options
| author | Santo Cariotti <santo@dcariotti.me> | 2025-09-04 12:09:00 +0000 |
|---|---|---|
| committer | Santo Cariotti <santo@dcariotti.me> | 2025-09-04 12:09:00 +0000 |
| commit | 42518b55cf65abbb048a2771ce0c173fa035bf33 (patch) | |
| tree | 052d8b314e8e883210f1a7eee440231c8a63a507 /src/proof.rs | |
| parent | bc89fb1bfc39276abb029c0774567dce18ee6666 (diff) | |
Use levels, not leaves, to increase speed of proofing
Diffstat (limited to 'src/proof.rs')
| -rw-r--r-- | src/proof.rs | 58 |
1 files changed, 30 insertions, 28 deletions
diff --git a/src/proof.rs b/src/proof.rs index ed8c251..ff29978 100644 --- a/src/proof.rs +++ b/src/proof.rs @@ -54,7 +54,7 @@ pub trait Proofer { pub struct DefaultProofer<H: Hasher> { hasher: H, - leaves: Vec<Node>, + levels: Vec<Vec<Node>>, } impl<H> DefaultProofer<H> @@ -62,7 +62,29 @@ where H: Hasher, { pub fn new(hasher: H, leaves: Vec<Node>) -> Self { - Self { hasher, leaves } + let mut levels = Vec::new(); + levels.push(leaves.clone()); + + let mut current_level = leaves; + while current_level.len() > 1 { + if current_level.len() % 2 != 0 { + current_level.push(current_level.last().unwrap().clone()); + } + let next_level: Vec<Node> = current_level + .par_chunks(2) + .map(|pair| { + let (left, right) = (&pair[0], &pair[1]); + let combined = [left.hash().as_bytes(), right.hash().as_bytes()].concat(); + let hash = hasher.hash(&combined); + Node::new_internal(hash, left.clone(), right.clone()) + }) + .collect(); + + levels.push(next_level.clone()); + current_level = next_level; + } + + Self { hasher, levels } } pub fn verify_hash(&self, proof: &MerkleProof, hash: String, root_hash: &str) -> bool { @@ -86,22 +108,19 @@ where H: Hasher, { fn generate(&self, index: usize) -> Option<MerkleProof> { - if index >= self.leaves.len() { + if index >= self.levels[0].len() { return None; } let mut path = Vec::new(); let mut current_index = index; - let mut current_level = self.leaves.clone(); - while current_level.len() > 1 { - if current_level.len() % 2 != 0 { - current_level.push(current_level.last()?.clone()); - } + for level in &self.levels[..self.levels.len() - 1] { + // Flip the last bit and ensures that it never goes out-of-bounds + let sibling_index = (current_index ^ 1).min(level.len() - 1); + + let sibling = &level[sibling_index]; - // Flip index to get sibling - let sibling_index = current_index ^ 1; - let sibling = ¤t_level[sibling_index]; let child_type = if sibling_index < current_index { NodeChildType::Left } else { @@ -113,23 +132,6 @@ where child_type, }); - // Move to the next level - current_level = current_level - .par_chunks(2) - .map(|pair| { - let (left, right) = (&pair[0], &pair[1]); - let (left_hash, right_hash) = (left.hash().as_bytes(), right.hash().as_bytes()); - - let mut buffer = Vec::with_capacity(left_hash.len() + right_hash.len()); - buffer.extend_from_slice(left_hash); - buffer.extend_from_slice(right_hash); - - let hash = self.hasher.hash(&buffer); - Node::new_internal(hash, left.clone(), right.clone()) - }) - .collect(); - - // Faster way to make "divide by 2" current_index >>= 1; } |
