Skip to content

Commit

Permalink
Record more tsinfer parameters in provenance
Browse files Browse the repository at this point in the history
These are useful when coming to inspect how the tree sequence was inferred
  • Loading branch information
hyanwong committed Jan 14, 2025
1 parent 0a83c7b commit 33470da
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 14 deletions.
59 changes: 53 additions & 6 deletions tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,27 @@ def test_no_provenance_match_samples(self, small_sd_fixture):
assert ts.num_provenances == small_sd_fixture.num_provenances

@pytest.mark.parametrize("mmr", [None, 0.1])
def test_provenance_infer(self, small_sd_fixture, mmr):
@pytest.mark.parametrize("pc", [True, False])
@pytest.mark.parametrize("post", [True, False])
@pytest.mark.parametrize("precision", [4, 5])
def test_provenance_infer(self, small_sd_fixture, mmr, pc, post, precision):
ts = tsinfer.infer(
small_sd_fixture, mismatch_ratio=mmr, recombination_rate=1e-8
small_sd_fixture,
path_compression=pc,
post_process=post,
precision=precision,
mismatch_ratio=mmr,
recombination_rate=1e-8,
)
assert ts.num_provenances == small_sd_fixture.num_provenances + 1
record = json.loads(ts.provenance(-1).record)
params = record["parameters"]
assert params["command"] == "infer"
assert params["post_process"] == post
assert params["precision"] == precision
assert params["mismatch_ratio"] == mmr
assert params["path_compression"] == pc
assert "simplify" not in params

def test_provenance_generate_ancestors(self, small_sd_fixture):
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
Expand All @@ -122,24 +134,42 @@ def test_provenance_generate_ancestors(self, small_sd_fixture):
assert params["command"] == "generate_ancestors"

@pytest.mark.parametrize("mmr", [None, 0.1])
def test_provenance_match_ancestors(self, small_sd_fixture, mmr):
@pytest.mark.parametrize("pc", [True, False])
@pytest.mark.parametrize("precision", [4, 5])
def test_provenance_match_ancestors(self, small_sd_fixture, mmr, pc, precision):
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
anc_ts = tsinfer.match_ancestors(
small_sd_fixture, ancestors, mismatch_ratio=mmr, recombination_rate=1e-8
small_sd_fixture,
ancestors,
mismatch_ratio=mmr,
recombination_rate=1e-8,
path_compression=pc,
precision=precision,
)
assert anc_ts.num_provenances == small_sd_fixture.num_provenances + 2
params = json.loads(anc_ts.provenance(-2).record)["parameters"]
assert params["command"] == "generate_ancestors"
params = json.loads(anc_ts.provenance(-1).record)["parameters"]
assert params["command"] == "match_ancestors"
assert params["mismatch_ratio"] == mmr
assert params["path_compression"] == pc
assert params["precision"] == precision

@pytest.mark.parametrize("mmr", [None, 0.1])
def test_provenance_match_samples(self, small_sd_fixture, mmr):
@pytest.mark.parametrize("pc", [True, False])
@pytest.mark.parametrize("post", [True, False])
@pytest.mark.parametrize("precision", [4, 5])
def test_provenance_match_samples(self, small_sd_fixture, mmr, pc, precision, post):
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors)
ts = tsinfer.match_samples(
small_sd_fixture, anc_ts, mismatch_ratio=mmr, recombination_rate=1e-8
small_sd_fixture,
anc_ts,
mismatch_ratio=mmr,
path_compression=pc,
precision=precision,
post_process=post,
recombination_rate=1e-8,
)
assert ts.num_provenances == small_sd_fixture.num_provenances + 3
params = json.loads(ts.provenance(-3).record)["parameters"]
Expand All @@ -149,6 +179,23 @@ def test_provenance_match_samples(self, small_sd_fixture, mmr):
params = json.loads(ts.provenance(-1).record)["parameters"]
assert params["command"] == "match_samples"
assert params["mismatch_ratio"] == mmr
assert params["path_compression"] == pc
assert params["precision"] == precision
assert params["post_process"] == post
assert "simplify" not in params # deprecated

@pytest.mark.parametrize("simp", [True, False])
def test_deprecated_simplify(self, small_sd_fixture, simp):
# Included for completeness, but this is deprecated
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors)
ts1 = tsinfer.match_samples(small_sd_fixture, anc_ts, simplify=simp)
ts2 = tsinfer.infer(small_sd_fixture, simplify=simp)
for ts in [ts1, ts2]:
record = json.loads(ts.provenance(-1).record)
params = record["parameters"]
assert params["simplify"] == simp
assert "post_process" not in params


class TestGetProvenance:
Expand Down
33 changes: 25 additions & 8 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,16 @@ def infer(
)
if record_provenance:
tables = inferred_ts.dump_tables()
record = provenance.get_provenance_dict(
command="infer",
mismatch_ratio=mismatch_ratio,
)
params = { # TODO: maybe record recombination rate (which could be a RateMap)
"mismatch_ratio": mismatch_ratio,
"path_compression": path_compression,
"precision": precision,
}
if simplify is None:
params["post_process"] = post_process
else:
params["simplify"] = simplify
record = provenance.get_provenance_dict(command="infer", **params)
tables.provenances.add_row(record=json.dumps(record))
inferred_ts = tables.tree_sequence()
return inferred_ts
Expand Down Expand Up @@ -577,6 +583,9 @@ def match_ancestors(
record = provenance.get_provenance_dict(
command="match_ancestors",
mismatch_ratio=mismatch_ratio,
path_compression=path_compression,
precision=precision,
# TODO: maybe record recombination rate (which could be a RateMap)
)
tables.provenances.add_row(record=json.dumps(record))
ts = tables.tree_sequence()
Expand Down Expand Up @@ -901,6 +910,8 @@ def augment_ancestors(
record = provenance.get_provenance_dict(
command="augment_ancestors",
mismatch_ratio=mismatch_ratio,
path_compression=path_compression,
precision=precision,
)
tables.provenances.add_row(record=json.dumps(record))
ts = tables.tree_sequence()
Expand Down Expand Up @@ -1254,10 +1265,16 @@ def match_samples(
if record_provenance:
tables = ts.dump_tables()
# We don't have a source here because tree sequence files don't have a UUID yet.
record = provenance.get_provenance_dict(
command="match_samples",
mismatch_ratio=mismatch_ratio,
)
params = { # TODO: maybe record recombination rate (which could be a RateMap)
"mismatch_ratio": mismatch_ratio,
"path_compression": path_compression,
"precision": precision,
}
if simplify is None:
params["post_process"] = post_process
else:
params["simplify"] = simplify
record = provenance.get_provenance_dict(command="match_samples", **params)
tables.provenances.add_row(record=json.dumps(record))
ts = tables.tree_sequence()
return ts
Expand Down

0 comments on commit 33470da

Please sign in to comment.