Skip to content

Commit

Permalink
Add dependency on AnyTree.
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Jul 31, 2023
1 parent 6be89eb commit 35241f6
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 9 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ matplotlib = "^3.7.1"
seaborn = "^0.12.2"
pydantic = "^1.10.9"
netcdf4 = "^1.6.4"
anytree = "^2.9.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.2.1"
Expand Down
4 changes: 4 additions & 0 deletions src/pmhn/_trees/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
"""

from pmhn._trees._simulate import simulate_trees
from pmhn._trees._interfaces import Tree
from pmhn._trees._backend import OriginalTreeMHNBackend

__all__ = [
"simulate_trees",
"Tree",
"OriginalTreeMHNBackend",
]
37 changes: 36 additions & 1 deletion src/pmhn/_trees/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pmhn._trees._interfaces import Tree


class IndividualTreeMHNBackend(Protocol):
class IndividualTreeMHNBackendInterface(Protocol):
def loglikelihood(
self,
tree: Tree,
Expand Down Expand Up @@ -53,3 +53,38 @@ def gradient_and_loglikelihood(
separately.
"""
return self.gradient(tree, theta), self.loglikelihood(tree, theta)


class OriginalTreeMHNBackend(IndividualTreeMHNBackendInterface):
def loglikelihood(
self,
tree: Tree,
theta: np.ndarray,
) -> float:
"""Calculates loglikelihood `log P(tree | theta)`.
Args:
tree: a tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
Returns:
loglikelihood of the tree
"""
# TODO(Laurenz): This implementation is missing.
raise NotImplementedError

def gradient(self, tree: Tree, theta: np.ndarray) -> np.ndarray:
"""Calculates the partial derivatives of `log P(tree | theta)`
with respect to `theta`.
Args:
tree: a tree
theta: real-valued matrix, shape (n_mutations, n_mutatations)
Returns:
gradient `d log P(tree | theta) / d theta`,
shape (n_mutations, n_mutations)
"""
# TODO(Laurenz): This implementation is missing.
raise NotImplementedError
11 changes: 3 additions & 8 deletions src/pmhn/_trees/_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
"""Interfaces."""
from typing import TypeAlias
from anytree import Node


class Tree:
"""This is a mock class.
It will be replaced by the actual Tree class,
build around AnyTree's `Node` class:
https://anytree.readthedocs.io/
"""

pass
Tree: TypeAlias = Node

0 comments on commit 35241f6

Please sign in to comment.