Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized implementation of the likelihood #22

Merged
merged 53 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
afa0462
Add files for listing children and subtrees
laukeller Sep 22, 2023
cc40f5e
added simulation of trees and comparison plots (not correct yet)
laukeller Sep 25, 2023
6f51d22
removed plots and warmup directory, added plotting/csv related files …
laukeller Sep 26, 2023
9d00b71
Merge branch 'main' into warmup
pawel-czyz Sep 27, 2023
e2c3abd
changed _simulate.py, write_csv.py and created a few unit tests
laukeller Sep 27, 2023
40c8a76
Merge branch 'warmup' of github.com:cbg-ethz/pMHN into warmup
laukeller Sep 27, 2023
d154117
changed comment in _simulate.py
laukeller Sep 27, 2023
39907ab
minor changes
laukeller Sep 27, 2023
c4f96d3
moved files to warmup dir
laukeller Sep 28, 2023
cb57917
remove not needed files
laukeller Sep 28, 2023
e6ea6be
change: draw new sampling time if tree is discarded, mean_sampling_ti…
laukeller Sep 28, 2023
3131c24
reformatted files with black
laukeller Sep 28, 2023
ea7cd53
changed unit tests
laukeller Sep 28, 2023
a206230
reformat with black
laukeller Sep 29, 2023
4bb2929
Merge branch 'main' of github.com:cbg-ethz/pMHN
laukeller Sep 29, 2023
8773dec
Merge branch 'warmup'
laukeller Sep 29, 2023
34b0dc9
Remove poetry.lock
pawel-czyz Sep 30, 2023
9d6fc97
fixed ruff errors
laukeller Sep 30, 2023
41cd4c1
fixed ruff errors
laukeller Sep 30, 2023
bbe8ffc
Merge branch 'warmup' of github.com:cbg-ethz/pMHN into warmup
laukeller Sep 30, 2023
c7596a0
Merge branch 'warmup'
laukeller Sep 30, 2023
05bb425
pyright fixed
laukeller Sep 30, 2023
1c3bf80
Merge branch 'warmup'
laukeller Sep 30, 2023
ca76812
modified _backend and added _tree_utils, created unit tests for both
laukeller Oct 4, 2023
d6cd286
added files for likelihood comparison, changed absolute paths to rela…
laukeller Oct 7, 2023
2da22a5
small change
laukeller Oct 7, 2023
6fedbdc
added likelihood tests, modified test_tree_utils.py
laukeller Oct 8, 2023
01beb2f
minor changes
laukeller Oct 9, 2023
f1cdb6b
minor change
laukeller Oct 9, 2023
52eea92
implemented 2 new versions of the likelihood calculation which relies…
laukeller Oct 12, 2023
b839afd
memoization for finding all subtrees
laukeller Oct 13, 2023
42ccc90
minor changes
laukeller Oct 16, 2023
3479452
small change
laukeller Oct 16, 2023
6f36518
small change
laukeller Oct 16, 2023
64bdc6a
small change
laukeller Oct 17, 2023
6dc78fb
small change
laukeller Oct 17, 2023
f490594
small change
laukeller Oct 17, 2023
6d492a6
changed io file and unit tests for it
laukeller Oct 18, 2023
a6c4836
genotype
laukeller Oct 20, 2023
20a2f1e
small change
laukeller Oct 21, 2023
b6b97ee
small change
laukeller Oct 21, 2023
f96265a
small change in _tree_utils_geno.py
laukeller Oct 22, 2023
76aeba2
small change in diag_entry
laukeller Oct 22, 2023
18769d4
small change in diag_entry function
laukeller Oct 22, 2023
502f07f
small change in _tree_utils.py (memoization not needed)
laukeller Oct 22, 2023
614719c
small change
laukeller Oct 22, 2023
8377756
implemented pawel's suggestions
laukeller Oct 23, 2023
1974f3a
Merge branch 'main' into likelihood_optimized
laurenzkeller Oct 25, 2023
25a6f75
Apply Black formatter.
pawel-czyz Oct 30, 2023
955de94
Remove redundant code from original backend
pawel-czyz Oct 30, 2023
71ca142
Fix unit tests
pawel-czyz Oct 30, 2023
263235a
Update the warmup files.
pawel-czyz Oct 30, 2023
04a27ef
Ignore Pyright false positive
pawel-czyz Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# We use convention that poetry.lock is not committed
poetry.lock

# Jupyter Notebooks are disallowed by default
# Use Quarto Notebook (*.qmd) instead
*.ipynb
Expand Down
107 changes: 98 additions & 9 deletions src/pmhn/_trees/_backend.py
laurenzkeller marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@


import numpy as np
from pmhn._trees._interfaces import Tree
from pmhn._trees._tree_utils import create_all_subtrees, bfs_compare
from anytree import Node, LevelOrderGroupIter


from pmhn._trees._interfaces import Tree
class LoglikelihoodSingleTree:
def __init__(self, tree: Node):
self._subtrees_dict = create_all_subtrees(tree)

_subtrees_dict: dict[Node, int]
laurenzkeller marked this conversation as resolved.
Show resolved Hide resolved


class IndividualTreeMHNBackendInterface(Protocol):
Expand Down Expand Up @@ -57,26 +64,108 @@ def gradient_and_loglikelihood(


class OriginalTreeMHNBackend(IndividualTreeMHNBackendInterface):
def __init__(self, jitter: float = 1e-10):
self._jitter = jitter

_jitter: float
laurenzkeller marked this conversation as resolved.
Show resolved Hide resolved

def diag_entry(self, tree: Node, theta: np.ndarray, all_mut: set[int]) -> float:
"""
Calculates a diagonal entry of the V matrix.

Args:
tree: a tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
all_mut: set containing all possible mutations
Returns:
the diagonal entry of the V matrix corresponding to tree
"""
lamb_sum = 0

for level in LevelOrderGroupIter(tree):
for node in level:
tree_mutations = {n.name for n in node.path}.union(
{c.name for c in node.children}
)
exit_mutations = set(all_mut).difference(tree_mutations)

for mutation in exit_mutations:
lamb = 0
exit_subclone = {
anc.name for anc in node.path if anc.parent is not None
}.union({mutation})
for j in exit_subclone:
lamb += theta[mutation - 1][j - 1]
lamb = np.exp(lamb)
lamb_sum -= lamb

return lamb_sum

def off_diag_entry(self, tree1: Node, tree2: Node, theta: np.ndarray) -> float:
"""
Calculates an off-diagonal entry of the V matrix.

Args:
tree1: the first tree
tree2: the second tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
Returns:
the off-diagonal entry of the V matrix corresponding to tree1 and tree2
"""
exit_node = bfs_compare(tree1, tree2)
lamb = 0
if exit_node is None:
return lamb
else:
for j in [
node.name # type: ignore
for node in exit_node.path
if node.parent is not None
]:
lamb += theta[exit_node.name - 1][j - 1]
lamb = np.exp(lamb)
return float(lamb)

def loglikelihood(
self,
tree: Tree,
theta: np.ndarray,
self, tree: LoglikelihoodSingleTree, theta: np.ndarray, sampling_rate: float
) -> float:
"""Calculates loglikelihood `log P(tree | theta)`.
"""
Calculates loglikelihood `log P(tree | theta)`.

Args:
tree: a tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)

sampling_rate: a scalar of type float
Returns:
loglikelihood of the tree
"""
# TODO(Pawel): this is part of https://github.com/cbg-ethz/pMHN/issues/15
# It can be implemented in any way.
raise NotImplementedError

def gradient(self, tree: Tree, theta: np.ndarray) -> np.ndarray:
subtrees_size = len(tree._subtrees_dict)
x = np.zeros(subtrees_size)
x[0] = 1
n_mutations = len(theta)
all_mut = set(i + 1 for i in range(n_mutations))
for i, (subtree_i, subtree_size_i) in enumerate(tree._subtrees_dict.items()):
V_col = {}
V_diag = 0.0
for j, (subtree_j, subtree_size_j) in enumerate(
tree._subtrees_dict.items()
):
if subtree_size_i - subtree_size_j == 1:
V_col[j] = -self.off_diag_entry(subtree_j, subtree_i, theta)
elif i == j:
V_diag = sampling_rate - self.diag_entry(subtree_i, theta, all_mut)
for index, val in V_col.items():
x[i] -= val * x[index]
x[i] /= V_diag

return np.log(x[-1] + self._jitter) + np.log(sampling_rate)

def gradient(self, tree: Node, theta: np.ndarray) -> np.ndarray:
"""Calculates the partial derivatives of `log P(tree | theta)`
with respect to `theta`.

Expand Down
236 changes: 236 additions & 0 deletions src/pmhn/_trees/_backend_geno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from typing import Protocol, Optional

import numpy as np
from pmhn._trees._interfaces import Tree
from pmhn._trees._tree_utils_geno import create_mappings
from anytree import Node


class LoglikelihoodSingleTree:
def __init__(self, tree: Node):
(
self._genotype_subtree_node_map,
self._index_subclone_map,
) = create_mappings(tree)

_genotype_subtree_node_map: dict[tuple[tuple[Node, int]], tuple[int, int]]
laurenzkeller marked this conversation as resolved.
Show resolved Hide resolved
_index_subclone_map: dict[int, tuple[int]]
laurenzkeller marked this conversation as resolved.
Show resolved Hide resolved


class IndividualTreeMHNBackendInterface(Protocol):
def loglikelihood(
self,
tree: Tree,
theta: np.ndarray,
) -> float:
"""Calculates loglikelihood `log P(tree | theta)`.

Args:
tree: a tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)

Returns:
loglikelihood of the tree
"""
raise NotImplementedError

def gradient(
self,
tree: Tree,
theta: np.ndarray,
) -> np.ndarray:
"""Calculates the partial derivatives of `log P(tree | theta)`
with respect to `theta`.

Args:
tree: a tree
theta: real-valued matrix,
shape (n_mutations, n_mutatations)

Returns:
gradient `d log P(tree | theta) / d theta`,
shape (n_mutations, n_mutations)
"""
raise NotImplementedError

def gradient_and_loglikelihood(
self, tree: Tree, theta: np.ndarray
) -> tuple[np.ndarray, float]:
"""Returns the gradient and the loglikelihood.

Note:
This function may be faster than calling `gradient` and `loglikelihood`
separately.
"""
return self.gradient(tree, theta), self.loglikelihood(tree, theta)


class OriginalTreeMHNBackend(IndividualTreeMHNBackendInterface):
def __init__(self, jitter: float = 1e-10):
self._jitter = jitter

_jitter: float

def diag_entry(
self,
tree: LoglikelihoodSingleTree,
genotype: tuple[tuple[Node, int]],
theta: np.ndarray,
all_mut: set[int],
) -> float:
"""
Calculates a diagonal entry of the V matrix.

Args:
tree: a tree
genotype: the genotype of a subtree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
all_mut: a set containing all possible mutations
Returns:
the diagonal entry of the V matrix corresponding to
genotype
"""
lamb_sum = 0
for i, (node, val) in enumerate(genotype):
if val:
lineage = tree._index_subclone_map[i]
lineage = list(lineage)
tree_mutations = set(lineage + [c.name for c in node.children])

exit_mutations = all_mut.difference(tree_mutations)

for mutation in exit_mutations:
lamb = 0
lamb += theta[mutation - 1][mutation - 1]
for j in lineage:
if j != 0:
lamb += theta[mutation - 1][j - 1]
lamb = np.exp(lamb)
lamb_sum -= lamb
return lamb_sum

def find_single_difference(
self, arr1: np.ndarray, arr2: np.ndarray
) -> Optional[int]:
"""
Checks if two binary arrays of equal size differ in only one entry.
If so, the index of the differing entry is returned, otherwise None.

Args:
arr1: the first array
arr2: the second array
Returns:
the index of the differing entry if there's
a single difference, otherwise None.
"""
differing_indices = np.nonzero(np.bitwise_xor(arr1, arr2))[0]

return differing_indices[0] if len(differing_indices) == 1 else None

def off_diag_entry(
self,
tree: LoglikelihoodSingleTree,
genotype_i: np.ndarray,
genotype_j: np.ndarray,
theta: np.ndarray,
) -> float:
"""
Calculates an off-diagonal entry of the V matrix.

Args:
tree: the original tree
genotype_i: the genotype of a subtree
genotype_j: the genotype of another subtree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
Returns:
an off-diagonal entry of the V matrix corresponding to
the genotype_i and genotype_j
"""
index = self.find_single_difference(genotype_i, genotype_j)
if index is None:
return 0
else:
lamb = 0
lineage = tree._index_subclone_map[index]
exit_mutation = lineage[-1]
for mutation in lineage:
if mutation != 0:
lamb += theta[exit_mutation - 1][mutation - 1]
lamb = np.exp(lamb)
return float(lamb)

def loglikelihood(
self,
tree: LoglikelihoodSingleTree,
theta: np.ndarray,
sampling_rate: float,
all_mut: set[int],
) -> float:
"""
Calculates loglikelihood `log P(tree | theta)`.

Args:
tree: a tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
sampling_rate: a scalar of type float
all_mut: a set containing all possible mutations
Returns:
the loglikelihood of tree
"""
# TODO(Pawel): this is part of https://github.com/cbg-ethz/pMHN/issues/15
# It can be implemented in any way.
subtrees_size = len(tree._genotype_subtree_node_map)
x = np.zeros(subtrees_size)
x[0] = 1
genotype_lists = []
for genotype in tree._genotype_subtree_node_map.keys():
genotype_lists.append(np.array([item[1] for item in genotype]))
for genotype_i, (i, subtree_size_i) in tree._genotype_subtree_node_map.items():
V_col = []
V_diag = 0.0
for j, subtree_size_j in tree._genotype_subtree_node_map.values():
if subtree_size_i - subtree_size_j == 1:
V_col.append(
(
j,
-self.off_diag_entry(
tree, genotype_lists[j], genotype_lists[i], theta
),
)
)
elif i == j:
V_diag = sampling_rate - self.diag_entry(
tree, genotype_i, theta, all_mut
)
for index, val in V_col:
x[i] -= val * x[index]
x[i] /= V_diag
return np.log(x[-1] + self._jitter) + np.log(sampling_rate)

def gradient(self, tree: Node, theta: np.ndarray) -> np.ndarray:
"""Calculates the partial derivatives of `log P(tree | theta)`
with respect to `theta`.

Args:
tree: a tree
theta: real-valued matrix, shape (n_mutations, n_mutatations)

Returns:
gradient `d log P(tree | theta) / d theta`,
shape (n_mutations, n_mutations)
"""
# TODO(Pawel): This is part of
# https://github.com/cbg-ethz/pMHN/issues/18,
# but it is *not* a priority.
# We will try to do the modelling as soon as possible,
# starting with a sequential Monte Carlo sampler
# and Metropolis transitions.
# Only after initial experiments
# (we will probably see that it's not scalable),
# we'll consider switching to Hamiltonian Monte Carlo,
# which requires gradients.
raise NotImplementedError
Loading