summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/proof.rs58
1 files changed, 30 insertions, 28 deletions
diff --git a/src/proof.rs b/src/proof.rs
index ed8c251..ff29978 100644
--- a/src/proof.rs
+++ b/src/proof.rs
@@ -54,7 +54,7 @@ pub trait Proofer {
pub struct DefaultProofer<H: Hasher> {
hasher: H,
- leaves: Vec<Node>,
+ levels: Vec<Vec<Node>>,
}
impl<H> DefaultProofer<H>
@@ -62,7 +62,29 @@ where
H: Hasher,
{
pub fn new(hasher: H, leaves: Vec<Node>) -> Self {
- Self { hasher, leaves }
+ let mut levels = Vec::new();
+ levels.push(leaves.clone());
+
+ let mut current_level = leaves;
+ while current_level.len() > 1 {
+ if current_level.len() % 2 != 0 {
+ current_level.push(current_level.last().unwrap().clone());
+ }
+ let next_level: Vec<Node> = current_level
+ .par_chunks(2)
+ .map(|pair| {
+ let (left, right) = (&pair[0], &pair[1]);
+ let combined = [left.hash().as_bytes(), right.hash().as_bytes()].concat();
+ let hash = hasher.hash(&combined);
+ Node::new_internal(hash, left.clone(), right.clone())
+ })
+ .collect();
+
+ levels.push(next_level.clone());
+ current_level = next_level;
+ }
+
+ Self { hasher, levels }
}
pub fn verify_hash(&self, proof: &MerkleProof, hash: String, root_hash: &str) -> bool {
@@ -86,22 +108,19 @@ where
H: Hasher,
{
fn generate(&self, index: usize) -> Option<MerkleProof> {
- if index >= self.leaves.len() {
+ if index >= self.levels[0].len() {
return None;
}
let mut path = Vec::new();
let mut current_index = index;
- let mut current_level = self.leaves.clone();
- while current_level.len() > 1 {
- if current_level.len() % 2 != 0 {
- current_level.push(current_level.last()?.clone());
- }
+ for level in &self.levels[..self.levels.len() - 1] {
+ // Flip the last bit and ensures that it never goes out-of-bounds
+ let sibling_index = (current_index ^ 1).min(level.len() - 1);
+
+ let sibling = &level[sibling_index];
- // Flip index to get sibling
- let sibling_index = current_index ^ 1;
- let sibling = &current_level[sibling_index];
let child_type = if sibling_index < current_index {
NodeChildType::Left
} else {
@@ -113,23 +132,6 @@ where
child_type,
});
- // Move to the next level
- current_level = current_level
- .par_chunks(2)
- .map(|pair| {
- let (left, right) = (&pair[0], &pair[1]);
- let (left_hash, right_hash) = (left.hash().as_bytes(), right.hash().as_bytes());
-
- let mut buffer = Vec::with_capacity(left_hash.len() + right_hash.len());
- buffer.extend_from_slice(left_hash);
- buffer.extend_from_slice(right_hash);
-
- let hash = self.hasher.hash(&buffer);
- Node::new_internal(hash, left.clone(), right.clone())
- })
- .collect();
-
- // Faster way to make "divide by 2"
current_index >>= 1;
}