diff --git a/src/lib.rs b/src/lib.rs index 3ef20de..8030eb3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,10 +2,7 @@ // // SPDX-License-Identifier: AGPL-3.0-or-later -use aes::{ - cipher::{generic_array::GenericArray, BlockDecrypt, BlockEncrypt, KeyInit}, - Block, -}; +use aes::cipher::{generic_array::GenericArray, BlockDecrypt, BlockEncrypt, KeyInit}; use hkdf::Hkdf; use sha2::Sha256; use treemath::{ @@ -88,11 +85,6 @@ impl ParentNodePlaintext { } } -pub struct PublicTree { - leaf_nodes: Vec>, - parent_nodes: Vec>, -} - #[derive(Debug, PartialEq, Clone)] enum Resolution { Empty, @@ -100,12 +92,64 @@ enum Resolution { Both(TreeNodeIndex, TreeNodeIndex), } +trait Resolver { + fn has_node(&self, index: usize) -> bool; + + /// Calculates the resolution of a node + fn resolution(&self, index: TreeNodeIndex) -> Resolution { + return match index { + TreeNodeIndex::Leaf(l) => { + if self.has_node(l.usize()) { + return Resolution::One(TreeNodeIndex::from(l)); + } else { + return Resolution::Empty; + } + } + TreeNodeIndex::Parent(p) => { + let left = TreeNodeIndex::from(left(p)); + let right = TreeNodeIndex::from(right(p)); + match (self.resolution(left), self.resolution(right)) { + (Resolution::One(l), Resolution::One(r)) => Resolution::Both(l, r), + (Resolution::One(l), Resolution::Empty) => Resolution::One(l), + (Resolution::Empty, Resolution::One(r)) => Resolution::One(r), + (Resolution::Empty, Resolution::Empty) => Resolution::Empty, + (Resolution::Both(_, _), Resolution::Empty) => Resolution::One(left), + (Resolution::Empty, Resolution::Both(_, _)) => Resolution::One(right), + (Resolution::Both(_, _), Resolution::Both(_, _)) => { + Resolution::Both(left, right) + } + (Resolution::One(x), Resolution::Both(_, _)) => Resolution::Both(x, right), + (Resolution::Both(_, _), Resolution::One(x)) => Resolution::Both(left, x), + } + } + }; + } +} + +#[derive(Debug, PartialEq, Clone)] +pub struct PublicTree { + leaf_nodes: Vec>, + parent_nodes: Vec>, +} + +impl Resolver for PublicTree { + fn has_node(&self, index: usize) -> bool { + self.leaf_nodes[index].is_some() + } +} + pub struct WrappingTree { init_secret: InitSecret, leaf_nodes: Vec, parent_nodes: Vec, } +impl Resolver for WrappingTree { + fn has_node(&self, index: usize) -> bool { + self.leaf_nodes[index].content.is_some() + } +} + impl WrappingTree { /// Create a new empty tree pub fn new(init_secret: InitSecret) -> WrappingTree { @@ -134,7 +178,26 @@ impl WrappingTree { // We populate the tree with the public nodes and decrypt them // starting from the root - tree.expand_nodes(root, derive_node_key(&init_secret, root), &public_tree); + println!( + "Expanding nodes with root key {:?}", + derive_node_key(&init_secret, root) + ); + match public_tree.resolution(root) { + Resolution::Empty => { + // The root node is blank, nothing to do + println!("Root node's resolution is empty"); + } + Resolution::One(child) => { + // The root node has only one child, we start from there + println!("Root node has only one child: {:?}", child); + tree.expand_nodes(child, derive_node_key(&init_secret, child), &public_tree); + } + Resolution::Both(_, _) => { + // The root node has two children, we stay at the root + println!("Root node has two children"); + tree.expand_nodes(root, derive_node_key(&init_secret, root), &public_tree); + } + } tree } @@ -147,7 +210,17 @@ impl WrappingTree { // If the leaf node is present in the public tree, we decrypt it // and store it in the tree if let Some(ciphertext) = public_tree.leaf_nodes[l.usize()] { + println!("Decrypting leaf node {} with key {:?}", l.usize(), key); self.leaf_nodes[l.usize()] = self.decrypt_leaf_node(ciphertext, &key); + debug_assert_eq!( + self.leaf_nodes[l.usize()] + .clone() + .content + .unwrap() + .plaintext + .key[..32], + [1u8; 32] + ); } } TreeNodeIndex::Parent(p) => { @@ -155,7 +228,8 @@ impl WrappingTree { // and store it in the tree. We also extract the keys for the children. let (left_key, right_key) = if let Some(ciphertext) = public_tree.parent_nodes[p.usize()] { - let node = self.expand_parent_node(ciphertext, &key); + println!("Decrypting parent node {} with key {:?}", p.usize(), key); + let node = self.decrypt_parent_node(ciphertext, &key); let left_key = node.content.as_ref().map(|c| c.plaintext.left_key).unwrap(); let right_key = node .content @@ -179,7 +253,7 @@ impl WrappingTree { } /// Decrypt a parent node from the public tree - fn expand_parent_node(&self, ciphertext: ParentCiphertext, key: &AesKey) -> ParentNode { + fn decrypt_parent_node(&self, ciphertext: ParentCiphertext, key: &AesKey) -> ParentNode { let plaintext = decrypt(key, &ciphertext); ParentNode { content: Some(ParentNodeContent { @@ -200,9 +274,20 @@ impl WrappingTree { } } + #[inline(always)] + fn size(&self) -> TreeSize { + TreeSize::from_leaf_count(self.leaf_nodes.len()) + } + /// Create a new epoch pub fn new_epoch(&mut self, init_secret: InitSecret) { self.init_secret = init_secret; + + // Re-encrypt the root node + match root(self.size()) { + TreeNodeIndex::Leaf(l) => self.rewrap_leaf(l), + TreeNodeIndex::Parent(p) => self.rewrap_parent(p), + } } /// Add a new node to the tree @@ -222,6 +307,11 @@ impl WrappingTree { // Wrap the leaf node let leaf_key = derive_node_key(&self.init_secret, TreeNodeIndex::Leaf(index)); + println!( + "Encrypting leaf node {} with key {:?}", + index.usize(), + leaf_key + ); self.leaf_nodes[index.usize()] = encrypt_leaf_node(&leaf_key, leaf_node_plaintext); // Wrap the parent nodes up to the root @@ -232,7 +322,23 @@ impl WrappingTree { /// Remove a node from the tree pub fn remove(&mut self, index: LeafNodeIndex) { + println!( + "Removing node {:?}, leaf node length: {}", + index, + self.leaf_nodes.len() + ); + if self.leaf_nodes.len() == 0 { + println!("Tree is empty, nothing to remove"); + // The tree is empty, nothing to do + return; + } + if index.usize() >= self.leaf_nodes.len() { + // The node is outside the current tree, nothing to do + return; + } + + if self.leaf_nodes[index.usize()].content.is_none() { // The node is already blank, nothing to do return; } @@ -250,8 +356,10 @@ impl WrappingTree { } let right_most_leaf_index = if let Some(right_most_leaf_index) = right_most_leaf_index { + println!("Right most leaf index: {:?}", right_most_leaf_index); right_most_leaf_index } else { + println!("Tree is empty, clearing it"); // The tree is empty, we can clear it self.leaf_nodes.clear(); self.parent_nodes.clear(); @@ -260,8 +368,13 @@ impl WrappingTree { // Wrap the parent nodes up to the root if self.leaf_nodes.len() > 1 { + println!("Wrapping up from node {:?}", index); self.wrap_up(index) - }; + } else { + // If the tree only has one node, we rewrap it since it's the root + println!("Rewrapping root/leaf node after removal"); + self.rewrap_leaf(LeafNodeIndex::new(0)); + } let tree_size = TreeSize::from_leaf_count(self.leaf_nodes.len()); let desired_tree_size = TreeSize::new_with_index(right_most_leaf_index); @@ -278,13 +391,69 @@ impl WrappingTree { } } + fn rewrap_leaf(&mut self, index: LeafNodeIndex) { + if index.usize() >= self.leaf_nodes.len() { + println!("Rewrap leaf: Index is outside the tree, nothing to do"); + // The index is outside the tree, nothing to do + return; + } + + let key = derive_node_key(&self.init_secret, TreeNodeIndex::Leaf(index)); + + println!("Rewrapping leaf node {:?}", index); + + if let Some(ref leaf_node) = self.leaf_nodes[index.usize()].content { + self.leaf_nodes[index.usize()] = encrypt_leaf_node(&key, leaf_node.plaintext.clone()); + } + } + + fn rewrap_parent(&mut self, index: ParentNodeIndex) { + println!("Rewrapping parent node {:?}", index); + if index.usize() >= self.parent_nodes.len() { + println!("Rewrap parent: Index is outside the tree, nothing to do"); + // The index is outside the tree, nothing to do + return; + } + + let res = self.resolution(TreeNodeIndex::Parent(index)); + + match res { + Resolution::Empty => { + panic!("rewrap_parent: Resolution is empty"); + } + Resolution::One(child) => { + println!("Rewrapping partial parent node {:?}", child); + + match child { + TreeNodeIndex::Leaf(leaf_node_index) => { + self.rewrap_leaf(leaf_node_index); + } + TreeNodeIndex::Parent(parent_node_index) => { + self.rewrap_parent(parent_node_index); + } + } + } + Resolution::Both(_, _) => { + let key = derive_node_key(&self.init_secret, TreeNodeIndex::Parent(index)); + + println!("Rewrapping full parent node {:?}", index); + + if let Some(ref parent_node) = self.parent_nodes[index.usize()].content { + self.parent_nodes[index.usize()] = + encrypt_parent_node(&key, parent_node.plaintext.clone()); + } + } + } + } + /// Wrap up from a given index, which has either been set or blanked. /// This function will rewrap the parent nodes up to the root and skip /// the nodes that only have one child. fn wrap_up(&mut self, index: LeafNodeIndex) { - let tree_size = TreeSize::from_leaf_count(self.leaf_nodes.len()); + let root = root(self.size()); if index.usize() >= self.leaf_nodes.len() { + println!("Wrap up: Index is outside the tree, nothing to do"); // The index is outside the tree, nothing to do return; } @@ -294,47 +463,96 @@ impl WrappingTree { // We check the resolution of the nodes in the direct path and either // skip or rewrap them - for node_index in direct_path(index, tree_size) { + for node_index in direct_path(index, self.size()) { match self.resolution(node_index.into()) { Resolution::Empty => { // The node is blank, nothing to do continue; } - Resolution::One(_) => { + Resolution::One(child) => { + println!("Wrap up: one with node {:?}", node_index); // The node has only one child, we need to skip it - // and we blank it - self.parent_nodes[node_index.usize()] = ParentNode { content: None }; - continue; + // and we blank it, unless it is the root node. + if TreeNodeIndex::from(node_index) != root { + // Extract the keys from the parent node + if let Some(node_content) = + self.parent_nodes[node_index.usize()].content.as_ref() + { + last_key = if child.u32() < TreeNodeIndex::from(node_index).u32() { + // The child was on the left, hence we keep the left key + println!("\tThe child was on the left"); + node_content.plaintext.left_key + } else { + // The child was on the right, hence we keep the right key + println!("\tThe child was on the right"); + node_content.plaintext.right_key + }; + println!("\tExtracting key from parent node {:?}", last_key); + } + + println!("\tBlanking node {:?}", node_index); + self.parent_nodes[node_index.usize()] = ParentNode { content: None }; + + continue; + } else { + println!("Wrapped up to root node, but there is only one child"); + // Blank the root node + self.parent_nodes[node_index.usize()] = ParentNode { content: None }; + // Rewrap child + println!("\tRewrapping child node {:?}", child); + match child { + TreeNodeIndex::Leaf(leaf_node_index) => { + self.rewrap_leaf(leaf_node_index); + } + TreeNodeIndex::Parent(parent_node_index) => { + self.rewrap_parent(parent_node_index); + } + } + + break; + } } - Resolution::Both(_, _) => { + Resolution::Both(left_index, right_index) => { // The node has two children, we need to rewrap it. We only // want to derive a key for the node in the direct path and // keep the key for the other one. + println!( + "Wrap up: both with node {:?}\n\tLeft: {:?}\n\tRight: {:?}", + node_index, left_index, right_index + ); + let left_key; let right_key; if last_index.u32() < TreeNodeIndex::from(node_index).u32() { // We came from the left, hence we keep the right key + println!("We came from the left"); left_key = last_key; right_key = self .parent_nodes .get(node_index.usize()) .and_then(|n| n.content.as_ref()) .map(|c| c.plaintext.right_key) - .unwrap_or_else(|| self.next_or_derive(node_index)); + .unwrap_or_else(|| self.next_or_derive(right_index)); // find next key instead } else { // We came from the right, hence we keep the left key + println!("We came from the right"); left_key = self .parent_nodes .get(node_index.usize()) .and_then(|n| n.content.as_ref()) .map(|c| c.plaintext.left_key) - .unwrap_or_else(|| self.next_or_derive(node_index)); + .unwrap_or_else(|| self.next_or_derive(left_index)); right_key = last_key; } + println!( + "Parent node {:?}: \n\tLeft key: {:?}, \n\tRight key: {:?}", + node_index, left_key, right_key + ); + let parent_node_plaintext = ParentNodePlaintext { left_key, right_key, @@ -342,6 +560,11 @@ impl WrappingTree { // Encrypt the two keys and store the new parent node last_key = derive_node_key(&self.init_secret, node_index.into()); + println!( + "Encrypting parent node {} with key {:?}", + node_index.usize(), + last_key + ); let new_node = encrypt_parent_node(&last_key, parent_node_plaintext); self.parent_nodes[node_index.usize()] = new_node; } @@ -350,36 +573,6 @@ impl WrappingTree { } } - /// Calculates the resolution of a node - fn resolution(&self, index: TreeNodeIndex) -> Resolution { - return match index { - TreeNodeIndex::Leaf(l) => { - if self.leaf_nodes[l.usize()].content.is_some() { - return Resolution::One(TreeNodeIndex::from(l)); - } else { - return Resolution::Empty; - } - } - TreeNodeIndex::Parent(p) => { - let left = TreeNodeIndex::from(left(p)); - let right = TreeNodeIndex::from(right(p)); - match (self.resolution(left), self.resolution(right)) { - (Resolution::One(l), Resolution::One(r)) => Resolution::Both(l, r), - (Resolution::One(l), Resolution::Empty) => Resolution::One(l), - (Resolution::Empty, Resolution::One(r)) => Resolution::One(r), - (Resolution::Empty, Resolution::Empty) => Resolution::Empty, - (Resolution::Both(_, _), Resolution::Empty) => Resolution::One(left), - (Resolution::Empty, Resolution::Both(_, _)) => Resolution::One(right), - (Resolution::Both(_, _), Resolution::Both(_, _)) => { - Resolution::Both(left, right) - } - (Resolution::One(x), Resolution::Both(_, _)) => Resolution::Both(x, right), - (Resolution::Both(_, _), Resolution::One(x)) => Resolution::Both(left, x), - } - } - }; - } - /// Export the public tree, i.e. the tree without the plaintexts pub fn export_public_tree(&self) -> PublicTree { PublicTree { @@ -396,24 +589,34 @@ impl WrappingTree { } } - /// Looks for the key in the next node in the direct path, if it is not found - /// we assume it must be the root node and we derive the key - fn next_or_derive(&self, index: ParentNodeIndex) -> AesKey { + /// Looks for the key in the next non-blank node in the direct path, if it + /// is not found we assume it must be the root node and we derive the key + fn next_or_derive(&self, index: TreeNodeIndex) -> AesKey { let root = root(TreeSize::from_leaf_count(self.leaf_nodes.len())); - let mut index = TreeNodeIndex::Parent(index); + let mut path_index = index; + + println!("next_or_derive: {:?}", index); - while index != root { - let parent = parent(index); + while path_index != root { + let parent = parent(path_index); + println!("Looking for key in node {:?}", parent); if let Some(content) = &self.parent_nodes[parent.usize()].content { - if index > TreeNodeIndex::Parent(parent) { + if path_index > TreeNodeIndex::Parent(parent) { + println!("Found right key"); return content.plaintext.right_key; } else { + println!("Found left key"); return content.plaintext.left_key; } } - index = TreeNodeIndex::Parent(parent); + path_index = TreeNodeIndex::Parent(parent); } - derive_node_key(&self.init_secret, root) + println!( + "Deriving key for index {:?}: {:?}", + index, + derive_node_key(&self.init_secret, index) + ); + derive_node_key(&self.init_secret, index) } /// Export the init secret @@ -441,27 +644,51 @@ fn derive_node_key(init_secret: &InitSecret, index: TreeNodeIndex) -> AesKey { fn encrypt(key: &AesKey, plaintext: &[u8; 64]) -> [u8; 64] { let mut buffer = [0u8; 64]; buffer.copy_from_slice(plaintext); - let b1 = Block::from_slice(&buffer[..16]).to_owned(); - let b2 = Block::from_slice(&buffer[16..32]).to_owned(); - let b3 = Block::from_slice(&buffer[32..48]).to_owned(); - let b4 = Block::from_slice(&buffer[48..64]).to_owned(); - let cipher = aes::Aes256::new(&GenericArray::from_slice(key)); - cipher.encrypt_blocks(&mut [b1, b2, b3, b4]); + let mut b1 = GenericArray::from_slice(buffer[..16].as_ref()).to_owned(); + let mut b2 = GenericArray::from_slice(buffer[16..32].as_ref()).to_owned(); + let mut b3 = GenericArray::from_slice(buffer[32..48].as_ref()).to_owned(); + let mut b4 = GenericArray::from_slice(buffer[48..64].as_ref()).to_owned(); + + let cipher = aes::Aes256::new(&GenericArray::from(*key)); + + cipher.encrypt_block(&mut b1); + cipher.encrypt_block(&mut b2); + cipher.encrypt_block(&mut b3); + cipher.encrypt_block(&mut b4); + + // Copy blocks into buffer + let mut buffer = [0u8; 64]; + buffer[..16].copy_from_slice(b1.as_slice()); + buffer[16..32].copy_from_slice(b2.as_slice()); + buffer[32..48].copy_from_slice(b3.as_slice()); + buffer[48..64].copy_from_slice(b4.as_slice()); buffer } -fn decrypt(key: &[u8], ciphertext: &[u8; 64]) -> [u8; 64] { +fn decrypt(key: &AesKey, ciphertext: &[u8; 64]) -> [u8; 64] { let mut buffer = [0u8; 64]; buffer.copy_from_slice(ciphertext); - let b1 = Block::from_slice(&buffer[..16]).to_owned(); - let b2 = Block::from_slice(&buffer[16..32]).to_owned(); - let b3 = Block::from_slice(&buffer[32..48]).to_owned(); - let b4 = Block::from_slice(&buffer[48..64]).to_owned(); - let cipher = aes::Aes256::new(&GenericArray::from_slice(key)); - cipher.decrypt_blocks(&mut [b1, b2, b3, b4]); + let mut b1 = GenericArray::from_slice(buffer[..16].as_ref()).to_owned(); + let mut b2 = GenericArray::from_slice(buffer[16..32].as_ref()).to_owned(); + let mut b3 = GenericArray::from_slice(buffer[32..48].as_ref()).to_owned(); + let mut b4 = GenericArray::from_slice(buffer[48..64].as_ref()).to_owned(); + + let cipher = aes::Aes256::new(&GenericArray::from(*key)); + + cipher.decrypt_block(&mut b1); + cipher.decrypt_block(&mut b2); + cipher.decrypt_block(&mut b3); + cipher.decrypt_block(&mut b4); + + // Copy blocks into buffer + let mut buffer = [0u8; 64]; + buffer[..16].copy_from_slice(b1.as_slice()); + buffer[16..32].copy_from_slice(b2.as_slice()); + buffer[32..48].copy_from_slice(b3.as_slice()); + buffer[48..64].copy_from_slice(b4.as_slice()); buffer } @@ -553,38 +780,122 @@ fn remove_leaf() { #[test] fn fuzz() { - const LEAF_COUNT: u32 = 100; - const OPERATION_COUNT: usize = 1_000; + fn eval_tree(tree: &WrappingTree) { + let public_tree = tree.export_public_tree(); + let root_secret = tree.export_root_secret(); + let new_tree = WrappingTree::from_public_tree(public_tree.clone(), *root_secret); + let leaf_node_plaintexts = tree.export_leaf_node_plaintexts(); + let new_leaf_node_plaintexts = new_tree.export_leaf_node_plaintexts(); + assert_eq!(leaf_node_plaintexts, new_leaf_node_plaintexts); + } + + const LEAF_COUNT: u32 = 32; + const OPERATION_COUNT: usize = 100; + const EPOCH_COUNT: usize = 10; let init_secret = [0u8; 32]; let mut tree = WrappingTree::new(init_secret); let leaf_node_plaintext = LeafNodePlaintext { - key: [0u8; 32], - mac: [0u8; 32], + key: [1u8; 32], + mac: [2u8; 32], }; - for _ in 0..OPERATION_COUNT { - let index = LeafNodeIndex::new(rand::random::() % LEAF_COUNT); - let add = rand::random::(); + for epoch in 0..EPOCH_COUNT { + println!("--- Epoch: {}", epoch); - if add { - tree.add(index, leaf_node_plaintext.clone()); - } else { - tree.remove(index); - } + for operation in 0..OPERATION_COUNT { + println!("-- Operation: {}", operation); + let index = LeafNodeIndex::new(rand::random::() % LEAF_COUNT); + let add = rand::random::(); - let public_tree = tree.export_public_tree(); - let root_secret = tree.export_root_secret(); - - let new_tree = WrappingTree::from_public_tree(public_tree, *root_secret); - - let leaf_node_plaintexts = tree.export_leaf_node_plaintexts(); - let new_leaf_node_plaintexts = new_tree.export_leaf_node_plaintexts(); + if add { + println!("+ {:?}", index); + tree.add(index, leaf_node_plaintext.clone()); + } else { + println!("- {:?}", index); + tree.remove(index); + } - assert_eq!(leaf_node_plaintexts.len(), new_leaf_node_plaintexts.len()); - assert_eq!(leaf_node_plaintexts, new_leaf_node_plaintexts); + eval_tree(&tree); + } let new_init_secret = [rand::random::(); 32]; tree.new_epoch(new_init_secret); + + eval_tree(&tree); } } + +#[test] +fn pathological_case2() { + let init_secret = [3u8; 32]; + let mut tree = WrappingTree::new(init_secret); + let leaf_node_plaintext = LeafNodePlaintext { + key: [1u8; 32], + mac: [2u8; 32], + }; + + tree.add(LeafNodeIndex::new(3), leaf_node_plaintext.clone()); + tree.add(LeafNodeIndex::new(1), leaf_node_plaintext.clone()); + tree.add(LeafNodeIndex::new(0), leaf_node_plaintext.clone()); + + println!("First tree: {:?}", tree.export_public_tree().leaf_nodes); + + let public_tree = tree.export_public_tree(); + let root_secret = tree.export_root_secret(); + let new_tree = WrappingTree::from_public_tree(public_tree, *root_secret); + let leaf_node_plaintexts = tree.export_leaf_node_plaintexts(); + let new_leaf_node_plaintexts = new_tree.export_leaf_node_plaintexts(); + assert_eq!(leaf_node_plaintexts, new_leaf_node_plaintexts); + + println!("Second epoch"); + let new_init_secret = [5u8; 32]; + tree.new_epoch(new_init_secret); + + println!("Exporting tree"); + + let public_tree = tree.export_public_tree(); + let root_secret = tree.export_root_secret(); + let new_tree = WrappingTree::from_public_tree(public_tree, *root_secret); + let leaf_node_plaintexts = tree.export_leaf_node_plaintexts(); + let new_leaf_node_plaintexts = new_tree.export_leaf_node_plaintexts(); + assert_eq!(leaf_node_plaintexts, new_leaf_node_plaintexts); + + println!("Second tree: {:?}", tree.export_public_tree().leaf_nodes); + + let remove_node = 0; + + println!("Removing node {}", remove_node); + + tree.remove(LeafNodeIndex::new(remove_node)); + + println!("Node {} removed", remove_node); + + println!("Third tree: {:?}", tree.export_public_tree().leaf_nodes); + + let public_tree = tree.export_public_tree(); + let root_secret = tree.export_root_secret(); + let new_tree = WrappingTree::from_public_tree(public_tree, *root_secret); + let leaf_node_plaintexts = tree.export_leaf_node_plaintexts(); + let new_leaf_node_plaintexts = new_tree.export_leaf_node_plaintexts(); + assert_eq!(leaf_node_plaintexts, new_leaf_node_plaintexts); + + let new_init_secret = [7u8; 32]; + tree.new_epoch(new_init_secret); + + let public_tree = tree.export_public_tree(); + let root_secret = tree.export_root_secret(); + let new_tree = WrappingTree::from_public_tree(public_tree, *root_secret); + let leaf_node_plaintexts = tree.export_leaf_node_plaintexts(); + let new_leaf_node_plaintexts = new_tree.export_leaf_node_plaintexts(); + assert_eq!(leaf_node_plaintexts, new_leaf_node_plaintexts); +} + +#[test] +fn root_test() { + assert_eq!(root(TreeSize::new(1)).u32(), 0); + assert_eq!(TreeSize::new(1).parent_count(), 0); + assert_eq!(TreeSize::new(1).leaf_count(), 1); + assert_eq!(TreeSize::new(0).parent_count(), 0); + assert_eq!(TreeSize::new(0).leaf_count(), 0); +} diff --git a/src/treemath.rs b/src/treemath.rs index 87b8e9d..b47849a 100644 --- a/src/treemath.rs +++ b/src/treemath.rs @@ -134,7 +134,10 @@ impl TreeSize { /// Creates a new `TreeSize` from a specific leaf count pub(crate) fn from_leaf_count(leaf_count: usize) -> Self { - TreeSize::new((leaf_count * 2 - 1) as u32) + match leaf_count { + 0 => TreeSize::new(1), + _ => TreeSize::new((leaf_count * 2 - 1) as u32), + } } /// Return the number of leaf nodes in the tree.