Skip to content

Commit

Permalink
rewrite of compare
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp committed Jan 23, 2025
1 parent 5b6dc82 commit 697d08f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 56 deletions.
2 changes: 1 addition & 1 deletion requirements/CI-docs-pip/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ sphinx-argparse==0.4.0
svgwrite==1.4.3
tskit==0.6.0
tsinfer==0.3.3
scipy==1.14.1
scipy==1.15.1
msprime==1.3.2
sphinx-book-theme

2 changes: 1 addition & 1 deletion requirements/CI-tests-conda/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
scipy==1.14.1
scipy==1.15.1
msprime==1.3.2
35 changes: 35 additions & 0 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,41 @@ def verify_compare(self, ts, other, transform=None):
assert np.isclose(other_span, dis.total_span[1])
assert np.isclose(rmse, dis.rmse), f"{rmse} != {dis.rmse}"

def test_very_simple(self):
# 1.00┊ 2 ┊
# ┊ ┏┻┓ ┊
# 0.00┊ 0 1 ┊
# 0 1
ts = tskit.Tree.generate_star(2).tree_sequence
dis = tscompare.compare(ts, ts)
assert dis.arf == 0.0
assert dis.tpr == 1.0
assert dis.dissimilarity == 0.0
assert dis.inverse_dissimilarity == 0.0
assert dis.total_span == (3.0, 3.0)
assert dis.rmse == 0.0
# remove 1->2 branch
tables = ts.tables
tables.edges.clear()
tables.edges.add_row(parent=2, child=0, left=0, right=1)
empty_ts = tables.tree_sequence()
dis = tscompare.compare(ts, empty_ts)
assert np.isclose(dis.arf, 1/3)
assert dis.tpr == 1.0
assert dis.dissimilarity == 1.0
assert dis.inverse_dissimilarity == 0.0
assert dis.total_span == (3.0, 2.0)
assert dis.rmse == 0.0
dis = tscompare.compare(empty_ts, ts)
print(dis)
assert np.isclose(dis.arf, 1/3)
assert np.isclose(dis.tpr, 2/3)
assert dis.dissimilarity == 1.0
assert dis.inverse_dissimilarity == 1.0
assert dis.total_span == (2.0, 3.0)
assert np.isnan(dis.rmse)


@pytest.mark.parametrize(
"pair",
[(true_ext, true_ext), (true_simpl, true_ext), (true_simpl, true_unary)],
Expand Down
86 changes: 32 additions & 54 deletions tscompare/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,74 +408,52 @@ def f(t):
if transform is None:
transform = f

ts_node_spans = node_spans(ts)
shared_spans = shared_node_spans(ts, other)
col_ind = shared_spans.indices
row_ind = np.repeat(
np.arange(shared_spans.shape[0]), repeats=np.diff(shared_spans.indptr)
)
# Find all potential matches for a node based on max shared span length
max_span = shared_spans.max(axis=1).toarray().flatten()
total_match_n1_span = np.sum(max_span) # <---- one thing to output
# zero out everything that's not a row max
shared_spans.data[shared_spans.data != max_span[row_ind]] = 0.0
# now re-sparsify the matrix: but, beware! don't do this again later.
shared_spans.eliminate_zeros()
col_ind = shared_spans.indices
row_ind = np.repeat(
np.arange(shared_spans.shape[0]), repeats=np.diff(shared_spans.indptr)
)
# mask to find all potential node matches
match = shared_spans.data == max_span[row_ind]
# scale with difference in node times
# determine best matches with the best_match_matrix
ts_times = ts.nodes_time[row_ind[match]]
other_times = other.nodes_time[col_ind[match]]
time_difference = np.absolute(

# now, make a matrix with differences in transformed times
# in the places where shared_spans retains nonzero elements
time_diff = shared_spans.copy()
ts_times = ts.nodes_time[row_ind]
other_times = other.nodes_time[col_ind]
time_diff.data[:] = np.absolute(
np.asarray(transform(ts_times) - transform(other_times))
)
# If a node x in `ts` has no match then we set time_difference to zero
# This node then does not effect the rmse
for j in range(len(shared_spans.data[match])):
if shared_spans.data[match][j] == 0:
time_difference[j] = 0.0
# If two nodes have the same time, then
# time_difference is zero, which causes problems with argmin
# Instead we store data as 1/(1+x) and find argmax
best_match_matrix = scipy.sparse.coo_matrix(
(
1 / (1 + time_difference),
(row_ind[match], col_ind[match]),
),
shape=(ts.num_nodes, other.num_nodes),
)
# Between each pair of nodes, find the maximum shared span
# n1_match is the matching N1 -> N2 (for arf, dissimilarity)
# n2_match is finds the max match between nodes in N2 and their
# best match in N1 based on max-span (for tpr, inverse_dissimilarity)
best_n1_match = best_match_matrix.argmax(axis=1).A1
n2_match_matrix = best_match_matrix.tocsr()
bmm_row_ind = np.repeat(
np.arange(n2_match_matrix.shape[0]), repeats=np.diff(n2_match_matrix.indptr)
)
n2_match_matrix.data *= n2_match_matrix.indices == best_n1_match[bmm_row_ind]
best_n2_match = n2_match_matrix.argmax(axis=0).A1
n2_match_mask = best_n1_match[best_n2_match] == np.arange(other.num_nodes)
best_match_n1_spans = shared_spans[np.arange(ts.num_nodes), best_n1_match].reshape(
-1
)
best_match_n2_spans = shared_spans[
best_n2_match, np.arange(other.num_nodes)
].reshape(-1)[0, n2_match_mask]
total_match_n1_span = np.sum(best_match_n1_spans)
total_match_n2_span = np.sum(best_match_n2_spans)
ts_node_spans = node_spans(ts)
# "explicit=True" takes the min of only the entries explicitly represented
dt = time_diff.min(axis=1, explicit=True).toarray().flatten()
has_match = (max_span != 0)
if np.any(has_match):
rmse = np.sqrt(np.sum(dt[has_match]**2 * ts_node_spans[has_match]) / np.sum(ts_node_spans[has_match]))
# ^-- another thing to output
else:
rmse = np.nan

# next, zero out also those non-best-time-match elements
shared_spans.data[time_diff.data != dt[row_ind]] = 0.0
# and, find sum of column maxima
total_match_n2_span = shared_spans.max(axis=0).sum() # <--- the other thing we return

total_span_ts = np.sum(ts_node_spans)
total_span_other = np.sum(node_spans(other))
# Compute the root-mean-square difference in transformed time
# with the average weighted by span in ts
time_matrix = scipy.sparse.csr_matrix(
(time_difference, (row_ind[match], col_ind[match])),
shape=(ts.num_nodes, other.num_nodes),
)
time_discrepancies = np.asarray(
time_matrix[np.arange(len(best_n1_match)), best_n1_match].reshape(-1)
)
product = np.multiply((time_discrepancies**2), ts_node_spans)
rmse = np.sqrt(np.sum(product) / total_span_ts)
return ARFResult(
arf=1.0 - total_match_n1_span / total_span_ts,
tpr=total_match_n2_span / total_span_other,
# matched_span=(total_match_n1_span, total_match_n2_span),
dissimilarity=total_span_ts - total_match_n1_span,
inverse_dissimilarity=total_span_other - total_match_n2_span,
total_span=(total_span_ts, total_span_other),
Expand Down

0 comments on commit 697d08f

Please sign in to comment.