Skip to content

Commit

Permalink
update: leafs to leaves
Browse files Browse the repository at this point in the history
  • Loading branch information
olivmath committed Dec 24, 2023
1 parent b618223 commit 3a8d38c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 52 deletions.
68 changes: 33 additions & 35 deletions merkly/mtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,44 @@
validate_leafs,
)


class MerkleTree:
"""
# 🌳 Merkle Tree implementation
## Args:
- leafs: List of raw data
- hash_function (Callable[[str], str], optional): Function that hashes the data.
* Defaults to `keccak` if not provided. It must have the signature (data: str) -> str.
- leaves: List of raw data
- hash_function (Callable[[bytes, bytes], bytes], optional): Function that hashes the data.
* Defaults to `keccak` if not provided
"""

def __init__(
self,
leafs: List[str],
leaves: List[str],
hash_function: Callable[[bytes, bytes], bytes] = lambda x, y: keccak(x + y),
) -> None:
validate_leafs(leafs)
validate_leafs(leaves)
validate_hash_function(hash_function)
self.hash_function: Callable[[bytes, bytes], bytes] = hash_function
self.raw_leafs: List[str] = leafs
self.leafs: List[str] = self.__hash_leafs(leafs)
self.short_leafs: List[str] = self.short(self.leafs)
self.raw_leaves: List[str] = leaves
self.leaves: List[str] = self.__hash_leaves(leaves)
self.short_leaves: List[str] = self.short(self.leaves)

def __hash_leafs(self, leafs: List[str]) -> List[str]:
return list(map(lambda x: self.hash_function(x.encode(), bytes()), leafs))
def __hash_leaves(self, leaves: List[str]) -> List[str]:
return list(map(lambda x: self.hash_function(x.encode(), bytes()), leaves))

def __repr__(self) -> str:
return f"""MerkleTree(\nraw_leafs: {self.raw_leafs}\nleafs: {self.leafs}\nshort_leafs: {self.short(self.leafs)})"""
return f"""MerkleTree(\nraw_leaves: {self.raw_leaves}\nleaves: {self.leaves}\nshort_leaves: {self.short(self.leaves)})"""

def short(self, data: List[str]) -> List[str]:
return [x[:2] for x in data]

@property
def root(self) -> bytes:
return self.make_root(self.leafs)
return self.make_root(self.leaves)

def proof(self, raw_leaf: str) -> List[Node]:
return self.make_proof(
self.leafs, [], self.hash_function(raw_leaf.encode(), bytes())
self.leaves, [], self.hash_function(raw_leaf.encode(), bytes())
)

def verify(self, proof: List[bytes], raw_leaf: str) -> bool:
Expand All @@ -80,31 +78,31 @@ def concat_nodes(left: Node, right: Node) -> Node:

return reduce(concat_nodes, full_proof).data == self.root

def make_root(self, leafs: List[bytes]) -> List[str]:
while len(leafs) > 1:
def make_root(self, leaves: List[bytes]) -> bytes:
while len(leaves) > 1:
next_level = []
for i in range(0, len(leafs) - 1, 2):
next_level.append(self.hash_function(leafs[i], leafs[i + 1]))
for i in range(0, len(leaves) - 1, 2):
next_level.append(self.hash_function(leaves[i], leaves[i + 1]))

if len(leafs) % 2 == 1:
next_level.append(leafs[-1])
if len(leaves) % 2 == 1:
next_level.append(leaves[-1])

leafs = next_level
leaves = next_level

return leafs[0]
return leaves[0]

def make_proof(
self, leafs: List[bytes], proof: List[Node], leaf: bytes
self, leaves: List[bytes], proof: List[Node], leaf: bytes
) -> List[Node]:
"""
# Make a proof
## Dev:
- if the `leaf` index is less than half the size of the `leafs`
- if the `leaf` index is less than half the size of the `leaves`
list then the right side must reach root and vice versa
## Args:
- leafs: List of leafs
- leaves: List of leaves
- proof: Accumulated proof
- leaf: Leaf for which to create the proof
Expand All @@ -113,25 +111,25 @@ def make_proof(
"""

try:
index = leafs.index(leaf)
index = leaves.index(leaf)
except ValueError as err:
msg = f"Leaf: {leaf} does not exist in the tree: {leafs}"
msg = f"Leaf: {leaf} does not exist in the tree: {leaves}"
raise ValueError(msg) from err

if is_power_2(len(leafs)) is False:
return self.mix_tree(leafs, [], index)
if is_power_2(len(leaves)) is False:
return self.mix_tree(leaves, [], index)

if len(leafs) == 2:
if len(leaves) == 2:
if index == 1:
proof.append(Node(data=leafs[0], side=Side.LEFT))
proof.append(Node(data=leaves[0], side=Side.LEFT))
else:
proof.append(Node(data=leafs[1], side=Side.RIGHT))
proof.append(Node(data=leaves[1], side=Side.RIGHT))
proof.reverse()
return proof

left, right = half(leafs)
left, right = half(leaves)

if index < len(leafs) / 2:
if index < len(leaves) / 2:
proof.append(Node(data=self.make_root(right), side=Side.RIGHT))
return self.make_proof(left, proof, leaf)
else:
Expand Down
34 changes: 17 additions & 17 deletions test/merkle_root/test_merkle_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@


def test_simple_merkle_tree_constructor():
leafs = ["a", "b", "c", "d"]
tree = MerkleTree(leafs)
leaves = ["a", "b", "c", "d"]
tree = MerkleTree(leaves)

assert tree.raw_leafs == leafs
assert tree.raw_leaves == leaves
for i, j in zip(
tree.short_leafs,
tree.short_leaves,
[
bytes.fromhex("3ac2"),
bytes.fromhex("b555"),
Expand All @@ -19,7 +19,7 @@ def test_simple_merkle_tree_constructor():
],
):
assert i == j
assert tree.leafs == [
assert tree.leaves == [
bytes.fromhex(
"3ac225168df54212a25c1c01fd35bebfea408fdac2e31ddd6f80a4bbf9a5f1cb"
),
Expand All @@ -41,7 +41,7 @@ def test_simple_merkle_tree_constructor():


@mark.parametrize(
"leafs, root",
"leaves, root",
[
(
["a", "b", "c", "d", "e", "f", "g", "h", "1"],
Expand All @@ -65,14 +65,14 @@ def test_simple_merkle_tree_constructor():
),
],
)
def test_simple_merkle_root_with_keccak256(leafs: List[str], root: str):
tree = MerkleTree(leafs)
def test_simple_merkle_root_with_keccak256(leaves: List[str], root: str):
tree = MerkleTree(leaves)
result = tree.root.hex()
assert result == root


@mark.parametrize(
"leafs, root",
"leaves, root",
[
(
["a", "b", "c", "d", "e", "f", "g", "h", "1"],
Expand All @@ -96,20 +96,20 @@ def test_simple_merkle_root_with_keccak256(leafs: List[str], root: str):
),
],
)
def test_simple_merkle_root_with_sha_256(leafs: List[str], root: str):
def test_simple_merkle_root_with_sha_256(leaves: List[str], root: str):
def sha_256(x: bytes, y: bytes) -> bytes:
data = x + y
h = hashlib.sha256()
h.update(data)
return h.digest()

tree = MerkleTree(leafs, hash_function=sha_256)
tree = MerkleTree(leaves, hash_function=sha_256)
result = tree.root.hex()
assert result == root


@mark.parametrize(
"leafs, root",
"leaves, root",
[
(
["a", "b", "c", "d", "e", "f", "g", "h", "1"],
Expand All @@ -133,20 +133,20 @@ def sha_256(x: bytes, y: bytes) -> bytes:
),
],
)
def test_simple_merkle_root_with_shake256(leafs: List[str], root: str):
def test_simple_merkle_root_with_shake256(leaves: List[str], root: str):
def shake_256(x: bytes, y: bytes) -> bytes:
data = x + y
h = hashlib.shake_256()
h.update(data)
return h.digest(32)

tree = MerkleTree(leafs, hash_function=shake_256)
tree = MerkleTree(leaves, hash_function=shake_256)
result = tree.root.hex()
assert result == root


@mark.parametrize(
"leafs, root",
"leaves, root",
[
(
["a", "b", "c", "d", "e", "f", "g", "h", "1"],
Expand All @@ -170,13 +170,13 @@ def shake_256(x: bytes, y: bytes) -> bytes:
),
],
)
def test_simple_merkle_root_with_sha3_256(leafs: List[str], root: str):
def test_simple_merkle_root_with_sha3_256(leaves: List[str], root: str):
def sha3_256(x: bytes, y: bytes) -> bytes:
data = x + y
h = hashlib.sha3_256()
h.update(data)
return h.digest()

tree = MerkleTree(leafs, hash_function=sha3_256)
tree = MerkleTree(leaves, hash_function=sha3_256)
result = tree.root.hex()
assert result == root

0 comments on commit 3a8d38c

Please sign in to comment.