summaryrefslogtreecommitdiffstats
path: root/src/merkletree.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/merkletree.rs')
-rw-r--r--src/merkletree.rs57
1 files changed, 46 insertions, 11 deletions
diff --git a/src/merkletree.rs b/src/merkletree.rs
index 6173c2f..9d2ff84 100644
--- a/src/merkletree.rs
+++ b/src/merkletree.rs
@@ -8,7 +8,6 @@ use crate::{hasher::Hasher, node::Node};
/// Merkle trees are hash-based data structures used for secure and efficient data verification.
/// Each leaf node contains the hash of a data item, and each internal node contains the hash
/// of the concatenation of its children's hashes.
-#[derive(Debug)]
pub struct MerkleTree {
/// Leaf nodes at the base of the tree (may include a duplicate for even pairing).
leaves: Vec<Node>,
@@ -34,10 +33,11 @@ impl MerkleTree {
///
/// If the number of leaf nodes is odd, the last node is duplicated to ensure all internal
/// nodes have exactly two children.
- pub fn new<I, T>(hasher: &dyn Hasher, data: I) -> Self
+ pub fn new<I, T, H>(hasher: H, data: I) -> Self
where
I: IntoIterator<Item = T>,
T: AsRef<[u8]>,
+ H: Hasher + 'static,
{
let owned_data: Vec<T> = data.into_iter().collect();
let data_slices: Vec<&[u8]> = owned_data.iter().map(|item| item.as_ref()).collect();
@@ -49,7 +49,7 @@ impl MerkleTree {
let mut leaves: Vec<Node> = data_slices
.iter()
- .map(|x| Node::new_leaf(hasher, x))
+ .map(|data| Node::new_leaf(data, hasher.hash(data)))
.collect();
if leaves.len() % 2 != 0 {
@@ -60,7 +60,10 @@ impl MerkleTree {
}
/// Constructs the internal nodes of the tree from the leaves upward and computes the root.
- fn build(hasher: &dyn Hasher, mut nodes: Vec<Node>) -> Self {
+ fn build<H>(hasher: H, mut nodes: Vec<Node>) -> Self
+ where
+ H: Hasher + 'static,
+ {
let leaves = nodes.clone();
let mut height = 0;
@@ -74,7 +77,11 @@ impl MerkleTree {
for pair in nodes.chunks(2) {
let (left, right) = (pair[0].clone(), pair[1].clone());
- next_level.push(Node::new_internal(hasher, left, right));
+ let mut buffer = Vec::<u8>::new();
+ buffer.extend_from_slice(left.hash().as_bytes());
+ buffer.extend_from_slice(right.hash().as_bytes());
+ let hash = hasher.hash(&buffer);
+ next_level.push(Node::new_internal(&buffer, hash, left, right));
}
nodes = next_level;
height += 1;
@@ -112,14 +119,14 @@ impl MerkleTree {
#[cfg(test)]
mod tests {
- use crate::hasher::*;
-
use super::*;
+ use crate::hasher::*;
+ use crate::proof::*;
#[test]
fn test_merkle_tree_with_default_hasher() {
let data = &["hello".as_bytes(), "world".as_bytes()];
- let tree = MerkleTree::new(&DummyHasher, data);
+ let tree = MerkleTree::new(DummyHasher, data);
assert_eq!(tree.height(), 2);
assert_eq!(tree.root().hash(), "0xc0ff3");
@@ -129,7 +136,7 @@ mod tests {
#[cfg(feature = "sha256")]
fn test_merkle_tree_hashing() {
let data = &["hello".as_bytes(), "world".as_bytes()];
- let tree = MerkleTree::new(&SHA256Hasher, data);
+ let tree = MerkleTree::new(SHA256Hasher::new(), data);
assert_eq!(tree.height(), 2);
assert_eq!(
@@ -142,7 +149,7 @@ mod tests {
#[cfg(feature = "sha256")]
fn test_merkle_tree_single_leaf() {
let data = &["hello".as_bytes()];
- let tree = MerkleTree::new(&SHA256Hasher, data);
+ let tree = MerkleTree::new(SHA256Hasher::new(), data);
assert_eq!(tree.height(), 2);
assert_eq!(tree.len(), 2);
@@ -158,7 +165,7 @@ mod tests {
let inputs = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
let data: Vec<&[u8]> = inputs.iter().map(|s| s.as_bytes()).collect();
- let tree = MerkleTree::new(&SHA256Hasher, &data);
+ let tree = MerkleTree::new(SHA256Hasher::new(), &data);
assert_eq!(tree.height(), 5); // 10 elements padded to 16 → log2(16) + 1 = 5
@@ -171,4 +178,32 @@ mod tests {
"9da1ff0dfa79217bdbea9ec96407b1e693646cc493f64059fa27182a37cadf94"
);
}
+
+ #[test]
+ fn test_proof_generation_and_verification_dummy() {
+ let hasher = DummyHasher;
+ let data = vec!["a", "b", "c", "d"];
+ let tree = MerkleTree::new(hasher.clone(), data.clone());
+ let proofer = DefaultProofer::new(hasher.clone(), tree.leaves.clone());
+
+ for (index, item) in data.iter().enumerate() {
+ let proof = proofer.generate(index).unwrap();
+
+ assert!(proofer.verify(&proof, item, tree.root().hash(), &hasher));
+ }
+ }
+ #[test]
+ #[cfg(feature = "sha256")]
+ fn test_proof_generation_and_verification_sha256() {
+ let hasher = SHA256Hasher::new();
+ let data = vec!["a", "b", "c", "d"];
+ let tree = MerkleTree::new(hasher.clone(), data.clone());
+ let proofer = DefaultProofer::new(hasher.clone(), tree.leaves.clone());
+
+ for (index, item) in data.iter().enumerate() {
+ let proof = proofer.generate(index).unwrap();
+
+ assert!(proofer.verify(&proof, item, tree.root().hash(), &hasher));
+ }
+ }
}