Skip to content

Commit

Permalink
#46. Refactored. Still slow, and inaccurate for sample
Browse files Browse the repository at this point in the history
  • Loading branch information
weka511 committed Dec 21, 2020
1 parent d33950d commit 68b2ea6
Showing 1 changed file with 47 additions and 25 deletions.
72 changes: 47 additions & 25 deletions qrtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,20 @@ def dfs(tree):
dfs(tree)
return adj

def create_edges(tree,indices=None):
edges = []
def dfs(tree):
id = tree['id']
name = tree['name']
children = tree['children']
parentid = tree['parentid']
def create_edges(adj,n=None):
return [(a,b) for a,children in adj.items() for b in children if b>=n]

for child in children:
if len(child['name'])==0:
edges.append((id,child['id']))
dfs(child)
dfs(tree)
return edges
# extract_quartets

def extract_quartets(edges,adj,n=None):
def get_leaves(x):

# get_leaves
#
# Get all leaves in graph 'adj' that are below some specified node
def get_leaves(node):
# dfs
#
# Conduct depth first search, looking for leaves
def dfs(u):
if u>=n:
for v in adj[u]:
Expand All @@ -62,35 +59,60 @@ def dfs(u):
leaves.append(u)

leaves = []
dfs(x)
dfs(node)
return leaves

def split(a,b):
s_b = leaves[b]
s_a = [s for s in leaves[a] if s not in s_b]
s_a = [s for s in all_leaves if s not in s_b]
return [(s_a[j],s_a[i],s_b[l],s_b[k]) for i in range(len(s_a))
for j in range(i)
for k in range(len(s_b))
for l in range(k)]

all_leaves = list(range(n))
leaves = {x:sorted(get_leaves(x)) for x in adj.keys()}
internal_edges = [(a,b) for a,b in edges if b>=n]
return [q for a,b in internal_edges for q in split(a,b)]

def get_matches(quartets1,quartets2):
i = 0
j = 0
matches = 0
while i < len(quartets1) and j < len(quartets2):
if quartets1[i]==quartets2[j]:
matches += 1
i += 1
j += 1
elif quartets1[i]<quartets2[j]:
i += 1
else: # quartets1[i]>quartets2[j]
j+=1

return matches

# qrtd
#
# Given: A list containing n taxa and two unrooted binary trees T1 and T2 on the given taxa.
# Both T1 and T2 are given in Newick format.
#
# Return: The quartet distance dq(T1,T2)


def qrtd(species,newick1,newick2):
def qrtd(species,T1,T2):
n = len(species)
indices = {species[i]:i for i in range(n)}
tree1 = parse(newick1,start=n)
edges1 = create_edges(tree1,indices=indices)
tree1 = parse(T1,start=n)
adj1 = create_adj(tree1,indices=indices)
q1 = extract_quartets(edges1,adj1,n=n)
quartets1 = set(q1)
tree2 = parse(newick2,start=n)
edges2 = create_edges(tree2,indices=indices)
edges1 = create_edges(adj1,n=n)
quartets1 = sorted(set(extract_quartets(edges1,adj1,n=n)))
print (len(quartets1), quartets1)

tree2 = parse(T2,start=n)
adj2 = create_adj(tree2,indices=indices)
return 2*(n - sum([1 for q in set(extract_quartets(edges2,adj2,n=n)) if q in quartets1]))
edges2 = create_edges(adj2,n=n)
quartets2 = sorted(set(extract_quartets(edges2,adj2,n=n)))
print (len(quartets2),quartets2)
return len(quartets1) + len(quartets2) - 2*get_matches(quartets1,quartets2)

if __name__=='__main__':
start = time.time()
Expand Down

0 comments on commit 68b2ea6

Please sign in to comment.