Skip to content

Commit

Permalink
Merge pull request #442 from hyanwong/resources
Browse files Browse the repository at this point in the history
Record resources used in tsdate inference
  • Loading branch information
hyanwong authored Dec 5, 2024
2 parents 2b3412a + f797a4a commit 441ea2b
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 10 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tskit>=0.5.8
tskit>=0.6.0
tsinfer>=0.3.0
ruff
numpy
Expand Down
22 changes: 22 additions & 0 deletions tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ def test_date_params_recorded(self):
assert np.isclose(rec["parameters"]["population_size"], Ne)
assert rec["parameters"]["command"] == "maximization"

def test_date_time_recorded(self):
ts = utility_functions.single_tree_ts_n2()
mu = 0.123
Ne = 9
dated_ts = tsdate.date(
ts, population_size=Ne, mutation_rate=mu, method="maximization"
)
rec = json.loads(dated_ts.provenance(-1).record)
assert "resources" in rec
assert rec["resources"]["elapsed_time"] >= 0
assert rec["resources"]["user_time"] >= 0
assert rec["resources"]["sys_time"] >= 0

@pytest.mark.parametrize(
"popdict",
[
Expand Down Expand Up @@ -119,6 +132,15 @@ def test_preprocess_interval_recorded(self):
assert 40 < deleted_intervals[0][0] < 60
assert 40 < deleted_intervals[0][1] < 60

def test_preprocess_time_recorded(self):
ts = utility_functions.ts_w_data_desert(40, 60, 100)
preprocessed_ts = tsdate.preprocess_ts(ts, minimum_gap=20)
rec = json.loads(preprocessed_ts.provenance(-1).record)
assert "resources" in rec
assert rec["resources"]["elapsed_time"] >= 0
assert rec["resources"]["user_time"] >= 0
assert rec["resources"]["sys_time"] >= 0

@pytest.mark.parametrize("method", tsdate.core.estimation_methods.keys())
def test_named_methods(self, method):
ts = utility_functions.single_tree_ts_mutation_n3()
Expand Down
8 changes: 6 additions & 2 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
"then access `fit.node_posteriors()` to obtain a transposed version "
"of the matrix previously returned when ``return_posteriors=True.``"
)
self.start_time = time.time()
self.ts = ts
self.mutation_rate = mutation_rate
self.recombination_rate = recombination_rate
Expand Down Expand Up @@ -193,8 +194,6 @@ def get_modified_ts(self, result, eps):
nodes = tables.nodes
mutations = tables.mutations

if self.provenance_params is not None:
provenance.record_provenance(tables, self.name, **self.provenance_params)
# Constrain node ages for positive branch lengths
constr_timing = time.time()
nodes.time = util.constrain_ages(ts, node_mean_t, eps, self.constr_iterations)
Expand All @@ -220,6 +219,11 @@ def get_modified_ts(self, result, eps):
tables.compute_mutation_parents()
sort_timing -= time.time()
logger.info(f"Sorted tree sequence in {abs(sort_timing):.2f} seconds")
if self.provenance_params is not None:
# Note that the time recorded in provenance excludes numba compilation time
provenance.record_provenance(
tables, self.name, self.start_time, **self.provenance_params
)
return tables.tree_sequence()

def set_time_metadata(self, table, mean, var, default_schema, overwrite=False):
Expand Down
7 changes: 4 additions & 3 deletions tsdate/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_environment():
return env


def get_provenance_dict(command, **kwargs):
def get_provenance_dict(command, start_time=None, **kwargs):
"""
Returns a dictionary encoding an execution of tsdate conforming to the
tskit provenance schema.
Expand All @@ -78,14 +78,15 @@ def get_provenance_dict(command, **kwargs):
"software": {"name": "tsdate", "version": __version__},
"parameters": parameters,
"environment": get_environment(),
"resources": tskit.provenance.get_resources(start_time),
}
return document


def record_provenance(tables, command=None, **kwargs):
def record_provenance(tables, command=None, start_time=None, **kwargs):
"""
Adds provenance information to this table collection using the
tskit provenances schema.
"""
record = get_provenance_dict(command=command, **kwargs)
record = get_provenance_dict(command=command, start_time=start_time, **kwargs)
tables.provenances.add_row(record=json.dumps(record))
13 changes: 9 additions & 4 deletions tsdate/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import json
import logging
import time

import numba
import numpy as np
Expand Down Expand Up @@ -115,6 +116,7 @@ def preprocess_ts(
"""

logger.info("Beginning preprocessing")
start_time = time.time()
logger.info(f"Minimum_gap: {minimum_gap} and remove_telomeres: {remove_telomeres}")
if split_disjoint is None:
split_disjoint = True
Expand Down Expand Up @@ -184,10 +186,14 @@ def preprocess_ts(
record_provenance=False,
**kwargs,
)
if split_disjoint:
ts = split_disjoint_nodes(tables.tree_sequence(), record_provenance=False)
tables = ts.dump_tables()
if record_provenance:
provenance.record_provenance(
tables,
"preprocess_ts",
start_time=start_time,
minimum_gap=minimum_gap,
remove_telomeres=remove_telomeres,
split_disjoint=split_disjoint,
Expand All @@ -196,10 +202,7 @@ def preprocess_ts(
filter_sites=filter_sites,
delete_intervals=delete_intervals,
)
ts = tables.tree_sequence()
if split_disjoint:
ts = split_disjoint_nodes(ts, record_provenance=False)
return ts
return tables.tree_sequence()


def nodes_time_unconstrained(tree_sequence):
Expand Down Expand Up @@ -526,6 +529,7 @@ def split_disjoint_nodes(ts, *, record_provenance=None):
as ``True``).
"""
metadata_key = "unsplit_node_id"
start_time = time.time()
if record_provenance is None:
record_provenance = True
node_is_sample = np.bitwise_and(ts.nodes_flags, tskit.NODE_IS_SAMPLE).astype(bool)
Expand Down Expand Up @@ -578,6 +582,7 @@ def split_disjoint_nodes(ts, *, record_provenance=None):
provenance.record_provenance(
tables,
"split_disjoint_nodes",
start_time=start_time,
)
return tables.tree_sequence()

Expand Down

0 comments on commit 441ea2b

Please sign in to comment.