Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Likelihood #21

Merged
merged 34 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
afa0462
Add files for listing children and subtrees
laukeller Sep 22, 2023
cc40f5e
added simulation of trees and comparison plots (not correct yet)
laukeller Sep 25, 2023
6f51d22
removed plots and warmup directory, added plotting/csv related files …
laukeller Sep 26, 2023
9d00b71
Merge branch 'main' into warmup
pawel-czyz Sep 27, 2023
e2c3abd
changed _simulate.py, write_csv.py and created a few unit tests
laukeller Sep 27, 2023
40c8a76
Merge branch 'warmup' of github.com:cbg-ethz/pMHN into warmup
laukeller Sep 27, 2023
d154117
changed comment in _simulate.py
laukeller Sep 27, 2023
39907ab
minor changes
laukeller Sep 27, 2023
c4f96d3
moved files to warmup dir
laukeller Sep 28, 2023
cb57917
remove not needed files
laukeller Sep 28, 2023
e6ea6be
change: draw new sampling time if tree is discarded, mean_sampling_ti…
laukeller Sep 28, 2023
3131c24
reformatted files with black
laukeller Sep 28, 2023
ea7cd53
changed unit tests
laukeller Sep 28, 2023
a206230
reformat with black
laukeller Sep 29, 2023
4bb2929
Merge branch 'main' of github.com:cbg-ethz/pMHN
laukeller Sep 29, 2023
8773dec
Merge branch 'warmup'
laukeller Sep 29, 2023
34b0dc9
Remove poetry.lock
pawel-czyz Sep 30, 2023
9d6fc97
fixed ruff errors
laukeller Sep 30, 2023
41cd4c1
fixed ruff errors
laukeller Sep 30, 2023
bbe8ffc
Merge branch 'warmup' of github.com:cbg-ethz/pMHN into warmup
laukeller Sep 30, 2023
c7596a0
Merge branch 'warmup'
laukeller Sep 30, 2023
05bb425
pyright fixed
laukeller Sep 30, 2023
1c3bf80
Merge branch 'warmup'
laukeller Sep 30, 2023
ca76812
modified _backend and added _tree_utils, created unit tests for both
laukeller Oct 4, 2023
d6cd286
added files for likelihood comparison, changed absolute paths to rela…
laukeller Oct 7, 2023
2da22a5
small change
laukeller Oct 7, 2023
6fedbdc
added likelihood tests, modified test_tree_utils.py
laukeller Oct 8, 2023
01beb2f
minor changes
laukeller Oct 9, 2023
f1cdb6b
minor change
laukeller Oct 9, 2023
e4ff57a
implemented suggestions from pawel
laukeller Oct 13, 2023
4a7a735
Merge branch 'main' into likelihood
laurenzkeller Oct 16, 2023
7e0e3e7
changed io and test_io
laukeller Oct 17, 2023
0ce34b7
Merge branch 'likelihood' of github.com:cbg-ethz/pMHN into likelihood
laukeller Oct 17, 2023
11a584f
deleted files
laukeller Oct 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 121 additions & 9 deletions src/pmhn/_trees/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@


import numpy as np


from scipy.sparse.linalg import spsolve_triangular
from scipy.sparse import csr_matrix
from pmhn._trees._interfaces import Tree
from pmhn._trees._tree_utils import create_all_subtrees, bfs_compare
from anytree import Node


class IndividualTreeMHNBackendInterface(Protocol):
Expand Down Expand Up @@ -57,26 +59,136 @@ def gradient_and_loglikelihood(


class OriginalTreeMHNBackend(IndividualTreeMHNBackendInterface):
def __init__(self, jitter: float = 1e-10):
self._jitter = jitter

_jitter: float

def create_V_Mat(
self, tree: Node, theta: np.ndarray, sampling_rate: float
) -> np.ndarray:
"""Calculates the V matrix.

Args:
tree: a tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
sampling_rate: a scalar of type float
Returns:
the V matrix.
"""

subtrees = create_all_subtrees(tree)
subtrees_size = len(subtrees)
Q = np.zeros((subtrees_size, subtrees_size))
for i in range(subtrees_size):
for j in range(subtrees_size):
if i == j:
Q[i][j] = self.diag_entry(subtrees[i], theta)
else:
Q[i][j] = self.off_diag_entry(subtrees[i], subtrees[j], theta)
V = np.eye(subtrees_size) * sampling_rate - Q
return V

def diag_entry(self, tree: Node, theta: np.ndarray) -> float:
"""
Calculates a diagonal entry of the V matrix.

Args:
tree: a tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)

Returns:
the diagonal entry of the V matrix corresponding to tree
"""
lamb_sum = 0
n_mutations = len(theta)
current_nodes = [tree]
while len(current_nodes) != 0:
next_nodes = []
for node in current_nodes:
tree_mutations = list(node.path) + list(node.children)
exit_mutations = list(
set([i + 1 for i in range(n_mutations)]).difference(
set(
[
tree_mutation.name # type: ignore
for tree_mutation in tree_mutations
]
)
)
)
for mutation in exit_mutations:
lamb = 0
exit_subclone = [
anc.name # type: ignore
for anc in node.path
if anc.parent is not None
] + [mutation]
for j in exit_subclone:
lamb += theta[mutation - 1][j - 1]
lamb = np.exp(lamb)
lamb_sum -= lamb
for child in node.children:
next_nodes.append(child)
current_nodes = next_nodes
return lamb_sum

def off_diag_entry(self, tree1: Node, tree2: Node, theta: np.ndarray) -> float:
"""
Calculates an off-diagonal entry of the V matrix.

Args:
tree1: the first tree
tree2: the second tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)

Returns:
the off-diagonal entry of the V matrix corresponding to tree1 and tree2
"""
exit_node = bfs_compare(tree1, tree2)
lamb = 0
if exit_node is None:
return lamb
else:
for j in [
node.name # type: ignore
for node in exit_node.path
if node.parent is not None
]:
lamb += theta[exit_node.name - 1][j - 1]
lamb = np.exp(lamb)
return float(lamb)

def loglikelihood(
self,
tree: Tree,
theta: np.ndarray,
self, tree: Node, theta: np.ndarray, sampling_rate: float
) -> float:
"""Calculates loglikelihood `log P(tree | theta)`.
"""
Calculates loglikelihood `log P(tree | theta)`.

Args:
tree: a tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)

sampling_rate: a scalar of type float
Returns:
loglikelihood of the tree
"""
# TODO(Pawel): this is part of https://github.com/cbg-ethz/pMHN/issues/15
# It can be implemented in any way.
raise NotImplementedError
V = self.create_V_Mat(tree=tree, theta=theta, sampling_rate=sampling_rate)
V_size = V.shape[0]
b = np.zeros(V_size)
b[0] = 1
V_transposed = V.transpose()
V_csr = csr_matrix(V_transposed)
x = spsolve_triangular(V_csr, b, lower=True)

return np.log(x[V_size - 1] + self._jitter) + np.log(sampling_rate)

def gradient(self, tree: Tree, theta: np.ndarray) -> np.ndarray:
def gradient(self, tree: Node, theta: np.ndarray) -> np.ndarray:
"""Calculates the partial derivatives of `log P(tree | theta)`
with respect to `theta`.

Expand Down
9 changes: 6 additions & 3 deletions src/pmhn/_trees/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class TreeNaming:

node: str = "Node_ID"
parent: str = "Parent_ID"
mutation: str = "Mutation_ID"
data: dict[str, str] = dataclasses.field(
default_factory=lambda: {"Mutation_ID": "mutation"}
)
Expand All @@ -52,7 +53,7 @@ class ForestNaming:


def parse_tree(df: pd.DataFrame, naming: TreeNaming) -> anytree.Node:
"""Parses a data frame into a tree.
"""Parses a data frame into a tree

Args:
df: data frame with columns specified in `naming`.
Expand All @@ -79,10 +80,12 @@ def parse_tree(df: pd.DataFrame, naming: TreeNaming) -> anytree.Node:
f"Root is {root}, but {node_id} == {parent_id} "
"also looks like a root."
)
root = anytree.Node(node_id, parent=None, **values)
root = anytree.Node(row[naming.mutation], parent=None, **values)
nodes[node_id] = root
else:
nodes[node_id] = anytree.Node(node_id, parent=nodes[parent_id], **values)
nodes[node_id] = anytree.Node(
row[naming.mutation], parent=nodes[parent_id], **values
)

if root is None:
raise ValueError("No root found.")
Expand Down
2 changes: 1 addition & 1 deletion src/pmhn/_trees/_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _find_possible_mutations(old_mutations: list[int], n_mutations: int) -> list
)

possible_mutations = list(
set([i + 1 for i in range(n_mutations)]).difference(set(old_mutations))
set(range(1, n_mutations + 1)).difference(set(old_mutations))
)
return possible_mutations

Expand Down
Loading
Loading