diff --git a/qrtd.py b/qrtd.py index 1762ec3..8b684bd 100644 --- a/qrtd.py +++ b/qrtd.py @@ -1,5 +1,6 @@ #!/usr/bin/env python -# Copyright (C) 2020-2023 Greenweaves Software Limited + +# Copyright (C) 2020-2023 Simon Crase, simon@greenweaves.nz # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -16,20 +17,87 @@ ''' QRTD Quartet Distance''' -from argparse import ArgumentParser -from deprecated import deprecated -from io import StringIO -from os.path import basename -from time import time -from Bio.Phylo import read, draw -from Bio.Phylo.BaseTree import Clade, Tree -from matplotlib.pyplot import figure, show -import numpy as np -from helpers import read_strings +from argparse import ArgumentParser +from deprecated import deprecated +from io import StringIO +from os.path import basename +from time import time +from Bio.Phylo import read, draw +from Bio.Phylo.BaseTree import Clade, Tree +from matplotlib.pyplot import figure, show +import numpy as np +from helpers import read_strings + +class Colour: + '''The 3 colours mentioned in the paper''' + A = 0 + B = 1 + C = 2 + +class Species: + def __init__(self,name=None,colour=None): + self.name = name + self.colour = colour + +class ComponentClade(Clade): + ''' + ComponentClade + This class represents a clade in the Hierarchical Decomposition Tree, and also + represents a component in the component decompostion + ''' + def __init__(self, + name = '', + clades = [], + branch_length = 1, + composition_type = None): + super().__init__(name = name, + clades = clades, + branch_length = branch_length) + self.composition_type = composition_type + self.colour = None + + @staticmethod + def get_label(clade): + ''' + get_label + + Label nodes on HT tree + ''' + if clade.is_terminal(): + return clade.name + else: + return ''.join(['i' for i in range(clade.composition_type)]) if clade.composition_type<4 else 'iv' + + def decorate(self,S): + def get_element(species,target): + 1 if species.colour == target else 0 + + if self.composition_type==None: + if self.name in S: + species = S[self.name] + self.tuple = (get_element(species,Colour.A), + get_element(species,Colour.B), + get_element(species,Colour.C)) + self.F = lambda a,b,c: 0 + else: + self.tuple = (0,0,0) + self.F = lambda a,b,c: 0 #TODO + elif self.composition_type==1: + self.tuple = (0,0,0) #TODO + self.F = lambda a,b,c: 0 #TODO + elif self.composition_type==2: + self.tuple = (0,0,0) #TODO + self.F = lambda a,b,c: 0 #TODO + elif self.composition_type==3: + self.tuple = (0,0,0) #TODO + self.F = lambda a,b,c: 0 #TODO + elif self.composition_type==4: + self.tuple = (0,0,0) #TODO + self.F = lambda a,b,c: 0 #TODO class HTreeBuilder(Tree): - '''This class constructs Tree associated with Hierarchical Decomposition''' + '''This class constructs the Tree associated with Hierarchical Decomposition''' def build(self,T): '''Build the tree''' @@ -41,21 +109,28 @@ def build(self,T): open_edges = self.collect_edges() while len(open_edges)>0: open_edges = self.extract_open_edges(open_edges) - names = [k for k,v in self.components.items()] - i = np.argmax([len(name) for name in names]) - return Tree.from_clade(self.components[names[i]]) + + return Tree.from_clade(self.get_root()) def create_initial_components(self): - '''Ensure all nodes have names, and wrap each nodes in a component''' + ''' + create_initial_components + + Ensure all nodes have names, and wrap each node in a component + ''' for idx, clade in enumerate(self.T.find_clades()): if not clade.name: clade.name = str(idx) - self.components[clade.name] = Clade(name=clade.name,branch_length=1) + self.components[clade.name] = ComponentClade(name=clade.name) def collect_edges(self): - '''Collect edges for first pass; prioritize edges that terminate in leaves.''' + ''' + collect_edges + + Collect edges for first pass; prioritize edges that terminate in leaves. + ''' edges_leaves = [] edges_internal = [] for a in self.T.find_clades(): @@ -67,9 +142,21 @@ def collect_edges(self): return edges_leaves + edges_internal + def get_type(self,key,clades): + ''' + get_type + + Assign a type to a node in tree + ''' + if len(clades[0].clades)==0 and len(clades[1].clades)==0: return 1 + if all([hasattr(clade,'type') and clade.type==1 for clade in clades]): return 2 + return 3 + def extract_open_edges(self,open_edges): ''' - This is the heart of the algorithm. It processes as many edges as possible, adding them + extract_open_edges + + This is the heart build(...). It processes as many edges as possible, adding them to the clade structure. The remaining edges are modified to point to the top of the constructed clade that contains the original nodes. ''' @@ -77,11 +164,15 @@ def extract_open_edges(self,open_edges): for a,b in open_edges: if a not in self.top and b not in self.top: key = f'{a}-{b}' - self.top[a] = key - self.top[b] = key - self.components[key] = Clade(name = key, - clades = [self.components[a],self.components[b]], - branch_length = 1) + self.top[a] = key + self.top[b] = key + clades = [self.components[a],self.components[b]] + new_component = ComponentClade(name = key, + clades = clades, + branch_length = 1, + composition_type = self.get_type(key,clades)) + + self.components[key] = new_component else: if a in self.top: a = self.top[a] @@ -90,7 +181,20 @@ def extract_open_edges(self,open_edges): remaining_edges.append((a,b)) return remaining_edges + def get_root(self): + ''' + get_root + + Choose clade that will define the root node for tree + ''' + names = [k for k,v in self.components.items()] + index_root = np.argmax([len(name) for name in names]) + root = self.components[names[index_root]] + root.composition_type = 4 + return root + def tabulate_names(tree): + '''tabulate_names''' names = {} for idx, clade in enumerate(tree.find_clades()): if not clade.name: @@ -98,129 +202,116 @@ def tabulate_names(tree): names[clade.name] = clade return names -def root_with_specified_leaf(T,index=3): +def root_with_specified_leaf(T,S,index=0): ''' Allow user to change root to some arbitrary leaf - see Figure 9 ''' leaves = T.get_terminals() T.root_with_outgroup([leaves[index]]) tabulate_names(T) - root_name = leaves[index].name - T.root.clades = [clade for clade in T.root.clades if clade.name!=root_name] - T.root.name = root_name + root_name = leaves[index].name + T.root.clades = [clade for clade in T.root.clades if clade.name!=root_name] + T.root.name = root_name + return root_name + +def link(T1,HT2): + ''' + Create the links shown in Figure 9 + + Currently the link is via dictionary lookup via clade names. Morover the root of T1 is omitted. + ''' + leaves1 = {clade.name:clade for clade in T1.get_terminals()} + leaves2 = {clade.name:clade for clade in HT2.get_terminals() if clade.name in leaves1.keys()} + return leaves1,leaves2 + +def cache_sizes(Tr1): + ''' + cache_sizes + + Used by Count(...) to determine which of two subtrees is larger or amaller + + We process clades bottom up so we can use values that were calculated previously + ''' + size = {} + clades = list(Tr1.find_elements(order='postorder')) + for clade in clades[:-1]: + size[clade.name] = 0 if clade.is_terminal() else sum([(1 + size[child.name]) for child in clade.clades]) + + return size + + +def get_node_count(v): + pass + +def colour_leaves(): + pass + +def Count(v,sizes=[]): + '''Code from Figure 8''' + def get_subtree(v,small=True): + a = v.clades[0] + b = v.clades[1] + m = sizes[a.name] + n = sizes[b.name] + return m if small == (m-1: - adj[parentid].append(id if len(name)==0 else indices[name]) - for child in children: - dfs(child) - dfs(tree) - return adj -@deprecated -def create_edges(adj,n=None): - return [(a,b) for a,children in adj.items() for b in children if b>=n] - - -@deprecated -def extract_quartets(edges,adj,n=None): - - # get_leaves - # - # Get all leaves in graph 'adj' that are below some specified node - def get_leaves(node): - # dfs - # - # Conduct depth first search, looking for leaves - def dfs(u): - if u>=n: - for v in adj[u]: - dfs(v) - else: - leaves.append(u) - - leaves = [] - dfs(node) - return leaves - - def split(a,b): - s_b = leaves[b] - s_a = [s for s in all_leaves if s not in s_b] - return [(s_a[j],s_a[i],s_b[l],s_b[k]) for i in range(len(s_a)) - for j in range(i) - for k in range(len(s_b)) - for l in range(k)] - - all_leaves = list(range(n)) - leaves = {x:sorted(get_leaves(x)) for x in adj.keys()} - internal_edges = [(a,b) for a,b in edges if b>=n] - return [q for a,b in internal_edges for q in split(a,b)] -@deprecated -def get_matches(quartets1,quartets2): - i = 0 - j = 0 - matches = 0 - while i < len(quartets1) and j < len(quartets2): - if quartets1[i]==quartets2[j]: - matches += 1 - i += 1 - j += 1 - elif quartets1[i]quartets2[j] - j+=1 - - return matches - -# qrtd -# -# Given: A list containing n taxa and two unrooted binary trees T1 and T2 on the given taxa. -# Both T1 and T2 are given in Newick format. -# -# Return: The quartet distance dq(T1,T2) -@deprecated def qrtd(species,T1,T2): - n = len(species) - indices = {species[i]:i for i in range(n)} - tree1 = parse(T1,start=n) - adj1 = create_adj(tree1,indices=indices) - edges1 = create_edges(adj1,n=n) - quartets1 = extract_quartets(edges1,adj1,n=n) - print (len(quartets1), quartets1) - - tree2 = parse(T2,start=n) - adj2 = create_adj(tree2,indices=indices) - edges2 = create_edges(adj2,n=n) - quartets2 = extract_quartets(edges2,adj2,n=n) - print (len(quartets2),quartets2) - mismatches = 0 - for q in quartets1: - if not q in quartets2: - mismatches+=1 - for q in quartets2: - if not q in quartets1: - mismatches+=1 - return mismatches - #return len(quartets1) + len(quartets2) - 2*get_matches(quartets1,quartets2) + ''' + qrtd + + Given: A list containing n taxa and two unrooted binary trees T1 and T2 on the given taxa. + Both T1 and T2 are given in Newick format. + + Return: The quartet distance dq(T1,T2) + ''' + + S = {s:Species(s,colour=Colour.A) for s in species} + + Tr1 = read(StringIO(T1), 'newick') + Tr2 = read(StringIO(T2), 'newick') + root_name = root_with_specified_leaf(Tr1,S) + S[root_name].colour = Colour.C + Factory = HTreeBuilder() + HT2 = Factory.build(Tr2) + leaves1,leaves2 = link(Tr1,HT2) + decorate(HT2,S) + return Count(Tr1.root, + sizes = cache_sizes(Tr1)) + if __name__=='__main__': start = time() @@ -235,7 +326,8 @@ def qrtd(species,T1,T2): T0 = read(StringIO('a,((e,f),d),(b,c);'), 'newick') T1 = read(StringIO('a,((e,f),d),(b,c);'), 'newick') T2 = read(StringIO('a,((b,d),c),(e,f);'), 'newick') - root_with_specified_leaf(T1) + root_with_specified_leaf(T1, ['a', 'b', 'c', 'd','e', 'f'], + index = 3) Factory = HTreeBuilder() HT2 = Factory.build(T2) @@ -252,7 +344,7 @@ def qrtd(species,T1,T2): draw_tree(T1,'T1 re rooted',ax21) ax22 = fig.add_subplot(224) - draw_tree(HT2,'HT2',ax22) + draw_tree(HT2,'HT2',ax22, label_func = ComponentClade.get_label) if args.sample: