From 33470dac711f1027b0c64563e6480df99e93e2de Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Tue, 14 Jan 2025 15:06:37 +0000 Subject: [PATCH] Record more tsinfer parameters in provenance These are useful when coming to inspect how the tree sequence was inferred --- tests/test_provenance.py | 59 ++++++++++++++++++++++++++++++++++++---- tsinfer/inference.py | 33 ++++++++++++++++------ 2 files changed, 78 insertions(+), 14 deletions(-) diff --git a/tests/test_provenance.py b/tests/test_provenance.py index a3c9b5e9..4417c3b1 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -103,15 +103,27 @@ def test_no_provenance_match_samples(self, small_sd_fixture): assert ts.num_provenances == small_sd_fixture.num_provenances @pytest.mark.parametrize("mmr", [None, 0.1]) - def test_provenance_infer(self, small_sd_fixture, mmr): + @pytest.mark.parametrize("pc", [True, False]) + @pytest.mark.parametrize("post", [True, False]) + @pytest.mark.parametrize("precision", [4, 5]) + def test_provenance_infer(self, small_sd_fixture, mmr, pc, post, precision): ts = tsinfer.infer( - small_sd_fixture, mismatch_ratio=mmr, recombination_rate=1e-8 + small_sd_fixture, + path_compression=pc, + post_process=post, + precision=precision, + mismatch_ratio=mmr, + recombination_rate=1e-8, ) assert ts.num_provenances == small_sd_fixture.num_provenances + 1 record = json.loads(ts.provenance(-1).record) params = record["parameters"] assert params["command"] == "infer" + assert params["post_process"] == post + assert params["precision"] == precision assert params["mismatch_ratio"] == mmr + assert params["path_compression"] == pc + assert "simplify" not in params def test_provenance_generate_ancestors(self, small_sd_fixture): ancestors = tsinfer.generate_ancestors(small_sd_fixture) @@ -122,10 +134,17 @@ def test_provenance_generate_ancestors(self, small_sd_fixture): assert params["command"] == "generate_ancestors" @pytest.mark.parametrize("mmr", [None, 0.1]) - def test_provenance_match_ancestors(self, small_sd_fixture, mmr): + @pytest.mark.parametrize("pc", [True, False]) + @pytest.mark.parametrize("precision", [4, 5]) + def test_provenance_match_ancestors(self, small_sd_fixture, mmr, pc, precision): ancestors = tsinfer.generate_ancestors(small_sd_fixture) anc_ts = tsinfer.match_ancestors( - small_sd_fixture, ancestors, mismatch_ratio=mmr, recombination_rate=1e-8 + small_sd_fixture, + ancestors, + mismatch_ratio=mmr, + recombination_rate=1e-8, + path_compression=pc, + precision=precision, ) assert anc_ts.num_provenances == small_sd_fixture.num_provenances + 2 params = json.loads(anc_ts.provenance(-2).record)["parameters"] @@ -133,13 +152,24 @@ def test_provenance_match_ancestors(self, small_sd_fixture, mmr): params = json.loads(anc_ts.provenance(-1).record)["parameters"] assert params["command"] == "match_ancestors" assert params["mismatch_ratio"] == mmr + assert params["path_compression"] == pc + assert params["precision"] == precision @pytest.mark.parametrize("mmr", [None, 0.1]) - def test_provenance_match_samples(self, small_sd_fixture, mmr): + @pytest.mark.parametrize("pc", [True, False]) + @pytest.mark.parametrize("post", [True, False]) + @pytest.mark.parametrize("precision", [4, 5]) + def test_provenance_match_samples(self, small_sd_fixture, mmr, pc, precision, post): ancestors = tsinfer.generate_ancestors(small_sd_fixture) anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors) ts = tsinfer.match_samples( - small_sd_fixture, anc_ts, mismatch_ratio=mmr, recombination_rate=1e-8 + small_sd_fixture, + anc_ts, + mismatch_ratio=mmr, + path_compression=pc, + precision=precision, + post_process=post, + recombination_rate=1e-8, ) assert ts.num_provenances == small_sd_fixture.num_provenances + 3 params = json.loads(ts.provenance(-3).record)["parameters"] @@ -149,6 +179,23 @@ def test_provenance_match_samples(self, small_sd_fixture, mmr): params = json.loads(ts.provenance(-1).record)["parameters"] assert params["command"] == "match_samples" assert params["mismatch_ratio"] == mmr + assert params["path_compression"] == pc + assert params["precision"] == precision + assert params["post_process"] == post + assert "simplify" not in params # deprecated + + @pytest.mark.parametrize("simp", [True, False]) + def test_deprecated_simplify(self, small_sd_fixture, simp): + # Included for completeness, but this is deprecated + ancestors = tsinfer.generate_ancestors(small_sd_fixture) + anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors) + ts1 = tsinfer.match_samples(small_sd_fixture, anc_ts, simplify=simp) + ts2 = tsinfer.infer(small_sd_fixture, simplify=simp) + for ts in [ts1, ts2]: + record = json.loads(ts.provenance(-1).record) + params = record["parameters"] + assert params["simplify"] == simp + assert "post_process" not in params class TestGetProvenance: diff --git a/tsinfer/inference.py b/tsinfer/inference.py index f22811b0..4c596b22 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -377,10 +377,16 @@ def infer( ) if record_provenance: tables = inferred_ts.dump_tables() - record = provenance.get_provenance_dict( - command="infer", - mismatch_ratio=mismatch_ratio, - ) + params = { # TODO: maybe record recombination rate (which could be a RateMap) + "mismatch_ratio": mismatch_ratio, + "path_compression": path_compression, + "precision": precision, + } + if simplify is None: + params["post_process"] = post_process + else: + params["simplify"] = simplify + record = provenance.get_provenance_dict(command="infer", **params) tables.provenances.add_row(record=json.dumps(record)) inferred_ts = tables.tree_sequence() return inferred_ts @@ -577,6 +583,9 @@ def match_ancestors( record = provenance.get_provenance_dict( command="match_ancestors", mismatch_ratio=mismatch_ratio, + path_compression=path_compression, + precision=precision, + # TODO: maybe record recombination rate (which could be a RateMap) ) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() @@ -901,6 +910,8 @@ def augment_ancestors( record = provenance.get_provenance_dict( command="augment_ancestors", mismatch_ratio=mismatch_ratio, + path_compression=path_compression, + precision=precision, ) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() @@ -1254,10 +1265,16 @@ def match_samples( if record_provenance: tables = ts.dump_tables() # We don't have a source here because tree sequence files don't have a UUID yet. - record = provenance.get_provenance_dict( - command="match_samples", - mismatch_ratio=mismatch_ratio, - ) + params = { # TODO: maybe record recombination rate (which could be a RateMap) + "mismatch_ratio": mismatch_ratio, + "path_compression": path_compression, + "precision": precision, + } + if simplify is None: + params["post_process"] = post_process + else: + params["simplify"] = simplify + record = provenance.get_provenance_dict(command="match_samples", **params) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() return ts