From 4a123d8d3e7af76d899dd3bba9c6f1b782eab559 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Tue, 14 Jan 2025 15:06:37 +0000 Subject: [PATCH] Record more tsinfer parameters in provenance These are useful when coming to inspect how the tree sequence was inferred --- tests/test_provenance.py | 59 ++++++++++++++++++++++++++++++++++++---- tsinfer/inference.py | 17 ++++++++++++ tsinfer/provenance.py | 5 ++++ 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/tests/test_provenance.py b/tests/test_provenance.py index a3c9b5e9..4417c3b1 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -103,15 +103,27 @@ def test_no_provenance_match_samples(self, small_sd_fixture): assert ts.num_provenances == small_sd_fixture.num_provenances @pytest.mark.parametrize("mmr", [None, 0.1]) - def test_provenance_infer(self, small_sd_fixture, mmr): + @pytest.mark.parametrize("pc", [True, False]) + @pytest.mark.parametrize("post", [True, False]) + @pytest.mark.parametrize("precision", [4, 5]) + def test_provenance_infer(self, small_sd_fixture, mmr, pc, post, precision): ts = tsinfer.infer( - small_sd_fixture, mismatch_ratio=mmr, recombination_rate=1e-8 + small_sd_fixture, + path_compression=pc, + post_process=post, + precision=precision, + mismatch_ratio=mmr, + recombination_rate=1e-8, ) assert ts.num_provenances == small_sd_fixture.num_provenances + 1 record = json.loads(ts.provenance(-1).record) params = record["parameters"] assert params["command"] == "infer" + assert params["post_process"] == post + assert params["precision"] == precision assert params["mismatch_ratio"] == mmr + assert params["path_compression"] == pc + assert "simplify" not in params def test_provenance_generate_ancestors(self, small_sd_fixture): ancestors = tsinfer.generate_ancestors(small_sd_fixture) @@ -122,10 +134,17 @@ def test_provenance_generate_ancestors(self, small_sd_fixture): assert params["command"] == "generate_ancestors" @pytest.mark.parametrize("mmr", [None, 0.1]) - def test_provenance_match_ancestors(self, small_sd_fixture, mmr): + @pytest.mark.parametrize("pc", [True, False]) + @pytest.mark.parametrize("precision", [4, 5]) + def test_provenance_match_ancestors(self, small_sd_fixture, mmr, pc, precision): ancestors = tsinfer.generate_ancestors(small_sd_fixture) anc_ts = tsinfer.match_ancestors( - small_sd_fixture, ancestors, mismatch_ratio=mmr, recombination_rate=1e-8 + small_sd_fixture, + ancestors, + mismatch_ratio=mmr, + recombination_rate=1e-8, + path_compression=pc, + precision=precision, ) assert anc_ts.num_provenances == small_sd_fixture.num_provenances + 2 params = json.loads(anc_ts.provenance(-2).record)["parameters"] @@ -133,13 +152,24 @@ def test_provenance_match_ancestors(self, small_sd_fixture, mmr): params = json.loads(anc_ts.provenance(-1).record)["parameters"] assert params["command"] == "match_ancestors" assert params["mismatch_ratio"] == mmr + assert params["path_compression"] == pc + assert params["precision"] == precision @pytest.mark.parametrize("mmr", [None, 0.1]) - def test_provenance_match_samples(self, small_sd_fixture, mmr): + @pytest.mark.parametrize("pc", [True, False]) + @pytest.mark.parametrize("post", [True, False]) + @pytest.mark.parametrize("precision", [4, 5]) + def test_provenance_match_samples(self, small_sd_fixture, mmr, pc, precision, post): ancestors = tsinfer.generate_ancestors(small_sd_fixture) anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors) ts = tsinfer.match_samples( - small_sd_fixture, anc_ts, mismatch_ratio=mmr, recombination_rate=1e-8 + small_sd_fixture, + anc_ts, + mismatch_ratio=mmr, + path_compression=pc, + precision=precision, + post_process=post, + recombination_rate=1e-8, ) assert ts.num_provenances == small_sd_fixture.num_provenances + 3 params = json.loads(ts.provenance(-3).record)["parameters"] @@ -149,6 +179,23 @@ def test_provenance_match_samples(self, small_sd_fixture, mmr): params = json.loads(ts.provenance(-1).record)["parameters"] assert params["command"] == "match_samples" assert params["mismatch_ratio"] == mmr + assert params["path_compression"] == pc + assert params["precision"] == precision + assert params["post_process"] == post + assert "simplify" not in params # deprecated + + @pytest.mark.parametrize("simp", [True, False]) + def test_deprecated_simplify(self, small_sd_fixture, simp): + # Included for completeness, but this is deprecated + ancestors = tsinfer.generate_ancestors(small_sd_fixture) + anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors) + ts1 = tsinfer.match_samples(small_sd_fixture, anc_ts, simplify=simp) + ts2 = tsinfer.infer(small_sd_fixture, simplify=simp) + for ts in [ts1, ts2]: + record = json.loads(ts.provenance(-1).record) + params = record["parameters"] + assert params["simplify"] == simp + assert "post_process" not in params class TestGetProvenance: diff --git a/tsinfer/inference.py b/tsinfer/inference.py index f22811b0..dcb61f8f 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -380,6 +380,11 @@ def infer( record = provenance.get_provenance_dict( command="infer", mismatch_ratio=mismatch_ratio, + path_compression=path_compression, + precision=precision, + simplify=simplify, + post_process=post_process, + # TODO: maybe record recombination rate (which could be a RateMap) ) tables.provenances.add_row(record=json.dumps(record)) inferred_ts = tables.tree_sequence() @@ -577,6 +582,9 @@ def match_ancestors( record = provenance.get_provenance_dict( command="match_ancestors", mismatch_ratio=mismatch_ratio, + path_compression=path_compression, + precision=precision, + # TODO: maybe record recombination rate (which could be a RateMap) ) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() @@ -810,6 +818,8 @@ def match_ancestors_batch_finalise(work_dir): record = provenance.get_provenance_dict( command="match_ancestors", mismatch_ratio=metadata["mismatch_ratio"], + path_compression=metadata["path_compression"], + precision=metadata["precision"], ) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() @@ -901,6 +911,8 @@ def augment_ancestors( record = provenance.get_provenance_dict( command="augment_ancestors", mismatch_ratio=mismatch_ratio, + path_compression=path_compression, + precision=precision, ) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() @@ -1257,6 +1269,11 @@ def match_samples( record = provenance.get_provenance_dict( command="match_samples", mismatch_ratio=mismatch_ratio, + path_compression=path_compression, + precision=precision, + simplify=simplify, + post_process=post_process, + # TODO: maybe record recombination rate (which could be a RateMap) ) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() diff --git a/tsinfer/provenance.py b/tsinfer/provenance.py index d7081fd8..6e165c70 100644 --- a/tsinfer/provenance.py +++ b/tsinfer/provenance.py @@ -75,6 +75,11 @@ def get_provenance_dict(command=None, **kwargs): raise ValueError("Command must be provided") parameters = dict(kwargs) parameters["command"] = command + if "simplify" in parameters: + if parameters["simplify"] is None: + del parameters["simplify"] # simplify is deprecated version of post_process + else: + del parameters["post_process"] document = { "schema_version": "1.0.0", "software": {"name": "tsinfer", "version": __version__},