From 2ef7371f7a4eefe7478cad43cb4922efaa12876a Mon Sep 17 00:00:00 2001 From: Santo Cariotti Date: Wed, 25 Jun 2025 10:40:24 +0200 Subject: Use `rayon` for parallelization --- src/merkletree.rs | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) (limited to 'src/merkletree.rs') diff --git a/src/merkletree.rs b/src/merkletree.rs index c8962f4..c3c9fd7 100644 --- a/src/merkletree.rs +++ b/src/merkletree.rs @@ -2,6 +2,7 @@ //! with binary Merkle trees using custom hashers. use crate::{hasher::Hasher, node::Node}; +use rayon::prelude::*; /// A binary Merkle tree implementation. /// @@ -37,7 +38,7 @@ impl MerkleTree { where I: IntoIterator, T: AsRef<[u8]>, - H: Hasher + 'static, + H: Hasher + 'static + std::marker::Sync, { let owned_data: Vec = data.into_iter().collect(); let data_slices: Vec<&[u8]> = owned_data.iter().map(|item| item.as_ref()).collect(); @@ -62,7 +63,7 @@ impl MerkleTree { /// Constructs the internal nodes of the tree from the leaves upward and computes the root. fn build(hasher: H, nodes: Vec) -> Self where - H: Hasher + 'static, + H: Hasher + 'static + std::marker::Sync, { let leaves = nodes.clone(); let mut current_level = nodes; @@ -76,19 +77,21 @@ impl MerkleTree { } next_level.clear(); + next_level = current_level + .par_chunks(2) + .map(|pair| { + let (left, right) = (&pair[0], &pair[1]); - for pair in current_level.chunks(2) { - let (left, right) = (&pair[0], &pair[1]); + let (left_hash, right_hash) = (left.hash().as_bytes(), right.hash().as_bytes()); - let (left_hash, right_hash) = (left.hash().as_bytes(), right.hash().as_bytes()); + let mut buffer = Vec::with_capacity(left_hash.len() + right_hash.len()); + buffer.extend_from_slice(left_hash); + buffer.extend_from_slice(right_hash); - let mut buffer = Vec::::with_capacity(left_hash.len() + right_hash.len()); - buffer.extend_from_slice(left_hash); - buffer.extend_from_slice(right_hash); - - let hash = hasher.hash(&buffer); - next_level.push(Node::new_internal(hash, left.clone(), right.clone())); - } + let hash = hasher.hash(&buffer); + Node::new_internal(hash, left.clone(), right.clone()) + }) + .collect(); std::mem::swap(&mut current_level, &mut next_level); height += 1; -- cgit v1.2.3-71-g8e6c