Skip to content

Commit

Permalink
Optimized implementation of the likelihood (#22)
Browse files Browse the repository at this point in the history
Co-authored-by: Laurenz Keller <[email protected]>
Co-authored-by: Paweł Czyż <[email protected]>
Co-authored-by: Paweł Czyż <[email protected]>
  • Loading branch information
4 people authored Oct 30, 2023
1 parent 3777ce4 commit 9b0b2e2
Show file tree
Hide file tree
Showing 10 changed files with 794 additions and 476 deletions.
199 changes: 51 additions & 148 deletions src/pmhn/_trees/_backend.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,21 @@
from typing import Protocol


import numpy as np
from scipy.sparse.linalg import spsolve_triangular
from scipy.sparse import csr_matrix
from pmhn._trees._interfaces import Tree
from pmhn._trees._tree_utils import create_all_subtrees, bfs_compare
from anytree import Node


class IndividualTreeMHNBackendInterface(Protocol):
def loglikelihood(
self,
tree: Tree,
theta: np.ndarray,
) -> float:
"""Calculates loglikelihood `log P(tree | theta)`.
Args:
tree: a tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)

Returns:
loglikelihood of the tree
"""
raise NotImplementedError

def gradient(
self,
tree: Tree,
theta: np.ndarray,
) -> np.ndarray:
"""Calculates the partial derivatives of `log P(tree | theta)`
with respect to `theta`.
Args:
tree: a tree
theta: real-valued matrix,
shape (n_mutations, n_mutatations)
from pmhn._trees._tree_utils import create_all_subtrees, bfs_compare
from anytree import Node, LevelOrderGroupIter

Returns:
gradient `d log P(tree | theta) / d theta`,
shape (n_mutations, n_mutations)
"""
raise NotImplementedError

def gradient_and_loglikelihood(
self, tree: Tree, theta: np.ndarray
) -> tuple[np.ndarray, float]:
"""Returns the gradient and the loglikelihood.
class TreeWrapper:
"""A wrapper for a tree which stores all subtrees."""

Note:
This function may be faster than calling `gradient` and `loglikelihood`
separately.
"""
return self.gradient(tree, theta), self.loglikelihood(tree, theta)
def __init__(self, tree: Node):
self._subtrees_dict: dict[Node, int] = create_all_subtrees(tree)


class OriginalTreeMHNBackend(IndividualTreeMHNBackendInterface):
class OriginalTreeMHNBackend:
def __init__(self, jitter: float = 1e-10):
self._jitter = jitter

_jitter: float
self._jitter: float = jitter

def create_V_Mat(
self, tree: Node, theta: np.ndarray, sampling_rate: float
) -> np.ndarray:
"""Calculates the V matrix.
Args:
tree: a tree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
sampling_rate: a scalar of type float
Returns:
the V matrix.
"""

subtrees = create_all_subtrees(tree)
subtrees_size = len(subtrees)
Q = np.zeros((subtrees_size, subtrees_size))
for i in range(subtrees_size):
for j in range(subtrees_size):
if i == j:
Q[i][j] = self.diag_entry(subtrees[i], theta)
else:
Q[i][j] = self.off_diag_entry(subtrees[i], subtrees[j], theta)
V = np.eye(subtrees_size) * sampling_rate - Q
return V

def diag_entry(self, tree: Node, theta: np.ndarray) -> float:
def _diag_entry(self, tree: Node, theta: np.ndarray, all_mut: set[int]) -> float:
"""
Calculates a diagonal entry of the V matrix.
Expand All @@ -99,43 +24,34 @@ def diag_entry(self, tree: Node, theta: np.ndarray) -> float:
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
all_mut: set containing all possible mutations
Returns:
the diagonal entry of the V matrix corresponding to tree
"""
lamb_sum = 0
n_mutations = len(theta)
current_nodes = [tree]
while len(current_nodes) != 0:
next_nodes = []
for node in current_nodes:
tree_mutations = list(node.path) + list(node.children)
exit_mutations = list(
set([i + 1 for i in range(n_mutations)]).difference(
set(
[
tree_mutation.name # type: ignore
for tree_mutation in tree_mutations
]
)
)

for level in LevelOrderGroupIter(tree):
for node in level:
tree_mutations = {n.name for n in node.path}.union(
{c.name for c in node.children}
)
exit_mutations = set(all_mut).difference(tree_mutations)

for mutation in exit_mutations:
lamb = 0
exit_subclone = [
anc.name # type: ignore
for anc in node.path
if anc.parent is not None
] + [mutation]
exit_subclone = {
anc.name for anc in node.path if anc.parent is not None
}.union({mutation})

for j in exit_subclone:
lamb += theta[mutation - 1][j - 1]
lamb = np.exp(lamb)
lamb_sum -= lamb
for child in node.children:
next_nodes.append(child)
current_nodes = next_nodes

return lamb_sum

def off_diag_entry(self, tree1: Node, tree2: Node, theta: np.ndarray) -> float:
def _off_diag_entry(self, tree1: Node, tree2: Node, theta: np.ndarray) -> float:
"""
Calculates an off-diagonal entry of the V matrix.
Expand Down Expand Up @@ -163,51 +79,38 @@ def off_diag_entry(self, tree1: Node, tree2: Node, theta: np.ndarray) -> float:
return float(lamb)

def loglikelihood(
self, tree: Node, theta: np.ndarray, sampling_rate: float
self, tree_wrapper: TreeWrapper, theta: np.ndarray, sampling_rate: float
) -> float:
"""
Calculates loglikelihood `log P(tree | theta)`.
"""Calculates loglikelihood `log P(tree | theta)`.
Args:
tree: a tree
tree: a wrapper storing a tree (and its subtrees)
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
sampling_rate: a scalar of type float
Returns:
loglikelihood of the tree
"""
# TODO(Pawel): this is part of https://github.com/cbg-ethz/pMHN/issues/15
# It can be implemented in any way.
V = self.create_V_Mat(tree=tree, theta=theta, sampling_rate=sampling_rate)
V_size = V.shape[0]
b = np.zeros(V_size)
b[0] = 1
V_transposed = V.transpose()
V_csr = csr_matrix(V_transposed)
x = spsolve_triangular(V_csr, b, lower=True)

return np.log(x[V_size - 1] + self._jitter) + np.log(sampling_rate)

def gradient(self, tree: Node, theta: np.ndarray) -> np.ndarray:
"""Calculates the partial derivatives of `log P(tree | theta)`
with respect to `theta`.
Args:
tree: a tree
theta: real-valued matrix, shape (n_mutations, n_mutatations)
sampling_rate: a scalar representing sampling rate
Returns:
gradient `d log P(tree | theta) / d theta`,
shape (n_mutations, n_mutations)
loglikelihood of the tree
"""
# TODO(Pawel): This is part of
# https://github.com/cbg-ethz/pMHN/issues/18,
# but it is *not* a priority.
# We will try to do the modelling as soon as possible,
# starting with a sequential Monte Carlo sampler
# and Metropolis transitions.
# Only after initial experiments
# (we will probably see that it's not scalable),
# we'll consider switching to Hamiltonian Monte Carlo,
# which requires gradients.
raise NotImplementedError
subtrees_size = len(tree_wrapper._subtrees_dict)
x = np.zeros(subtrees_size)
x[0] = 1
n_mutations = len(theta)
all_mut = set(i + 1 for i in range(n_mutations))
for i, (subtree_i, subtree_size_i) in enumerate(
tree_wrapper._subtrees_dict.items()
):
V_col = {}
V_diag = 0.0
for j, (subtree_j, subtree_size_j) in enumerate(
tree_wrapper._subtrees_dict.items()
):
if subtree_size_i - subtree_size_j == 1:
V_col[j] = -self._off_diag_entry(subtree_j, subtree_i, theta)
elif i == j:
V_diag = sampling_rate - self._diag_entry(subtree_i, theta, all_mut)
for index, val in V_col.items():
x[i] -= val * x[index]
x[i] /= V_diag

return np.log(x[-1] + self._jitter) + np.log(sampling_rate)
Loading

0 comments on commit 9b0b2e2

Please sign in to comment.