Skip to content

Commit

Permalink
update: leafs to leaves
Browse files Browse the repository at this point in the history
  • Loading branch information
olivmath committed Dec 24, 2023
1 parent b618223 commit ac53503
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 52 deletions.
68 changes: 33 additions & 35 deletions merkly/mtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,44 @@
validate_leafs,
)


class MerkleTree:
"""
# 🌳 Merkle Tree implementation
## Args:
- leafs: List of raw data
- hash_function (Callable[[str], str], optional): Function that hashes the data.
* Defaults to `keccak` if not provided. It must have the signature (data: str) -> str.
- leaves: List of raw data
- hash_function (Callable[[bytes, bytes], bytes], optional): Function that hashes the data.
* Defaults to `keccak` if not provided
"""

def __init__(
self,
leafs: List[str],
leaves: List[str],
hash_function: Callable[[bytes, bytes], bytes] = lambda x, y: keccak(x + y),
) -> None:
validate_leafs(leafs)
validate_leafs(leaves)
validate_hash_function(hash_function)
self.hash_function: Callable[[bytes, bytes], bytes] = hash_function
self.raw_leafs: List[str] = leafs
self.leafs: List[str] = self.__hash_leafs(leafs)
self.short_leafs: List[str] = self.short(self.leafs)
self.raw_leaves: List[str] = leaves
self.leaves: List[str] = self.__hash_leaves(leaves)
self.short_leaves: List[str] = self.short(self.leaves)

def __hash_leafs(self, leafs: List[str]) -> List[str]:
return list(map(lambda x: self.hash_function(x.encode(), bytes()), leafs))
def __hash_leaves(self, leaves: List[str]) -> List[str]:
return list(map(lambda x: self.hash_function(x.encode(), bytes()), leaves))

def __repr__(self) -> str:
return f"""MerkleTree(\nraw_leafs: {self.raw_leafs}\nleafs: {self.leafs}\nshort_leafs: {self.short(self.leafs)})"""
return f"""MerkleTree(\nraw_leaves: {self.raw_leaves}\nleaves: {self.leaves}\nshort_leaves: {self.short(self.leaves)})"""

def short(self, data: List[str]) -> List[str]:
return [x[:2] for x in data]

@property
def root(self) -> bytes:
return self.make_root(self.leafs)
return self.make_root(self.leaves)

def proof(self, raw_leaf: str) -> List[Node]:
return self.make_proof(
self.leafs, [], self.hash_function(raw_leaf.encode(), bytes())
self.leaves, [], self.hash_function(raw_leaf.encode(), bytes())
)

def verify(self, proof: List[bytes], raw_leaf: str) -> bool:
Expand All @@ -80,31 +78,31 @@ def concat_nodes(left: Node, right: Node) -> Node:

return reduce(concat_nodes, full_proof).data == self.root

def make_root(self, leafs: List[bytes]) -> List[str]:
while len(leafs) > 1:
def make_root(self, leaves: List[bytes]) -> bytes:
while len(leaves) > 1:
next_level = []
for i in range(0, len(leafs) - 1, 2):
next_level.append(self.hash_function(leafs[i], leafs[i + 1]))
for i in range(0, len(leaves) - 1, 2):
next_level.append(self.hash_function(leaves[i], leaves[i + 1]))

if len(leafs) % 2 == 1:
next_level.append(leafs[-1])
if len(leaves) % 2 == 1:
next_level.append(leaves[-1])

leafs = next_level
leaves = next_level

return leafs[0]
return leaves[0]

def make_proof(
self, leafs: List[bytes], proof: List[Node], leaf: bytes
self, leaves: List[bytes], proof: List[Node], leaf: bytes
) -> List[Node]:
"""
# Make a proof
## Dev:
- if the `leaf` index is less than half the size of the `leafs`
- if the `leaf` index is less than half the size of the `leaves`
list then the right side must reach root and vice versa
## Args:
- leafs: List of leafs
- leaves: List of leaves
- proof: Accumulated proof
- leaf: Leaf for which to create the proof
Expand All @@ -113,25 +111,25 @@ def make_proof(
"""

try:
index = leafs.index(leaf)
index = leaves.index(leaf)
except ValueError as err:
msg = f"Leaf: {leaf} does not exist in the tree: {leafs}"
msg = f"Leaf: {leaf} does not exist in the tree: {leaves}"
raise ValueError(msg) from err

if is_power_2(len(leafs)) is False:
return self.mix_tree(leafs, [], index)
if is_power_2(len(leaves)) is False:
return self.mix_tree(leaves, [], index)

if len(leafs) == 2:
if len(leaves) == 2:
if index == 1:
proof.append(Node(data=leafs[0], side=Side.LEFT))
proof.append(Node(data=leaves[0], side=Side.LEFT))
else:
proof.append(Node(data=leafs[1], side=Side.RIGHT))
proof.append(Node(data=leaves[1], side=Side.RIGHT))
proof.reverse()
return proof

left, right = half(leafs)
left, right = half(leaves)

if index < len(leafs) / 2:
if index < len(leaves) / 2:
proof.append(Node(data=self.make_root(right), side=Side.RIGHT))
return self.make_proof(left, proof, leaf)
else:
Expand Down
34 changes: 17 additions & 17 deletions test/merkle_root/test_merkle_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@


def test_simple_merkle_tree_constructor():
leafs = ["a", "b", "c", "d"]
tree = MerkleTree(leafs)
leaves = ["a", "b", "c", "d"]
tree = MerkleTree(leaves)

assert tree.raw_leafs == leafs
assert tree.raw_leaves == leaves

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
for i, j in zip(
tree.short_leafs,
tree.short_leaves,
[
bytes.fromhex("3ac2"),
bytes.fromhex("b555"),
Expand All @@ -19,7 +19,7 @@ def test_simple_merkle_tree_constructor():
],
):
assert i == j
assert tree.leafs == [
assert tree.leaves == [

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
bytes.fromhex(
"3ac225168df54212a25c1c01fd35bebfea408fdac2e31ddd6f80a4bbf9a5f1cb"
),
Expand All @@ -41,7 +41,7 @@ def test_simple_merkle_tree_constructor():


@mark.parametrize(
"leafs, root",
"leaves, root",
[
(
["a", "b", "c", "d", "e", "f", "g", "h", "1"],
Expand All @@ -65,14 +65,14 @@ def test_simple_merkle_tree_constructor():
),
],
)
def test_simple_merkle_root_with_keccak256(leafs: List[str], root: str):
tree = MerkleTree(leafs)
def test_simple_merkle_root_with_keccak256(leaves: List[str], root: str):
tree = MerkleTree(leaves)
result = tree.root.hex()
assert result == root


@mark.parametrize(
"leafs, root",
"leaves, root",
[
(
["a", "b", "c", "d", "e", "f", "g", "h", "1"],
Expand All @@ -96,20 +96,20 @@ def test_simple_merkle_root_with_keccak256(leafs: List[str], root: str):
),
],
)
def test_simple_merkle_root_with_sha_256(leafs: List[str], root: str):
def test_simple_merkle_root_with_sha_256(leaves: List[str], root: str):
def sha_256(x: bytes, y: bytes) -> bytes:
data = x + y
h = hashlib.sha256()
h.update(data)
return h.digest()

tree = MerkleTree(leafs, hash_function=sha_256)
tree = MerkleTree(leaves, hash_function=sha_256)
result = tree.root.hex()
assert result == root


@mark.parametrize(
"leafs, root",
"leaves, root",
[
(
["a", "b", "c", "d", "e", "f", "g", "h", "1"],
Expand All @@ -133,20 +133,20 @@ def sha_256(x: bytes, y: bytes) -> bytes:
),
],
)
def test_simple_merkle_root_with_shake256(leafs: List[str], root: str):
def test_simple_merkle_root_with_shake256(leaves: List[str], root: str):
def shake_256(x: bytes, y: bytes) -> bytes:
data = x + y
h = hashlib.shake_256()
h.update(data)
return h.digest(32)

tree = MerkleTree(leafs, hash_function=shake_256)
tree = MerkleTree(leaves, hash_function=shake_256)
result = tree.root.hex()
assert result == root


@mark.parametrize(
"leafs, root",
"leaves, root",
[
(
["a", "b", "c", "d", "e", "f", "g", "h", "1"],
Expand All @@ -170,13 +170,13 @@ def shake_256(x: bytes, y: bytes) -> bytes:
),
],
)
def test_simple_merkle_root_with_sha3_256(leafs: List[str], root: str):
def test_simple_merkle_root_with_sha3_256(leaves: List[str], root: str):
def sha3_256(x: bytes, y: bytes) -> bytes:
data = x + y
h = hashlib.sha3_256()
h.update(data)
return h.digest()

tree = MerkleTree(leafs, hash_function=sha3_256)
tree = MerkleTree(leaves, hash_function=sha3_256)
result = tree.root.hex()
assert result == root

0 comments on commit ac53503

Please sign in to comment.