diff --git a/src/pmhn/_trees/_backend.py b/src/pmhn/_trees/_backend.py index 50d6651..5f97406 100644 --- a/src/pmhn/_trees/_backend.py +++ b/src/pmhn/_trees/_backend.py @@ -1,96 +1,21 @@ -from typing import Protocol - - import numpy as np -from scipy.sparse.linalg import spsolve_triangular -from scipy.sparse import csr_matrix -from pmhn._trees._interfaces import Tree -from pmhn._trees._tree_utils import create_all_subtrees, bfs_compare -from anytree import Node - - -class IndividualTreeMHNBackendInterface(Protocol): - def loglikelihood( - self, - tree: Tree, - theta: np.ndarray, - ) -> float: - """Calculates loglikelihood `log P(tree | theta)`. - - Args: - tree: a tree - theta: real-valued (i.e., log-theta) matrix, - shape (n_mutations, n_mutations) - Returns: - loglikelihood of the tree - """ - raise NotImplementedError - - def gradient( - self, - tree: Tree, - theta: np.ndarray, - ) -> np.ndarray: - """Calculates the partial derivatives of `log P(tree | theta)` - with respect to `theta`. - - Args: - tree: a tree - theta: real-valued matrix, - shape (n_mutations, n_mutatations) +from pmhn._trees._tree_utils import create_all_subtrees, bfs_compare +from anytree import Node, LevelOrderGroupIter - Returns: - gradient `d log P(tree | theta) / d theta`, - shape (n_mutations, n_mutations) - """ - raise NotImplementedError - def gradient_and_loglikelihood( - self, tree: Tree, theta: np.ndarray - ) -> tuple[np.ndarray, float]: - """Returns the gradient and the loglikelihood. +class TreeWrapper: + """A wrapper for a tree which stores all subtrees.""" - Note: - This function may be faster than calling `gradient` and `loglikelihood` - separately. - """ - return self.gradient(tree, theta), self.loglikelihood(tree, theta) + def __init__(self, tree: Node): + self._subtrees_dict: dict[Node, int] = create_all_subtrees(tree) -class OriginalTreeMHNBackend(IndividualTreeMHNBackendInterface): +class OriginalTreeMHNBackend: def __init__(self, jitter: float = 1e-10): - self._jitter = jitter - - _jitter: float + self._jitter: float = jitter - def create_V_Mat( - self, tree: Node, theta: np.ndarray, sampling_rate: float - ) -> np.ndarray: - """Calculates the V matrix. - - Args: - tree: a tree - theta: real-valued (i.e., log-theta) matrix, - shape (n_mutations, n_mutations) - sampling_rate: a scalar of type float - Returns: - the V matrix. - """ - - subtrees = create_all_subtrees(tree) - subtrees_size = len(subtrees) - Q = np.zeros((subtrees_size, subtrees_size)) - for i in range(subtrees_size): - for j in range(subtrees_size): - if i == j: - Q[i][j] = self.diag_entry(subtrees[i], theta) - else: - Q[i][j] = self.off_diag_entry(subtrees[i], subtrees[j], theta) - V = np.eye(subtrees_size) * sampling_rate - Q - return V - - def diag_entry(self, tree: Node, theta: np.ndarray) -> float: + def _diag_entry(self, tree: Node, theta: np.ndarray, all_mut: set[int]) -> float: """ Calculates a diagonal entry of the V matrix. @@ -99,43 +24,34 @@ def diag_entry(self, tree: Node, theta: np.ndarray) -> float: theta: real-valued (i.e., log-theta) matrix, shape (n_mutations, n_mutations) + all_mut: set containing all possible mutations + Returns: the diagonal entry of the V matrix corresponding to tree """ lamb_sum = 0 - n_mutations = len(theta) - current_nodes = [tree] - while len(current_nodes) != 0: - next_nodes = [] - for node in current_nodes: - tree_mutations = list(node.path) + list(node.children) - exit_mutations = list( - set([i + 1 for i in range(n_mutations)]).difference( - set( - [ - tree_mutation.name # type: ignore - for tree_mutation in tree_mutations - ] - ) - ) + + for level in LevelOrderGroupIter(tree): + for node in level: + tree_mutations = {n.name for n in node.path}.union( + {c.name for c in node.children} ) + exit_mutations = set(all_mut).difference(tree_mutations) + for mutation in exit_mutations: lamb = 0 - exit_subclone = [ - anc.name # type: ignore - for anc in node.path - if anc.parent is not None - ] + [mutation] + exit_subclone = { + anc.name for anc in node.path if anc.parent is not None + }.union({mutation}) + for j in exit_subclone: lamb += theta[mutation - 1][j - 1] lamb = np.exp(lamb) lamb_sum -= lamb - for child in node.children: - next_nodes.append(child) - current_nodes = next_nodes + return lamb_sum - def off_diag_entry(self, tree1: Node, tree2: Node, theta: np.ndarray) -> float: + def _off_diag_entry(self, tree1: Node, tree2: Node, theta: np.ndarray) -> float: """ Calculates an off-diagonal entry of the V matrix. @@ -163,51 +79,38 @@ def off_diag_entry(self, tree1: Node, tree2: Node, theta: np.ndarray) -> float: return float(lamb) def loglikelihood( - self, tree: Node, theta: np.ndarray, sampling_rate: float + self, tree_wrapper: TreeWrapper, theta: np.ndarray, sampling_rate: float ) -> float: - """ - Calculates loglikelihood `log P(tree | theta)`. + """Calculates loglikelihood `log P(tree | theta)`. Args: - tree: a tree + tree: a wrapper storing a tree (and its subtrees) theta: real-valued (i.e., log-theta) matrix, shape (n_mutations, n_mutations) - sampling_rate: a scalar of type float - Returns: - loglikelihood of the tree - """ - # TODO(Pawel): this is part of https://github.com/cbg-ethz/pMHN/issues/15 - # It can be implemented in any way. - V = self.create_V_Mat(tree=tree, theta=theta, sampling_rate=sampling_rate) - V_size = V.shape[0] - b = np.zeros(V_size) - b[0] = 1 - V_transposed = V.transpose() - V_csr = csr_matrix(V_transposed) - x = spsolve_triangular(V_csr, b, lower=True) - - return np.log(x[V_size - 1] + self._jitter) + np.log(sampling_rate) - - def gradient(self, tree: Node, theta: np.ndarray) -> np.ndarray: - """Calculates the partial derivatives of `log P(tree | theta)` - with respect to `theta`. - - Args: - tree: a tree - theta: real-valued matrix, shape (n_mutations, n_mutatations) + sampling_rate: a scalar representing sampling rate Returns: - gradient `d log P(tree | theta) / d theta`, - shape (n_mutations, n_mutations) + loglikelihood of the tree """ - # TODO(Pawel): This is part of - # https://github.com/cbg-ethz/pMHN/issues/18, - # but it is *not* a priority. - # We will try to do the modelling as soon as possible, - # starting with a sequential Monte Carlo sampler - # and Metropolis transitions. - # Only after initial experiments - # (we will probably see that it's not scalable), - # we'll consider switching to Hamiltonian Monte Carlo, - # which requires gradients. - raise NotImplementedError + subtrees_size = len(tree_wrapper._subtrees_dict) + x = np.zeros(subtrees_size) + x[0] = 1 + n_mutations = len(theta) + all_mut = set(i + 1 for i in range(n_mutations)) + for i, (subtree_i, subtree_size_i) in enumerate( + tree_wrapper._subtrees_dict.items() + ): + V_col = {} + V_diag = 0.0 + for j, (subtree_j, subtree_size_j) in enumerate( + tree_wrapper._subtrees_dict.items() + ): + if subtree_size_i - subtree_size_j == 1: + V_col[j] = -self._off_diag_entry(subtree_j, subtree_i, theta) + elif i == j: + V_diag = sampling_rate - self._diag_entry(subtree_i, theta, all_mut) + for index, val in V_col.items(): + x[i] -= val * x[index] + x[i] /= V_diag + + return np.log(x[-1] + self._jitter) + np.log(sampling_rate) diff --git a/src/pmhn/_trees/_backend_code.py b/src/pmhn/_trees/_backend_code.py new file mode 100644 index 0000000..db81051 --- /dev/null +++ b/src/pmhn/_trees/_backend_code.py @@ -0,0 +1,168 @@ +from typing import Optional + +import numpy as np +from pmhn._trees._tree_utils_geno import create_mappings +from anytree import Node + + +class TreeWrapperCode: + """Tree wrapper using smart encoding of subtrees.""" + + def __init__(self, tree: Node) -> None: + self._genotype_subtree_node_map: dict[ + tuple[tuple[Node, int], ...], tuple[int, int] + ] + self._index_subclone_map: dict[int, tuple[int, ...]] + + ( + self._genotype_subtree_node_map, + self._index_subclone_map, + ) = create_mappings(tree) + + +class TreeMHNBackendCode: + def __init__(self, jitter: float = 1e-10) -> None: + self._jitter: float = jitter + + def _diag_entry( + self, + tree_wrapper: TreeWrapperCode, + genotype: tuple[tuple[Node, int], ...], + theta: np.ndarray, + all_mut: set[int], + ) -> float: + """Calculates a diagonal entry of the V matrix. + + Args: + tree: a tree wrappper + genotype: the genotype of a subtree + theta: real-valued (i.e., log-theta) matrix, + shape (n_mutations, n_mutations) + all_mut: a set containing all possible mutations + Returns: + the diagonal entry of the V matrix corresponding to + genotype + """ + lamb_sum = 0 + for i, (node, val) in enumerate(genotype): + if val: + lineage = tree_wrapper._index_subclone_map[i] + lineage = list(lineage) + tree_mutations = set(lineage + [c.name for c in node.children]) + + exit_mutations = all_mut.difference(tree_mutations) + + for mutation in exit_mutations: + lamb = 0 + lamb += theta[mutation - 1][mutation - 1] + for j in lineage: + if j != 0: + lamb += theta[mutation - 1][j - 1] + lamb = np.exp(lamb) + lamb_sum -= lamb + return lamb_sum + + def find_single_difference( + self, arr1: np.ndarray, arr2: np.ndarray + ) -> Optional[int]: + """ + Checks if two binary arrays of equal size differ in only one entry. + If so, the index of the differing entry is returned, otherwise None. + + Args: + arr1: the first array + arr2: the second array + Returns: + the index of the differing entry if there's + a single difference, otherwise None. + """ + differing_indices = np.nonzero(np.bitwise_xor(arr1, arr2))[0] + + return differing_indices[0] if len(differing_indices) == 1 else None + + def _off_diag_entry( + self, + tree_wrapper: TreeWrapperCode, + genotype_i: np.ndarray, + genotype_j: np.ndarray, + theta: np.ndarray, + ) -> float: + """ + Calculates an off-diagonal entry of the V matrix. + + Args: + tree: the original tree + genotype_i: the genotype of a subtree + genotype_j: the genotype of another subtree + theta: real-valued (i.e., log-theta) matrix, + shape (n_mutations, n_mutations) + Returns: + an off-diagonal entry of the V matrix corresponding to + the genotype_i and genotype_j + """ + index = self.find_single_difference(genotype_i, genotype_j) + if index is None: + return 0 + else: + lamb = 0 + lineage = tree_wrapper._index_subclone_map[index] + exit_mutation = lineage[-1] + for mutation in lineage: + if mutation != 0: + lamb += theta[exit_mutation - 1][mutation - 1] + lamb = np.exp(lamb) + return float(lamb) + + def loglikelihood( + self, + tree_wrapper: TreeWrapperCode, + theta: np.ndarray, + sampling_rate: float, + all_mut: set[int], + ) -> float: + """ + Calculates loglikelihood `log P(tree | theta)`. + + Args: + tree: a tree + theta: real-valued (i.e., log-theta) matrix, + shape (n_mutations, n_mutations) + sampling_rate: a scalar of type float + all_mut: a set containing all possible mutations + + Returns: + the loglikelihood of tree + """ + subtrees_size = len(tree_wrapper._genotype_subtree_node_map) + x = np.zeros(subtrees_size) + x[0] = 1 + genotype_lists = [] + for genotype in tree_wrapper._genotype_subtree_node_map.keys(): + genotype_lists.append(np.array([item[1] for item in genotype])) + for genotype_i, ( + i, + subtree_size_i, + ) in tree_wrapper._genotype_subtree_node_map.items(): + V_col = [] + V_diag = 0.0 + for j, subtree_size_j in tree_wrapper._genotype_subtree_node_map.values(): + if subtree_size_i - subtree_size_j == 1: + V_col.append( + ( + j, + -self._off_diag_entry( + tree_wrapper, + genotype_lists[j], + genotype_lists[i], + theta, + ), + ) + ) + elif i == j: + V_diag = sampling_rate - self._diag_entry( + tree_wrapper, genotype_i, theta, all_mut + ) + for index, val in V_col: + x[i] -= val * x[index] + x[i] /= V_diag + return np.log(x[-1] + self._jitter) + np.log(sampling_rate) diff --git a/src/pmhn/_trees/_tree_utils.py b/src/pmhn/_trees/_tree_utils.py index c292c89..6dff521 100644 --- a/src/pmhn/_trees/_tree_utils.py +++ b/src/pmhn/_trees/_tree_utils.py @@ -55,7 +55,7 @@ def all_combinations_of_elements(*lists): yield list(element_combination) -def create_subtree(original_root: Node, nodes_list: list[Node]) -> Optional[Node]: +def create_subtree(original_root: Node, nodes_list: list[Node]) -> Node: """ Creates a subtree given a list of nodes and the root node. @@ -71,8 +71,7 @@ def create_subtree(original_root: Node, nodes_list: list[Node]) -> Optional[Node if node in nodes_list: parent_node = next((n for n in nodes_list if n is node.parent), None) nodes_dict[node] = Node(node.name, parent=nodes_dict.get(parent_node)) - - return nodes_dict.get(original_root) + return nodes_dict[original_root] def get_subtrees(node: Node) -> list[list[Node]]: @@ -86,7 +85,8 @@ def get_subtrees(node: Node) -> list[list[Node]]: Args: node: the root node Returns: - a list of subtrees + a list of subtrees + """ if not node.children: return [[node]] @@ -95,44 +95,44 @@ def get_subtrees(node: Node) -> list[list[Node]]: combined_subtrees = all_combinations_of_elements(*child_subtrees) - result_subtrees = [] - result_subtrees.append([node]) - for combination in combined_subtrees: - subtree_with_root = [node] + [ - item for sublist in combination for item in sublist - ] - result_subtrees.append(subtree_with_root) + result_subtrees = [[node]] + [ + [node] + [item for sublist in combination for item in sublist] + for combination in combined_subtrees + ] return result_subtrees -def create_all_subtrees(root: Node) -> list[Node]: +def create_all_subtrees(root: Node) -> dict[Node, int]: """ - Creates a list of subtrees and sorts the list in ascending subtree size. + Creates a dictionary where each key is a subtree, + and each value is the size of that subtree. + Args: root: the root node Returns: - the final list of subtrees + A dictionary mapping subtrees to their sizes """ all_node_lists = get_subtrees(root) all_node_lists = sorted(all_node_lists, key=len) - all_subtrees = [] - for subtree in all_node_lists: - all_subtrees.append(create_subtree(root, subtree)) - return all_subtrees + all_subtrees_dict = { + create_subtree(root, node_list): len(node_list) for node_list in all_node_lists + } + return all_subtrees_dict -def get_lineage(node: Node) -> list[int]: +def get_lineage(node: Node) -> tuple[int]: """ - Creates a list of the names of the nodes that - are in the lineage of input node. + Creates a tuple of the names of the nodes that + are in the lineage of the input node. + Args: node: a node Returns: the lineage of a node """ - return [ancestor.name for ancestor in node.path] # type: ignore + return tuple(ancestor.name for ancestor in node.path) # type: ignore def check_equality(tree1: Optional[Node], tree2: Optional[Node]) -> bool: @@ -152,8 +152,9 @@ def check_equality(tree1: Optional[Node], tree2: Optional[Node]) -> bool: if len(tree1.descendants) != len(tree2.descendants): return False for nodes1, nodes2 in zip(iter1, iter2): - set_nodes1_lineages = {tuple(get_lineage(node)) for node in nodes1} - set_nodes2_lineages = {tuple(get_lineage(node)) for node in nodes2} + set_nodes1_lineages = {get_lineage(node) for node in nodes1} + set_nodes2_lineages = {get_lineage(node) for node in nodes2} + additional_nodes_lineages = set_nodes2_lineages ^ set_nodes1_lineages if len(additional_nodes_lineages) != 0: return False @@ -163,40 +164,50 @@ def check_equality(tree1: Optional[Node], tree2: Optional[Node]) -> bool: def bfs_compare(tree1: Node, tree2: Node) -> Optional[Node]: """ - Checks if tree1 is a subtree of tree2 and is smaller in - size by one. + Checks if tree1 is a subtree of tree2 with the assumption + that tree2 is larger than the first tree by one. + Args: tree1: the first tree tree2: the second tree Returns: the additional node in the second tree if available, otherwise None. + """ + diff_count = 0 iter1 = list(LevelOrderGroupIter(tree1)) iter2 = list(LevelOrderGroupIter(tree2)) exit_node = None - if len(list(tree2.descendants)) - len(list(tree1.descendants)) != 1: - return None - for nodes1, nodes2 in zip(iter1, iter2): - set_nodes1_lineages = {tuple(get_lineage(node)) for node in nodes1} - set_nodes2_lineages = {tuple(get_lineage(node)) for node in nodes2} + + for level, (nodes1, nodes2) in enumerate(zip(iter1, iter2)): + dict_nodes1_lineages = {node: get_lineage(node) for node in nodes1} + dict_nodes2_lineages = {node: get_lineage(node) for node in nodes2} + set_nodes1_lineages = set(dict_nodes1_lineages.values()) + set_nodes2_lineages = set(dict_nodes2_lineages.values()) + additional_nodes_lineages = set_nodes2_lineages ^ set_nodes1_lineages diff_count += len(additional_nodes_lineages) if diff_count == 1 and exit_node is None: additional_node_lineage = additional_nodes_lineages.pop() + for node in nodes1: - if tuple(get_lineage(node)) == additional_node_lineage: + if dict_nodes1_lineages[node] == additional_node_lineage: return None for node in nodes2: - if tuple(get_lineage(node)) == additional_node_lineage: + if dict_nodes2_lineages[node] == additional_node_lineage: exit_node = node - if diff_count > 1: - return None - if len(iter1) < len(iter2): + break + + if diff_count > 1: + return None + + if diff_count == 0: return iter2[-1][0] + return exit_node diff --git a/src/pmhn/_trees/_tree_utils_geno.py b/src/pmhn/_trees/_tree_utils_geno.py new file mode 100644 index 0000000..ad1e001 --- /dev/null +++ b/src/pmhn/_trees/_tree_utils_geno.py @@ -0,0 +1,190 @@ +from anytree import Node, LevelOrderGroupIter +from itertools import combinations, product +from typing import Optional + + +def all_combinations_of_elements(*lists): + """ + Takes a variable number of lists as input and returns a generator that yields + all possible combinations of the input lists. In our use case: It takes a list + of lists of subtrees as input where a subtree itself is a list of nodes and + outputs all possible combinations of the lists of subtrees. + + For instance, if we have the following tree: + + 0 + / | \ + 1 3 2 + | + 2 + + and assumed that we know the list of subtrees for the trees: + + 1 + | -> list of subtrees: [[1], [1, 2]] + 2 + + 3 -> list of subtrees: [[3]] + + and + + 2 -> list of subtrees: [[2]] + + , we can find the subtrees of the original tree by looking at + all possible combinations of the list of subtrees for the trees above + and add the root node (0) to each combination (this is done in the + get_subtrees function). + + So the input would be [[[1], [1, 2]],[[3]], [[2]]] + + The generator would yield the following combinations one at a time: + [[1]], [[1, 2]], [[3]], [[2]], [[1], [3]], [[1, 2], [3]], [[1], [2]], + [[1, 2], [2]], [[3], [2]], [[1], [3], [2]], [[1, 2], [3], [2]] + + Args: + *lists: any number of lists + + Returns: + A generator that yields all combinations of the input lists. + + """ + n = len(lists) + for r in range(1, n + 1): + for list_combination in combinations(lists, r): + for element_combination in product(*list_combination): + yield list(element_combination) + + +def create_subtree(subtree_nodes: list[Node], original_tree_nodes: list[Node]) -> Node: + """ + Creates a certain subtree of the original tree. + + Args: + subtree_nodes: the nodes that are contained in both + the subtree and the original tree + original_tree_nodes: all nodes of the original tree + Returns: + a subtree + """ + nodes_dict = {} + + for node in subtree_nodes: + parent_node = node.parent + nodes_dict[node] = Node(node.name, parent=nodes_dict.get(parent_node)) + + return nodes_dict[original_tree_nodes[0]] + + +def get_subtrees(node: Node) -> list[list[Node]]: + """ + Creates a list of all subtrees of a tree. + A recursive approach is employed: If one knows the subtrees of the + children of the root node, then one can find all combinations of + the subtrees of the children and add the root node to each one + of these combinations, this way one obtains all subtrees of the original tree. + + Args: + node: the root node + Returns: + a list of subtrees + """ + if not node.children: + return [[node]] + + child_subtrees = [get_subtrees(child) for child in node.children] + + combined_subtrees = all_combinations_of_elements(*child_subtrees) + + result_subtrees = [[node]] + [ + [node] + [item for sublist in combination for item in sublist] + for combination in combined_subtrees + ] + + return result_subtrees + + +def get_lineage(node: Node) -> tuple[int]: + """ + Creates a tuple of the names of the nodes that + are in the lineage of the input node. + Args: + node: a node + Returns: + the lineage of a node + """ + return tuple(ancestor.name for ancestor in node.path) # type: ignore + + +def create_index_subclone_maps( + root: Node, +) -> tuple[dict[int, tuple[int, ...]], dict[tuple[int, ...], int]]: + """ + Assigns a unique index to each subclone in the provided + tree and generates two dictionaries: one mapping each unique + index to its corresponding subclone, and the other inverting this relationship. + Args: + root: the root node of a tree + Returns: + two dictionaries that contain the mappings + """ + index_subclone_map = {} + subclone_index_map = {} + index = 0 + for level in LevelOrderGroupIter(root): + for node in level: + index_subclone_map[index] = get_lineage(node) + subclone_index_map[get_lineage(node)] = index + index += 1 + return index_subclone_map, subclone_index_map + + +def create_genotype( + size: int, root: Node, subclone_index_map: dict[tuple[int, ...], int] +) -> tuple[tuple[Optional[Node], int], ...]: + """ + Creates the genotype of a given tree. + + Args: + size: the size of the original tree + root: the root node of a subtree of the original tree + subclone_index_map: a dictionary that maps subclones to their indices + Returns: + a tuple of tuples, where each inner tuple represents a subclone from + the original tree. For each subclone, if it exists in the subtree, + the inner tuple contains the last node of that subclone and the value 1; + if it doesn't exist, the tuple contains None and the value 0. + """ + x = [(Node(None), int(0))] * size + for level in LevelOrderGroupIter(root): + for node in level: + lineage = get_lineage(node) + x[subclone_index_map[lineage]] = (node, 1) + return tuple(x) + + +def create_mappings( + root: Node, +) -> tuple[ + dict[tuple[tuple[Node, int], ...], tuple[int, int]], dict[int, tuple[int, ...]] +]: + """ + Creates the required mappings to calculate the likelihood of a tree. + + Args: + root: the root node of the original tree + Returns: + two dictionaries, one mapping genotypes to subtrees (here only the + index and length of the subtrees are needed) and the other one + mapping indices to subclones + """ + index_subclone_map, subclone_index_map = create_index_subclone_maps(root) + genotype_subtree_map = {} + subtrees = get_subtrees(root) + original_tree = subtrees[-1] + all_node_lists_with_len = [(subtree, len(subtree)) for subtree in subtrees] + size = len(subtrees) + for index, (subtree, subtree_size) in enumerate(all_node_lists_with_len): + subtree = create_subtree(subtree, original_tree) + genotype = create_genotype(size, subtree, subclone_index_map) + genotype_subtree_map[genotype] = (index, subtree_size) + return genotype_subtree_map, index_subclone_map diff --git a/tests/ppl/test_multiplemhn.py b/tests/ppl/test_multiplemhn.py index 962e52b..f3d7704 100644 --- a/tests/ppl/test_multiplemhn.py +++ b/tests/ppl/test_multiplemhn.py @@ -16,7 +16,7 @@ def test_loglikelihood(n_patients: int, n_genes: int) -> None: mutations = rng.binomial(1, 0.5, size=(n_patients, n_genes)) thetas = rng.normal(size=(n_patients, n_genes, n_genes)) - loglikelihood = np.sum( + loglikelihood = np.sum( # pyright: ignore [ lmhn.MHNCythonBackend().gradient_and_loglikelihood( mutations=mutations[i].reshape((1, -1)), theta=thetas[i] diff --git a/tests/trees/test_likelihood.py b/tests/trees/test_likelihood.py index 5f98ac4..548315a 100644 --- a/tests/trees/test_likelihood.py +++ b/tests/trees/test_likelihood.py @@ -1,294 +1,49 @@ -from pmhn._trees._backend import OriginalTreeMHNBackend -from pmhn._trees._tree_utils import create_all_subtrees -from anytree import Node import numpy as np +import pytest +from anytree import Node +import pmhn._trees._backend as backend_orig +import pmhn._trees._backend_code as backend_geno -def test_create_V_Mat(): - """ - Checks if create_V_Mat is implemented correctly. - - tree: - 0 - / \ - 1 3 - / - 3 - """ - - true_Q = np.array( - [ - [ - -(np.exp(-1.41) + np.exp(-2.26) + np.exp(-2.55)), - np.exp(-1.41), - np.exp(-2.55), - 0, - 0, - 0, - ], - [ - 0, - -( - np.exp(-2.26) - + np.exp(-2.55) - + np.exp(-1.12 - 2.26) - + np.exp(1 - 2.55) - ), - 0, - np.exp(1 - 2.55), - np.exp(-2.55), - 0, - ], - [ - 0, - 0, - -( - np.exp(-1.41) - + np.exp(-2.26) - + np.exp(-1.41 + 3) - + np.exp(-2.26 + 2) - ), - 0, - np.exp(-1.41), - 0, - ], - [ - 0, - 0, - 0, - -( - np.exp(-2.26) - + np.exp(-2.55) - + np.exp(-2.26 - 1.12) - + np.exp(-2.26 - 1.12 + 2) - ), - 0, - np.exp(-2.55), - ], - [ - 0, - 0, - 0, - 0, - -( - np.exp(-2.26) - + np.exp(-2.26 - 1.12) - + np.exp(-2.55 + 1) - + np.exp(-1.41 + 3) - + np.exp(-2.26 + 2) - ), - np.exp(-2.55 + 1), - ], - [ - 0, - 0, - 0, - 0, - 0, - -( - np.exp(-2.26) - + np.exp(-2.26 - 1.12) - + np.exp(-2.26 + 2 - 1.12) - + np.exp(-1.41 + 3) - + np.exp(-2.26 + 2) - ), - ], - ] - ) - A = Node(0) - B = Node(1, parent=A) - Node(3, parent=A) - Node(3, parent=B) - subtrees = create_all_subtrees(A) - subtrees_size = len(subtrees) - sampling_rate = 1.0 - true_V = np.eye(subtrees_size) * sampling_rate - true_Q - backend = OriginalTreeMHNBackend() - theta = np.array([[-1.41, 2, 3], [-1.12, -2.26, 2], [1, -0.86, -2.55]]) - - V = backend.create_V_Mat(A, theta, sampling_rate) - - assert np.allclose(V, true_V, atol=1e-8) - - -def test_diag_entry(): - r""" - - Checks if the diagonal values of the V matrix are calculated correctly. - - 0 - / | \ - 2 1 3 - | | - 3 3 - - - augmented tree: - - - 0 - / | \ - 2 1 3 - |\ |\ |\ - 3 1 3 2 1 2 - | | - 1 2 - - """ - A = Node(0) - B = Node(2, parent=A) - D = Node(1, parent=A) - Node(3, parent=A) - Node(3, parent=B) - Node(3, parent=D) - theta = np.array([[-1.41, 2, 3], [-1.12, -2.26, 2], [1, -0.86, -2.55]]) - true_diag_entry = -( - np.exp(-1.41 + 2) - + np.exp(-2.26 - 1.12) - + np.exp(-1.41 + 3) - + np.exp(-2.26 + 2) - + np.exp(-1.41 + 2 + 3) - + np.exp(-2.26 + 2 - 1.12) - ) - backend = OriginalTreeMHNBackend() - diag_entry = backend.diag_entry(A, theta) - - assert np.allclose(diag_entry, true_diag_entry, atol=1e-8) - - -def test_off_diag_entry_valid(): - r""" - Checks if the off-diagonal entries of the V matrix - are calculated correctly. - - first tree: - 0 - / | \ - 2 1 3 - | - 3 - - second tree: - 0 - / | \ - 2 1 3 - | | - 3 3 - - - """ - - theta = np.array([[-1.41, 2, 3], [-1.12, -2.26, 2], [1, -0.86, -2.55]]) - # first tree - A_1 = Node(0) - B_1 = Node(2, parent=A_1) - Node(1, parent=A_1) - Node(3, parent=A_1) - Node(3, parent=B_1) - - # second tree - A_2 = Node(0) - B_2 = Node(2, parent=A_2) - D_2 = Node(1, parent=A_2) - Node(3, parent=A_2) - Node(3, parent=B_2) - Node(3, parent=D_2) - - true_off_diag_entry = np.exp(-2.55 + 1) - backend = OriginalTreeMHNBackend() - off_diag_entry = backend.off_diag_entry(A_1, A_2, theta) - - assert np.allclose(off_diag_entry, true_off_diag_entry, atol=1e-8) - - -def test_off_diag_entry_invalid_size(): - r""" - - Checks if off_diag_entry successfully returns 0 when the size is invalid - (i.e the first tree is not smaller than the second tree by one). - - first tree: - 0 - / | \ - 2 1 3 - | - 3 - - second tree: - 0 - / | \ - 2 1 3 - | | | - 3 3 2 - - - """ - - # first tree - A_1 = Node(0) - B_1 = Node(2, parent=A_1) - Node(1, parent=A_1) - Node(3, parent=A_1) - Node(3, parent=B_1) - - # second tree - A_2 = Node(0) - B_2 = Node(2, parent=A_2) - C_2 = Node(1, parent=A_2) - D_2 = Node(3, parent=A_2) - Node(3, parent=B_2) - Node(3, parent=C_2) - Node(2, parent=D_2) - - theta = np.array([[-1.41, 2, 3], [-1.12, -2.26, 2], [1, -0.86, -2.55]]) - backend = OriginalTreeMHNBackend() - assert backend.off_diag_entry(A_1, A_2, theta) == 0 +def get_loglikelihood_functions() -> list: + """This is an auxiliary function which returns a list of + loglikelihood functions to be tested. + Each of these functions has signature: -def test_off_diag_entry_not_subset(): - r""" - Checks if the off_diag_entry succesfully returns 0 when the - first tree is not a subtree of the second tree. + loglikelihood( + tree: Node, + theta: np.ndarray, + sampling_rate: float, + all_mut: set[int], + ) -> float - first tree: - 0 - / | \ - 2 1 3 - | - 3 - - second tree: - 0 - / | \ - 2 1 3 - | | - 3 2 - - + Note: + Whenever a new backend is used + (and it has a wrapper around trees for memoization), + it could just be added here. """ - # first tree - A_1 = Node(0) - B_1 = Node(2, parent=A_1) - Node(1, parent=A_1) - Node(3, parent=A_1) - Node(3, parent=B_1) - # second tree - A_2 = Node(0) - Node(2, parent=A_2) - C_2 = Node(1, parent=A_2) - D_2 = Node(3, parent=A_2) - Node(3, parent=C_2) - Node(2, parent=D_2) + def backend1( + tree: Node, theta: np.ndarray, sampling_rate: float, all_mut: set[int] + ) -> float: + return backend_orig.OriginalTreeMHNBackend().loglikelihood( + backend_orig.TreeWrapper(tree), theta, sampling_rate + ) - theta = np.array([[-1.41, 2, 3], [-1.12, -2.26, 2], [1, -0.86, -2.55]]) - backend = OriginalTreeMHNBackend() + def backend2( + tree: Node, theta: np.ndarray, sampling_rate: float, all_mut: set[int] + ) -> float: + return backend_geno.TreeMHNBackendCode().loglikelihood( + backend_geno.TreeWrapperCode(tree), theta, sampling_rate, all_mut + ) - assert backend.off_diag_entry(A_1, A_2, theta) == 0 + return [backend1, backend2] -def test_likelihood_small_tree(): +@pytest.mark.parametrize("backend", get_loglikelihood_functions()) +def test_likelihood_small_tree(backend) -> None: """ Checks if the likelihood of a small tree is calculated correctly. @@ -319,14 +74,14 @@ def test_likelihood_small_tree(): ] ) sampling_rate = 1.0 + all_mut = set(range(1, 11)) - backend = OriginalTreeMHNBackend() - log_value = backend.loglikelihood(A, theta, sampling_rate) - + log_value = backend(A, theta, sampling_rate, all_mut) assert np.allclose(log_value, -5.793104, atol=1e-5) -def test_likelihood_medium_tree(): +@pytest.mark.parametrize("backend", get_loglikelihood_functions()) +def test_likelihood_medium_tree(backend) -> None: """ Checks if the likelihood of a medium-sized tree is calculated correctly. @@ -363,13 +118,14 @@ def test_likelihood_medium_tree(): ) sampling_rate = 1.0 - backend = OriginalTreeMHNBackend() - log_value = backend.loglikelihood(A, theta, sampling_rate) + all_mut = set(range(1, 11)) + log_value = backend(A, theta, sampling_rate, all_mut) assert np.allclose(log_value, -14.729560, atol=1e-5) -def test_likelihood_large_tree(): +@pytest.mark.parametrize("backend", get_loglikelihood_functions()) +def test_likelihood_large_tree(backend) -> None: """ Checks if the likelihood of a large tree is calculated correctly. @@ -412,8 +168,8 @@ def test_likelihood_large_tree(): ] ) sampling_rate = 1.0 + all_mut = set(range(1, 11)) - backend = OriginalTreeMHNBackend() - log_value = backend.loglikelihood(A, theta, sampling_rate) + log_value = backend(A, theta, sampling_rate, all_mut) assert np.allclose(log_value, -22.288420, atol=1e-5) diff --git a/warmup/likelihood/R_py_loglikelihood_comparison.py b/warmup/likelihood/R_py_loglikelihood_comparison.py index a1d079a..0f90674 100644 --- a/warmup/likelihood/R_py_loglikelihood_comparison.py +++ b/warmup/likelihood/R_py_loglikelihood_comparison.py @@ -1,8 +1,9 @@ import pandas as pd import pmhn._trees._io as io -from pmhn._trees._backend import OriginalTreeMHNBackend +from pmhn._trees._backend import OriginalTreeMHNBackend, TreeWrapper import csv import numpy as np +import time def csv_to_numpy(file_path): @@ -48,19 +49,25 @@ def csv_to_numpy(file_path): # calculate loglikelihoods log_vec_py_AML = np.empty(len(trees_AML)) log_vec_py_500 = np.empty(len(trees_500)) +start_time = time.time() backend = OriginalTreeMHNBackend() - for idx, tree in trees_AML.items(): print(f"Processing tree {idx} of {len(trees_AML)}") - log_value = backend.loglikelihood(tree, theta_AML, sampling_rate) + tree_log = TreeWrapper(tree) + log_value = backend.loglikelihood(tree_log, theta_AML, sampling_rate) + log_vec_py_AML[idx - 1] = log_value print(f"log_value: {log_value}") for idx, tree in trees_500.items(): print(f"Processing tree {idx} of {len(trees_500)}") - log_value = backend.loglikelihood(tree, theta_500, sampling_rate) + tree_log = TreeWrapper(tree) + log_value = backend.loglikelihood(tree_log, theta_500, sampling_rate) log_vec_py_500[idx - 1] = log_value print(f"log_value: {log_value}") +end_time = time.time() +elapsed_time = end_time - start_time +print(f"Time elapsed: {elapsed_time} seconds") # write Python loglikelihoods to CSV np.savetxt("likelihood_py/log_vec_py_AML.csv", log_vec_py_AML, delimiter=",") diff --git a/warmup/likelihood/R_py_loglikelihood_comparison_geno.py b/warmup/likelihood/R_py_loglikelihood_comparison_geno.py new file mode 100644 index 0000000..cf8942a --- /dev/null +++ b/warmup/likelihood/R_py_loglikelihood_comparison_geno.py @@ -0,0 +1,88 @@ +import pandas as pd +import pmhn._trees._io as io +from pmhn._trees._backend_code import TreeMHNBackendCode, TreeWrapperCode +import csv +import numpy as np +import time + + +def csv_to_numpy(file_path): + with open(file_path, "r") as file: + reader = csv.reader(file) + next(reader) + data_list = list(reader) + return np.array(data_list, dtype=float) + + +# AML trees +df_AML = pd.read_csv("likelihood_R/trees_AML_R.csv") + +# randomly generated 500 trees using a random theta +df_500 = pd.read_csv("likelihood_R/trees_500_R.csv") + +# theta matrices +theta_AML = csv_to_numpy("likelihood_R/MHN_Matrix_AML.csv") +theta_500 = csv_to_numpy("likelihood_R/MHN_Matrix_500.csv") +# loglikelihoods in R +log_vec_R_AML = np.genfromtxt("likelihood_R/log_vec_R_AML.csv", delimiter=",") +log_vec_R_500 = np.genfromtxt("likelihood_R/log_vec_R_500.csv", delimiter=",") + +# define sampling rate +sampling_rate = 1.0 + +# use modified io +naming = io.ForestNaming( + tree_name="Tree_ID", + naming=io.TreeNaming( + node="Node_ID", + parent="Parent_ID", + data={ + "Mutation_ID": "mutation", + }, + ), +) + +# parse trees +trees_AML = io.parse_forest(df_AML, naming=naming) +trees_500 = io.parse_forest(df_500, naming=naming) + +# calculate loglikelihoods +log_vec_py_AML = np.empty(len(trees_AML)) +log_vec_py_500 = np.empty(len(trees_500)) + +start_time = time.time() +backend = TreeMHNBackendCode() +theta_AML_size = len(theta_AML) +all_mut_AML = set(range(1, theta_AML_size + 1)) +for idx, tree in trees_AML.items(): + print(f"Processing tree {idx} of {len(trees_AML)}") + tree_log = TreeWrapperCode(tree) + log_value = backend.loglikelihood(tree_log, theta_AML, sampling_rate, all_mut_AML) + log_vec_py_AML[idx - 1] = log_value + print(f"log_value: {log_value}") +theta_500_size = len(theta_500) +all_mut_500 = set(range(1, theta_500_size + 1)) +for idx, tree in trees_500.items(): + print(f"Processing tree {idx} of {len(trees_500)}") + tree_log = TreeWrapperCode(tree) + log_value = backend.loglikelihood(tree_log, theta_500, sampling_rate, all_mut_500) + log_vec_py_500[idx - 1] = log_value + print(f"log_value: {log_value}") + +end_time = time.time() +elapsed_time = end_time - start_time +print(f"Time elapsed: {elapsed_time} seconds") +# write Python loglikelihoods to CSV +np.savetxt("likelihood_py/log_vec_py_AML.csv", log_vec_py_AML, delimiter=",") +np.savetxt("likelihood_py/log_vec_py_500.csv", log_vec_py_500, delimiter=",") + + +# check if the loglikelihood vectors are the same +if np.allclose(log_vec_py_AML, log_vec_R_AML, atol=1e-10): + print("The loglikelihoods of the AML trees are the same in R and Python.") + +if np.allclose(log_vec_py_500, log_vec_R_500, atol=1e-10): + print( + "The loglikelihoods of the 500 randomly generated" + " trees are the same in R and Python." + ) diff --git a/warmup/likelihood/R_py_loglikelihood_comparison_geno_profiling.py b/warmup/likelihood/R_py_loglikelihood_comparison_geno_profiling.py new file mode 100644 index 0000000..61bb543 --- /dev/null +++ b/warmup/likelihood/R_py_loglikelihood_comparison_geno_profiling.py @@ -0,0 +1,93 @@ +import pandas as pd +import pmhn._trees._io as io +from pmhn._trees._backend_code import TreeMHNBackendCode, TreeWrapperCode +import csv +import numpy as np +import time +import pstats +import cProfile + + +def csv_to_numpy(file_path): + with open(file_path, "r") as file: + reader = csv.reader(file) + next(reader) + data_list = list(reader) + return np.array(data_list, dtype=float) + + +# AML trees +df_AML = pd.read_csv("likelihood_R/trees_AML_R.csv") + +# randomly generated 500 trees using a random theta +df_500 = pd.read_csv("likelihood_R/trees_500_R.csv") + +# theta matrices +theta_AML = csv_to_numpy("likelihood_R/MHN_Matrix_AML.csv") +theta_500 = csv_to_numpy("likelihood_R/MHN_Matrix_500.csv") +# loglikelihoods in R +log_vec_R_AML = np.genfromtxt("likelihood_R/log_vec_R_AML.csv", delimiter=",") +log_vec_R_500 = np.genfromtxt("likelihood_R/log_vec_R_500.csv", delimiter=",") + +# define sampling rate +sampling_rate = 1.0 + +# use modified io +naming = io.ForestNaming( + tree_name="Tree_ID", + naming=io.TreeNaming( + node="Node_ID", + parent="Parent_ID", + data={ + "Mutation_ID": "mutation", + }, + ), +) + +# parse trees +trees_AML = io.parse_forest(df_AML, naming=naming) +trees_500 = io.parse_forest(df_500, naming=naming) + +# calculate loglikelihoods +log_vec_py_AML = np.empty(len(trees_AML)) +log_vec_py_500 = np.empty(len(trees_500)) +profiler = cProfile.Profile() +profiler.enable() +start_time = time.time() +backend = TreeMHNBackendCode() +theta_AML_size = len(theta_AML) +all_mut_AML = set(i + 1 for i in range(theta_AML_size)) +for idx, tree in trees_AML.items(): + print(f"Processing tree {idx} of {len(trees_AML)}") + tree_log = TreeWrapperCode(tree) + log_value = backend.loglikelihood(tree_log, theta_AML, sampling_rate, all_mut_AML) + log_vec_py_AML[idx - 1] = log_value + print(f"log_value: {log_value}") +theta_500_size = len(theta_500) +all_mut_500 = set(i + 1 for i in range(theta_500_size)) +for idx, tree in trees_500.items(): + print(f"Processing tree {idx} of {len(trees_500)}") + tree_log = TreeWrapperCode(tree) + log_value = backend.loglikelihood(tree_log, theta_500, sampling_rate, all_mut_500) + log_vec_py_500[idx - 1] = log_value + print(f"log_value: {log_value}") +end_time = time.time() +profiler.disable() +elapsed_time = end_time - start_time +print(f"Time elapsed: {elapsed_time} seconds") +# write Python loglikelihoods to CSV +np.savetxt("likelihood_py/log_vec_py_AML.csv", log_vec_py_AML, delimiter=",") +np.savetxt("likelihood_py/log_vec_py_500.csv", log_vec_py_500, delimiter=",") + + +# check if the loglikelihood vectors are the same +if np.allclose(log_vec_py_AML, log_vec_R_AML, atol=1e-10): + print("The loglikelihoods of the AML trees are the same in R and Python.") + +if np.allclose(log_vec_py_500, log_vec_R_500, atol=1e-10): + print( + "The loglikelihoods of the 500 randomly generated" + " trees are the same in R and Python." + ) +stats = pstats.Stats(profiler).sort_stats("cumtime") # Sort by cumulative time spent +stats.print_stats() diff --git a/warmup/likelihood/R_py_loglikelihood_comparison_geno_seed43.py b/warmup/likelihood/R_py_loglikelihood_comparison_geno_seed43.py new file mode 100644 index 0000000..932fef0 --- /dev/null +++ b/warmup/likelihood/R_py_loglikelihood_comparison_geno_seed43.py @@ -0,0 +1,102 @@ +import pandas as pd +import pmhn._trees._io as io +from pmhn._trees._backend_code import TreeWrapperCode, TreeMHNBackendCode +import csv +import numpy as np +import time + + +def csv_to_numpy(file_path): + with open(file_path, "r") as file: + reader = csv.reader(file) + next(reader) + data_list = list(reader) + return np.array(data_list, dtype=float) + + +# AML trees +df_AML = pd.read_csv("likelihood_R/trees_AML_R.csv") + +# randomly generated 500 trees using a random theta +df_500 = pd.read_csv("likelihood_R/trees_500_R.csv") +df_1000 = pd.read_csv("likelihood_R/trees_1000_seed43_R.csv") +# theta matrices +theta_AML = csv_to_numpy("likelihood_R/MHN_Matrix_AML.csv") +theta_500 = csv_to_numpy("likelihood_R/MHN_Matrix_500.csv") +theta_1000 = csv_to_numpy("likelihood_R/MHN_Matrix_1000_seed43.csv") +# loglikelihoods in R +log_vec_R_AML = np.genfromtxt("likelihood_R/log_vec_R_AML.csv", delimiter=",") +log_vec_R_500 = np.genfromtxt("likelihood_R/log_vec_R_500.csv", delimiter=",") +log_vec_R_1000 = np.genfromtxt("likelihood_R/log_vec_R_1000_seed43.csv", delimiter=",") +# define sampling rate +sampling_rate = 1.0 + +# use modified io +naming = io.ForestNaming( + tree_name="Tree_ID", + naming=io.TreeNaming( + node="Node_ID", + parent="Parent_ID", + data={ + "Mutation_ID": "mutation", + }, + ), +) + +# parse trees +trees_AML = io.parse_forest(df_AML, naming=naming) +trees_500 = io.parse_forest(df_500, naming=naming) +trees_1000 = io.parse_forest(df_1000, naming=naming) +# calculate loglikelihoods +log_vec_py_AML = np.empty(len(trees_AML)) +log_vec_py_500 = np.empty(len(trees_500)) +log_vec_py_1000 = np.empty(len(trees_1000)) +start_time = time.time() +backend = TreeMHNBackendCode() +theta_AML_size = len(theta_AML) +all_mut_AML = set(range(1, theta_AML_size + 1)) +for idx, tree in trees_AML.items(): + print(f"Processing tree {idx} of {len(trees_AML)}") + tree_log = TreeWrapperCode(tree) + log_value = backend.loglikelihood(tree_log, theta_AML, sampling_rate, all_mut_AML) + log_vec_py_AML[idx - 1] = log_value + print(f"log_value: {log_value}") +theta_500_size = len(theta_500) +all_mut_500 = set(range(1, theta_500_size + 1)) +for idx, tree in trees_500.items(): + print(f"Processing tree {idx} of {len(trees_500)}") + tree_log = TreeWrapperCode(tree) + log_value = backend.loglikelihood(tree_log, theta_500, sampling_rate, all_mut_500) + log_vec_py_500[idx - 1] = log_value + print(f"log_value: {log_value}") +theta_1000_size = len(theta_1000) +all_mut_1000 = set(range(1, theta_1000_size + 1)) +for idx, tree in trees_1000.items(): + print(f"Processing tree {idx} of {len(trees_1000)}") + tree_log = TreeWrapperCode(tree) + log_value = backend.loglikelihood(tree_log, theta_1000, sampling_rate, all_mut_1000) + log_vec_py_1000[idx - 1] = log_value + print(f"log_value: {log_value}") +end_time = time.time() +elapsed_time = end_time - start_time +print(f"Time elapsed: {elapsed_time} seconds") +# write Python loglikelihoods to CSV +np.savetxt("likelihood_py/log_vec_py_AML.csv", log_vec_py_AML, delimiter=",") +np.savetxt("likelihood_py/log_vec_py_500.csv", log_vec_py_500, delimiter=",") + + +# check if the loglikelihood vectors are the same +if np.allclose(log_vec_py_AML, log_vec_R_AML, atol=1e-10): + print("The loglikelihoods of the AML trees are the same in R and Python.") + +if np.allclose(log_vec_py_500, log_vec_R_500, atol=1e-10): + print( + "The loglikelihoods of the 500 randomly generated" + " trees are the same in R and Python." + ) + +if np.allclose(log_vec_py_1000, log_vec_R_1000, atol=1e-10): + print( + "The loglikelihoods of the 1000 randomly generated" + " trees are the same in R and Python." + )