Skip to content

Commit

Permalink
Merge pull request #449 from hyanwong/md-overwrite
Browse files Browse the repository at this point in the history
Add `set_metadata` parameter
  • Loading branch information
hyanwong authored Jan 16, 2025
2 parents 3c08bec + ec0c730 commit 240593d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 40 deletions.
5 changes: 0 additions & 5 deletions tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,6 @@ def test_basic(
mu = sim_mutations_parameters["rate"]

dts = tsdate.inside_outside(ts, population_size=Ne, mutation_rate=mu)
# make sure we can read node metadata - old tsdate versions didn't set a schema
if dts.table_metadata_schemas.node.schema is None:
tables = dts.dump_tables()
tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json()
dts = tables.tree_sequence()

# Only test nonsample node times
nonsamples = np.ones(ts.num_nodes, dtype=bool)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,13 @@ def test_bad_arguments(self):
with pytest.raises(ValueError, match="Maximum number of EP iterations"):
tsdate.variational_gamma(self.ts, mutation_rate=5, max_iterations=-1)

def test_no_set_metadata(self):
assert len(self.ts.tables.mutations.metadata) == 0
assert len(self.ts.tables.nodes.metadata) == 0
ts = tsdate.variational_gamma(self.ts, mutation_rate=1e-8, set_metadata=False)
assert len(ts.tables.mutations.metadata) == 0
assert len(ts.tables.nodes.metadata) == 0

def test_no_existing_mutation_metadata(self):
# Currently only the variational_gamma method embeds mutation metadata
ts = tsdate.variational_gamma(self.ts, mutation_rate=1e-8)
Expand Down Expand Up @@ -565,3 +572,7 @@ def test_incompatible_schema_mutation_metadata(self, caplog):
dts = tsdate.variational_gamma(ts, mutation_rate=1e-8)
assert "Could not set" in caplog.text
assert len(dts.tables.mutations.metadata) == 0
assert len(dts.tables.nodes.metadata) > 0
dts = tsdate.variational_gamma(ts, mutation_rate=1e-8, set_metadata=True)
assert len(dts.tables.mutations.metadata) > 0
assert len(dts.tables.nodes.metadata) > 0
84 changes: 49 additions & 35 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ def __init__(
allow_unary=None,
record_provenance=None,
constr_iterations=None,
set_metadata=None,
progress=None,
# Deprecated params
return_posteriors=None,
):
# Set up all the generic params describe in the tsdate.date function, and define
# Set up all the generic params described in the tsdate.date function, and define
# priors if not passed-in already
if return_posteriors is not None:
raise ValueError(
Expand All @@ -106,6 +107,7 @@ def __init__(
self.recombination_rate = recombination_rate
self.return_fit = return_fit
self.return_likelihood = return_likelihood
self.set_metadata = set_metadata
self.pbar = progress
self.time_units = "generations" if time_units is None else time_units
if record_provenance is None:
Expand Down Expand Up @@ -206,7 +208,7 @@ def get_modified_ts(self, result, eps):
# Add posterior mean and variance to node/mutation metadata
meta_timing = time.time()
self.set_time_metadata(
nodes, node_mean_t, node_var_t, schemas.default_node_schema, overwrite=True
nodes, node_mean_t, node_var_t, schemas.default_node_schema
)
self.set_time_metadata(
mutations, mut_mean_t, mut_var_t, schemas.default_mutation_schema
Expand All @@ -226,40 +228,44 @@ def get_modified_ts(self, result, eps):
)
return tables.tree_sequence()

def set_time_metadata(self, table, mean, var, default_schema, overwrite=False):
if var is not None:
def set_time_metadata(self, table, mean, var, default_schema):
# Try to set metadata: if we fail, clear metadata, reset schema, and try again
def _time_md_array(table, mean, var):
# Return an array of metadata dicts, or raise an error if
# schema is None or metadata is not valid
schema = table.metadata_schema
if schema.schema is None:
raise tskit.MetadataEncodingError("No schema set")
if len(table.metadata) > 0:
md_iter = (row.metadata for row in table)
else:
md_iter = ({} for _ in range(table.num_rows)) # no decoding needed
metadata_array = []
for metadata_dict, mn, vr in zip(md_iter, mean, var):
metadata_dict.update((("mn", mn), ("vr", vr)))
metadata_array.append(schema.validate_and_encode_row(metadata_dict))
return metadata_array

if self.set_metadata is False or var is None:
return # no md to set (e.g. outside maximization method)
assert len(mean) == len(var) == table.num_rows
try:
table.packset_metadata(_time_md_array(table, mean, var))
except (tskit.MetadataEncodingError, tskit.MetadataValidationError) as e:
table_name = type(table).__name__
assert len(mean) == len(var) == table.num_rows
if table.metadata_schema.schema is None or overwrite:
if len(table.metadata) == 0 or overwrite:
table.metadata_schema = default_schema
md_iter = ({} for _ in range(table.num_rows))
# For speed, assume we don't need to validate
encoder = table.metadata_schema.encode_row
logger.info(f"Set metadata schema on {table_name}")
else:
if len(table.metadata) > 0 or table.metadata_schema.schema is not None:
if not self.set_metadata:
logger.warning(
f"Could not set metadata on {table_name}: "
"data already exists with no schema"
f"Could not set time metadata on {table_name} "
f"(force this by specifying `set_metadata=True`): {e}"
)
return
else:
md_iter = (
table.metadata_schema.decode_row(md)
for md in tskit.unpack_bytes(table.metadata, table.metadata_offset)
)
encoder = table.metadata_schema.validate_and_encode_row
# TODO: could try to add to the existing schema if it's compatible
metadata_array = []
try:
# wrap entire loop in try/except so metadata is either all set or not
for metadata_dict, mn, vr in zip(md_iter, mean, var):
metadata_dict.update((("mn", mn), ("vr", vr)))
# validate and replace
metadata_array.append(encoder(metadata_dict))
table.packset_metadata(metadata_array)
except tskit.MetadataValidationError as e:
logger.warning(f"Could not set time metadata in {table_name}: {e}")
else:
logger.info(f"Clearing metadata from {table_name}")
table.drop_metadata()
logger.info(f"Setting metadata schema on {table_name}")
table.metadata_schema = default_schema
table.packset_metadata(_time_md_array(table, mean, var))

def parse_result(self, result, epsilon):
# Construct the tree sequence to return and add other stuff we might want to
Expand Down Expand Up @@ -876,6 +882,7 @@ def date(
time_units=None,
method=None,
constr_iterations=None,
set_metadata=None,
return_fit=None,
return_likelihood=None,
allow_unary=None,
Expand Down Expand Up @@ -922,12 +929,18 @@ def date(
:param string method: What estimation method to use. See
:data:`~tsdate.core.estimation_methods` for possible values.
If ``None`` (default) the "variational_gamma" method is currently chosen.
:param bool return_fit: If ``True``, instead of returning just a dated tree
sequence, return a tuple of ``(dated_ts, fit)``.
Default: None, treated as False.
:param int constr_iterations: The maximum number of constrained least
squares iterations to use prior to forcing positive branch lengths.
Default: None, treated as 0.
:param bool set_metadata: Should unconstrained times be stored in table metadata,
in the form of ``"mn"`` (mean) and ``"vr"`` (variance) fields? If ``False``,
do not store metadata. If ``True``, force metadata to be set (if no schema
is set or the schema is incompatible, clear existing metadata in the relevant
tables and set a new schema). If ``None`` (default), only set metadata if
the existing schema allows (this may overwrite existing ``"mn"`` and ``"vr"``
fields) or if existing metadata is empty, otherwise issue a warning.
:param bool return_fit: If ``True``, instead of just a dated tree sequence,
return a tuple of ``(dated_ts, fit)``. Default: None, treated as False.
:param bool return_likelihood: If ``True``, return the log marginal likelihood
from the inside algorithm in addition to the dated tree sequence. If
``return_fit`` is also ``True``, then the marginal likelihood
Expand Down Expand Up @@ -963,6 +976,7 @@ def date(
return_fit=return_fit,
return_likelihood=return_likelihood,
allow_unary=allow_unary,
set_metadata=set_metadata,
record_provenance=record_provenance,
**kwargs,
)

0 comments on commit 240593d

Please sign in to comment.