From e1568f3d8e7388a820502692e8f43f66c55f65c8 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Sat, 14 Dec 2024 09:39:33 +0000 Subject: [PATCH 1/2] No need to worry about older tsdate versions --- tests/test_accuracy.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_accuracy.py b/tests/test_accuracy.py index 82324751..57aeb099 100644 --- a/tests/test_accuracy.py +++ b/tests/test_accuracy.py @@ -109,11 +109,6 @@ def test_basic( mu = sim_mutations_parameters["rate"] dts = tsdate.inside_outside(ts, population_size=Ne, mutation_rate=mu) - # make sure we can read node metadata - old tsdate versions didn't set a schema - if dts.table_metadata_schemas.node.schema is None: - tables = dts.dump_tables() - tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() - dts = tables.tree_sequence() # Only test nonsample node times nonsamples = np.ones(ts.num_nodes, dtype=bool) From ec0c73093a2435bc5cafe3c7e94fdc08512799c3 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Sat, 14 Dec 2024 09:40:29 +0000 Subject: [PATCH 2/2] Implement a set_metadata parameter --- tests/test_inference.py | 11 ++++++ tsdate/core.py | 84 ++++++++++++++++++++++++----------------- 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index cec220ad..d96ece8b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -478,6 +478,13 @@ def test_bad_arguments(self): with pytest.raises(ValueError, match="Maximum number of EP iterations"): tsdate.variational_gamma(self.ts, mutation_rate=5, max_iterations=-1) + def test_no_set_metadata(self): + assert len(self.ts.tables.mutations.metadata) == 0 + assert len(self.ts.tables.nodes.metadata) == 0 + ts = tsdate.variational_gamma(self.ts, mutation_rate=1e-8, set_metadata=False) + assert len(ts.tables.mutations.metadata) == 0 + assert len(ts.tables.nodes.metadata) == 0 + def test_no_existing_mutation_metadata(self): # Currently only the variational_gamma method embeds mutation metadata ts = tsdate.variational_gamma(self.ts, mutation_rate=1e-8) @@ -565,3 +572,7 @@ def test_incompatible_schema_mutation_metadata(self, caplog): dts = tsdate.variational_gamma(ts, mutation_rate=1e-8) assert "Could not set" in caplog.text assert len(dts.tables.mutations.metadata) == 0 + assert len(dts.tables.nodes.metadata) > 0 + dts = tsdate.variational_gamma(ts, mutation_rate=1e-8, set_metadata=True) + assert len(dts.tables.mutations.metadata) > 0 + assert len(dts.tables.nodes.metadata) > 0 diff --git a/tsdate/core.py b/tsdate/core.py index f6c796c4..f4ecac11 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -87,11 +87,12 @@ def __init__( allow_unary=None, record_provenance=None, constr_iterations=None, + set_metadata=None, progress=None, # Deprecated params return_posteriors=None, ): - # Set up all the generic params describe in the tsdate.date function, and define + # Set up all the generic params described in the tsdate.date function, and define # priors if not passed-in already if return_posteriors is not None: raise ValueError( @@ -106,6 +107,7 @@ def __init__( self.recombination_rate = recombination_rate self.return_fit = return_fit self.return_likelihood = return_likelihood + self.set_metadata = set_metadata self.pbar = progress self.time_units = "generations" if time_units is None else time_units if record_provenance is None: @@ -206,7 +208,7 @@ def get_modified_ts(self, result, eps): # Add posterior mean and variance to node/mutation metadata meta_timing = time.time() self.set_time_metadata( - nodes, node_mean_t, node_var_t, schemas.default_node_schema, overwrite=True + nodes, node_mean_t, node_var_t, schemas.default_node_schema ) self.set_time_metadata( mutations, mut_mean_t, mut_var_t, schemas.default_mutation_schema @@ -226,40 +228,44 @@ def get_modified_ts(self, result, eps): ) return tables.tree_sequence() - def set_time_metadata(self, table, mean, var, default_schema, overwrite=False): - if var is not None: + def set_time_metadata(self, table, mean, var, default_schema): + # Try to set metadata: if we fail, clear metadata, reset schema, and try again + def _time_md_array(table, mean, var): + # Return an array of metadata dicts, or raise an error if + # schema is None or metadata is not valid + schema = table.metadata_schema + if schema.schema is None: + raise tskit.MetadataEncodingError("No schema set") + if len(table.metadata) > 0: + md_iter = (row.metadata for row in table) + else: + md_iter = ({} for _ in range(table.num_rows)) # no decoding needed + metadata_array = [] + for metadata_dict, mn, vr in zip(md_iter, mean, var): + metadata_dict.update((("mn", mn), ("vr", vr))) + metadata_array.append(schema.validate_and_encode_row(metadata_dict)) + return metadata_array + + if self.set_metadata is False or var is None: + return # no md to set (e.g. outside maximization method) + assert len(mean) == len(var) == table.num_rows + try: + table.packset_metadata(_time_md_array(table, mean, var)) + except (tskit.MetadataEncodingError, tskit.MetadataValidationError) as e: table_name = type(table).__name__ - assert len(mean) == len(var) == table.num_rows - if table.metadata_schema.schema is None or overwrite: - if len(table.metadata) == 0 or overwrite: - table.metadata_schema = default_schema - md_iter = ({} for _ in range(table.num_rows)) - # For speed, assume we don't need to validate - encoder = table.metadata_schema.encode_row - logger.info(f"Set metadata schema on {table_name}") - else: + if len(table.metadata) > 0 or table.metadata_schema.schema is not None: + if not self.set_metadata: logger.warning( - f"Could not set metadata on {table_name}: " - "data already exists with no schema" + f"Could not set time metadata on {table_name} " + f"(force this by specifying `set_metadata=True`): {e}" ) return - else: - md_iter = ( - table.metadata_schema.decode_row(md) - for md in tskit.unpack_bytes(table.metadata, table.metadata_offset) - ) - encoder = table.metadata_schema.validate_and_encode_row - # TODO: could try to add to the existing schema if it's compatible - metadata_array = [] - try: - # wrap entire loop in try/except so metadata is either all set or not - for metadata_dict, mn, vr in zip(md_iter, mean, var): - metadata_dict.update((("mn", mn), ("vr", vr))) - # validate and replace - metadata_array.append(encoder(metadata_dict)) - table.packset_metadata(metadata_array) - except tskit.MetadataValidationError as e: - logger.warning(f"Could not set time metadata in {table_name}: {e}") + else: + logger.info(f"Clearing metadata from {table_name}") + table.drop_metadata() + logger.info(f"Setting metadata schema on {table_name}") + table.metadata_schema = default_schema + table.packset_metadata(_time_md_array(table, mean, var)) def parse_result(self, result, epsilon): # Construct the tree sequence to return and add other stuff we might want to @@ -876,6 +882,7 @@ def date( time_units=None, method=None, constr_iterations=None, + set_metadata=None, return_fit=None, return_likelihood=None, allow_unary=None, @@ -922,12 +929,18 @@ def date( :param string method: What estimation method to use. See :data:`~tsdate.core.estimation_methods` for possible values. If ``None`` (default) the "variational_gamma" method is currently chosen. - :param bool return_fit: If ``True``, instead of returning just a dated tree - sequence, return a tuple of ``(dated_ts, fit)``. - Default: None, treated as False. :param int constr_iterations: The maximum number of constrained least squares iterations to use prior to forcing positive branch lengths. Default: None, treated as 0. + :param bool set_metadata: Should unconstrained times be stored in table metadata, + in the form of ``"mn"`` (mean) and ``"vr"`` (variance) fields? If ``False``, + do not store metadata. If ``True``, force metadata to be set (if no schema + is set or the schema is incompatible, clear existing metadata in the relevant + tables and set a new schema). If ``None`` (default), only set metadata if + the existing schema allows (this may overwrite existing ``"mn"`` and ``"vr"`` + fields) or if existing metadata is empty, otherwise issue a warning. + :param bool return_fit: If ``True``, instead of just a dated tree sequence, + return a tuple of ``(dated_ts, fit)``. Default: None, treated as False. :param bool return_likelihood: If ``True``, return the log marginal likelihood from the inside algorithm in addition to the dated tree sequence. If ``return_fit`` is also ``True``, then the marginal likelihood @@ -963,6 +976,7 @@ def date( return_fit=return_fit, return_likelihood=return_likelihood, allow_unary=allow_unary, + set_metadata=set_metadata, record_provenance=record_provenance, **kwargs, )