From 1ecb23182013dd3dfbbb8bfd690b870bcb84616f Mon Sep 17 00:00:00 2001 From: raphaelrobert Date: Mon, 15 Apr 2024 23:43:19 +0200 Subject: [PATCH] Correctness tweaks --- src/lib.rs | 93 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 65 insertions(+), 28 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 29123e4..3ef20de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,11 @@ use aes::{ }; use hkdf::Hkdf; use sha2::Sha256; -use treemath::{direct_path, left, right, root, LeafNodeIndex, TreeNodeIndex, TreeSize}; +use treemath::{ + direct_path, left, right, root, LeafNodeIndex, ParentNodeIndex, TreeNodeIndex, TreeSize, +}; + +use crate::treemath::parent; mod treemath; @@ -103,7 +107,7 @@ pub struct WrappingTree { } impl WrappingTree { - // Create a new empty tree + /// Create a new empty tree pub fn new(init_secret: InitSecret) -> WrappingTree { WrappingTree { init_secret, @@ -112,7 +116,7 @@ impl WrappingTree { } } - // Create a new tree from a public tree and a root secret + /// Create a new tree from a public tree and a root secret pub fn from_public_tree(public_tree: PublicTree, init_secret: InitSecret) -> WrappingTree { // If the public tree is empty, we return a new empty tree if public_tree.leaf_nodes.is_empty() { @@ -135,8 +139,8 @@ impl WrappingTree { tree } - // Expand a node and its children by decrypting the nodes from the public - // tree and storing them in the tree + /// Expand a node and its children by decrypting the nodes from the public + /// tree and storing them in the tree fn expand_nodes(&mut self, index: TreeNodeIndex, key: AesKey, public_tree: &PublicTree) { match index { TreeNodeIndex::Leaf(l) => { @@ -196,12 +200,12 @@ impl WrappingTree { } } - // Create a new epoch + /// Create a new epoch pub fn new_epoch(&mut self, init_secret: InitSecret) { self.init_secret = init_secret; } - // Add a new node to the tree + /// Add a new node to the tree pub fn add(&mut self, index: LeafNodeIndex, leaf_node_plaintext: LeafNodePlaintext) { // Extend the nodes vector if necessary if index.usize() >= self.leaf_nodes.len() { @@ -218,7 +222,7 @@ impl WrappingTree { // Wrap the leaf node let leaf_key = derive_node_key(&self.init_secret, TreeNodeIndex::Leaf(index)); - self.leaf_nodes[index.usize()] = wrap_leaf(&leaf_key, leaf_node_plaintext); + self.leaf_nodes[index.usize()] = encrypt_leaf_node(&leaf_key, leaf_node_plaintext); // Wrap the parent nodes up to the root if self.leaf_nodes.len() > 1 { @@ -226,7 +230,7 @@ impl WrappingTree { }; } - // Remove a node from the tree + /// Remove a node from the tree pub fn remove(&mut self, index: LeafNodeIndex) { if index.usize() >= self.leaf_nodes.len() { // The node is already blank, nothing to do @@ -248,11 +252,17 @@ impl WrappingTree { let right_most_leaf_index = if let Some(right_most_leaf_index) = right_most_leaf_index { right_most_leaf_index } else { + // The tree is empty, we can clear it self.leaf_nodes.clear(); self.parent_nodes.clear(); return; }; + // Wrap the parent nodes up to the root + if self.leaf_nodes.len() > 1 { + self.wrap_up(index) + }; + let tree_size = TreeSize::from_leaf_count(self.leaf_nodes.len()); let desired_tree_size = TreeSize::new_with_index(right_most_leaf_index); @@ -266,11 +276,6 @@ impl WrappingTree { ParentNode { content: None }, ); } - - // Wrap the parent nodes up to the root - if self.leaf_nodes.len() > 1 { - self.wrap_up(index) - }; } /// Wrap up from a given index, which has either been set or blanked. @@ -279,7 +284,13 @@ impl WrappingTree { fn wrap_up(&mut self, index: LeafNodeIndex) { let tree_size = TreeSize::from_leaf_count(self.leaf_nodes.len()); - let mut last_index = TreeNodeIndex::Leaf(LeafNodeIndex::new(0)); + if index.usize() >= self.leaf_nodes.len() { + // The index is outside the tree, nothing to do + return; + } + + let mut last_index = TreeNodeIndex::Leaf(index); + let mut last_key = derive_node_key(&self.init_secret, last_index); // We check the resolution of the nodes in the direct path and either // skip or rewrap them @@ -295,7 +306,7 @@ impl WrappingTree { self.parent_nodes[node_index.usize()] = ParentNode { content: None }; continue; } - Resolution::Both(left, right) => { + Resolution::Both(_, _) => { // The node has two children, we need to rewrap it. We only // want to derive a key for the node in the direct path and // keep the key for the other one. @@ -305,13 +316,14 @@ impl WrappingTree { if last_index.u32() < TreeNodeIndex::from(node_index).u32() { // We came from the left, hence we keep the right key - left_key = derive_node_key(&self.init_secret, left); + left_key = last_key; right_key = self .parent_nodes .get(node_index.usize()) .and_then(|n| n.content.as_ref()) .map(|c| c.plaintext.right_key) - .unwrap_or_else(|| derive_node_key(&self.init_secret, right)); + .unwrap_or_else(|| self.next_or_derive(node_index)); + // find next key instead } else { // We came from the right, hence we keep the left key left_key = self @@ -319,16 +331,18 @@ impl WrappingTree { .get(node_index.usize()) .and_then(|n| n.content.as_ref()) .map(|c| c.plaintext.left_key) - .unwrap_or_else(|| derive_node_key(&self.init_secret, left)); - right_key = derive_node_key(&self.init_secret, right); + .unwrap_or_else(|| self.next_or_derive(node_index)); + right_key = last_key; } let parent_node_plaintext = ParentNodePlaintext { left_key, right_key, }; - let key = derive_node_key(&self.init_secret, node_index.into()); - let new_node = wrap_parent(&key, parent_node_plaintext); + + // Encrypt the two keys and store the new parent node + last_key = derive_node_key(&self.init_secret, node_index.into()); + let new_node = encrypt_parent_node(&last_key, parent_node_plaintext); self.parent_nodes[node_index.usize()] = new_node; } } @@ -336,7 +350,7 @@ impl WrappingTree { } } - // Calculates the resolution of a node + /// Calculates the resolution of a node fn resolution(&self, index: TreeNodeIndex) -> Resolution { return match index { TreeNodeIndex::Leaf(l) => { @@ -366,7 +380,7 @@ impl WrappingTree { }; } - // Export the public tree, i.e. the tree without the plaintexts + /// Export the public tree, i.e. the tree without the plaintexts pub fn export_public_tree(&self) -> PublicTree { PublicTree { leaf_nodes: self @@ -382,12 +396,32 @@ impl WrappingTree { } } - // Export the init secret + /// Looks for the key in the next node in the direct path, if it is not found + /// we assume it must be the root node and we derive the key + fn next_or_derive(&self, index: ParentNodeIndex) -> AesKey { + let root = root(TreeSize::from_leaf_count(self.leaf_nodes.len())); + let mut index = TreeNodeIndex::Parent(index); + + while index != root { + let parent = parent(index); + if let Some(content) = &self.parent_nodes[parent.usize()].content { + if index > TreeNodeIndex::Parent(parent) { + return content.plaintext.right_key; + } else { + return content.plaintext.left_key; + } + } + index = TreeNodeIndex::Parent(parent); + } + derive_node_key(&self.init_secret, root) + } + + /// Export the init secret pub fn export_root_secret(&self) -> &InitSecret { &self.init_secret } - // Export the leaf node plaintexts along their index + /// Export the leaf node plaintexts along their index pub fn export_leaf_node_plaintexts(&self) -> Vec<(usize, LeafNodePlaintext)> { self.leaf_nodes .iter() @@ -432,7 +466,7 @@ fn decrypt(key: &[u8], ciphertext: &[u8; 64]) -> [u8; 64] { buffer } -fn wrap_leaf(key: &[u8; 32], plaintext: LeafNodePlaintext) -> LeafNode { +fn encrypt_leaf_node(key: &[u8; 32], plaintext: LeafNodePlaintext) -> LeafNode { LeafNode { content: Some(LeafNodeContent { ciphertext: encrypt(key, &plaintext.serialize()), @@ -441,7 +475,7 @@ fn wrap_leaf(key: &[u8; 32], plaintext: LeafNodePlaintext) -> LeafNode { } } -fn wrap_parent(key: &[u8; 32], plaintext: ParentNodePlaintext) -> ParentNode { +fn encrypt_parent_node(key: &[u8; 32], plaintext: ParentNodePlaintext) -> ParentNode { ParentNode { content: Some(ParentNodeContent { ciphertext: encrypt(key, &plaintext.serialize()), @@ -549,5 +583,8 @@ fn fuzz() { assert_eq!(leaf_node_plaintexts.len(), new_leaf_node_plaintexts.len()); assert_eq!(leaf_node_plaintexts, new_leaf_node_plaintexts); + + let new_init_secret = [rand::random::(); 32]; + tree.new_epoch(new_init_secret); } }