Skip to content

Commit

Permalink
Merge pull request #990 from hyanwong/record-postprocess
Browse files Browse the repository at this point in the history
Record more tsinfer parameters in provenance
  • Loading branch information
benjeffery authored Jan 16, 2025
2 parents ba51298 + 4a123d8 commit e3b2155
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 6 deletions.
59 changes: 53 additions & 6 deletions tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,27 @@ def test_no_provenance_match_samples(self, small_sd_fixture):
assert ts.num_provenances == small_sd_fixture.num_provenances

@pytest.mark.parametrize("mmr", [None, 0.1])
def test_provenance_infer(self, small_sd_fixture, mmr):
@pytest.mark.parametrize("pc", [True, False])
@pytest.mark.parametrize("post", [True, False])
@pytest.mark.parametrize("precision", [4, 5])
def test_provenance_infer(self, small_sd_fixture, mmr, pc, post, precision):
ts = tsinfer.infer(
small_sd_fixture, mismatch_ratio=mmr, recombination_rate=1e-8
small_sd_fixture,
path_compression=pc,
post_process=post,
precision=precision,
mismatch_ratio=mmr,
recombination_rate=1e-8,
)
assert ts.num_provenances == small_sd_fixture.num_provenances + 1
record = json.loads(ts.provenance(-1).record)
params = record["parameters"]
assert params["command"] == "infer"
assert params["post_process"] == post
assert params["precision"] == precision
assert params["mismatch_ratio"] == mmr
assert params["path_compression"] == pc
assert "simplify" not in params

def test_provenance_generate_ancestors(self, small_sd_fixture):
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
Expand All @@ -122,24 +134,42 @@ def test_provenance_generate_ancestors(self, small_sd_fixture):
assert params["command"] == "generate_ancestors"

@pytest.mark.parametrize("mmr", [None, 0.1])
def test_provenance_match_ancestors(self, small_sd_fixture, mmr):
@pytest.mark.parametrize("pc", [True, False])
@pytest.mark.parametrize("precision", [4, 5])
def test_provenance_match_ancestors(self, small_sd_fixture, mmr, pc, precision):
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
anc_ts = tsinfer.match_ancestors(
small_sd_fixture, ancestors, mismatch_ratio=mmr, recombination_rate=1e-8
small_sd_fixture,
ancestors,
mismatch_ratio=mmr,
recombination_rate=1e-8,
path_compression=pc,
precision=precision,
)
assert anc_ts.num_provenances == small_sd_fixture.num_provenances + 2
params = json.loads(anc_ts.provenance(-2).record)["parameters"]
assert params["command"] == "generate_ancestors"
params = json.loads(anc_ts.provenance(-1).record)["parameters"]
assert params["command"] == "match_ancestors"
assert params["mismatch_ratio"] == mmr
assert params["path_compression"] == pc
assert params["precision"] == precision

@pytest.mark.parametrize("mmr", [None, 0.1])
def test_provenance_match_samples(self, small_sd_fixture, mmr):
@pytest.mark.parametrize("pc", [True, False])
@pytest.mark.parametrize("post", [True, False])
@pytest.mark.parametrize("precision", [4, 5])
def test_provenance_match_samples(self, small_sd_fixture, mmr, pc, precision, post):
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors)
ts = tsinfer.match_samples(
small_sd_fixture, anc_ts, mismatch_ratio=mmr, recombination_rate=1e-8
small_sd_fixture,
anc_ts,
mismatch_ratio=mmr,
path_compression=pc,
precision=precision,
post_process=post,
recombination_rate=1e-8,
)
assert ts.num_provenances == small_sd_fixture.num_provenances + 3
params = json.loads(ts.provenance(-3).record)["parameters"]
Expand All @@ -149,6 +179,23 @@ def test_provenance_match_samples(self, small_sd_fixture, mmr):
params = json.loads(ts.provenance(-1).record)["parameters"]
assert params["command"] == "match_samples"
assert params["mismatch_ratio"] == mmr
assert params["path_compression"] == pc
assert params["precision"] == precision
assert params["post_process"] == post
assert "simplify" not in params # deprecated

@pytest.mark.parametrize("simp", [True, False])
def test_deprecated_simplify(self, small_sd_fixture, simp):
# Included for completeness, but this is deprecated
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors)
ts1 = tsinfer.match_samples(small_sd_fixture, anc_ts, simplify=simp)
ts2 = tsinfer.infer(small_sd_fixture, simplify=simp)
for ts in [ts1, ts2]:
record = json.loads(ts.provenance(-1).record)
params = record["parameters"]
assert params["simplify"] == simp
assert "post_process" not in params


class TestGetProvenance:
Expand Down
17 changes: 17 additions & 0 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,11 @@ def infer(
record = provenance.get_provenance_dict(
command="infer",
mismatch_ratio=mismatch_ratio,
path_compression=path_compression,
precision=precision,
simplify=simplify,
post_process=post_process,
# TODO: maybe record recombination rate (which could be a RateMap)
)
tables.provenances.add_row(record=json.dumps(record))
inferred_ts = tables.tree_sequence()
Expand Down Expand Up @@ -577,6 +582,9 @@ def match_ancestors(
record = provenance.get_provenance_dict(
command="match_ancestors",
mismatch_ratio=mismatch_ratio,
path_compression=path_compression,
precision=precision,
# TODO: maybe record recombination rate (which could be a RateMap)
)
tables.provenances.add_row(record=json.dumps(record))
ts = tables.tree_sequence()
Expand Down Expand Up @@ -810,6 +818,8 @@ def match_ancestors_batch_finalise(work_dir):
record = provenance.get_provenance_dict(
command="match_ancestors",
mismatch_ratio=metadata["mismatch_ratio"],
path_compression=metadata["path_compression"],
precision=metadata["precision"],
)
tables.provenances.add_row(record=json.dumps(record))
ts = tables.tree_sequence()
Expand Down Expand Up @@ -901,6 +911,8 @@ def augment_ancestors(
record = provenance.get_provenance_dict(
command="augment_ancestors",
mismatch_ratio=mismatch_ratio,
path_compression=path_compression,
precision=precision,
)
tables.provenances.add_row(record=json.dumps(record))
ts = tables.tree_sequence()
Expand Down Expand Up @@ -1257,6 +1269,11 @@ def match_samples(
record = provenance.get_provenance_dict(
command="match_samples",
mismatch_ratio=mismatch_ratio,
path_compression=path_compression,
precision=precision,
simplify=simplify,
post_process=post_process,
# TODO: maybe record recombination rate (which could be a RateMap)
)
tables.provenances.add_row(record=json.dumps(record))
ts = tables.tree_sequence()
Expand Down
5 changes: 5 additions & 0 deletions tsinfer/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def get_provenance_dict(command=None, **kwargs):
raise ValueError("Command must be provided")
parameters = dict(kwargs)
parameters["command"] = command
if "simplify" in parameters:
if parameters["simplify"] is None:
del parameters["simplify"] # simplify is deprecated version of post_process
else:
del parameters["post_process"]
document = {
"schema_version": "1.0.0",
"software": {"name": "tsinfer", "version": __version__},
Expand Down

0 comments on commit e3b2155

Please sign in to comment.