Skip to content

Commit

Permalink
Merge pull request #437 from hyanwong/extras
Browse files Browse the repository at this point in the history
Allow unary nodes
  • Loading branch information
hyanwong authored Nov 10, 2024
2 parents bd9b072 + 40d6231 commit c05676f
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 46 deletions.
6 changes: 2 additions & 4 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ command-line interface. See {ref}`sec_cli` for more details.

### Numerical stability and preprocessing

Numerical stability issues witll manifest themselves by raising an error when dating.
Numerical stability issues will manifest themselves by raising an error when dating.
They are usually caused by "bad" tree sequences (i.e.
pathological combinations of topologies and mutations). These can be caused,
for example, by long deep branches with very few mutations, such as samples attaching directly
Expand All @@ -343,7 +343,5 @@ increase or decrease its stringency.

:::{note}
If unary regions are *correctly* estimated, they can help improve dating slightly.
There is therefore a specific route to date a tree sequence containing locally unary
nodes. For example, for discrete time methods, you can use the `allow_unary` option
when {ref}`building a prior<sec_priors>`.
You can set the `allow_unary=True` option to run tsdate on such tree sequences.
:::
68 changes: 36 additions & 32 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import pytest
import tsinfer
import tskit
import utility_functions
import utility_functions as util

import tsdate
from tsdate.demography import PopulationSizeHistory
Expand All @@ -50,24 +50,24 @@ class TestPrebuilt:
"""

def test_invalid_method_failure(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
with pytest.raises(ValueError, match="method must be one of"):
tsdate.date(ts, population_size=1, mutation_rate=None, method="foo")

def test_no_mutations_failure(self):
ts = utility_functions.single_tree_ts_n2()
ts = util.single_tree_ts_n2()
with pytest.raises(ValueError, match="No mutations present"):
tsdate.variational_gamma(ts, mutation_rate=1)

def test_no_population_size(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Must specify population size"):
tsdate.inside_outside(ts, mutation_rate=None)

def test_no_mutation(self):
for ts in (
utility_functions.two_tree_mutation_ts(),
utility_functions.single_tree_ts_mutation_n3(),
util.two_tree_mutation_ts(),
util.single_tree_ts_mutation_n3(),
):
with pytest.raises(ValueError, match="method requires mutation rate"):
tsdate.date(
Expand All @@ -86,53 +86,59 @@ def test_no_mutation(self):
)

def test_not_needed_population_size(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
prior = tsdate.build_prior_grid(ts, population_size=1, timepoints=10)
with pytest.raises(ValueError, match="Cannot specify population size"):
tsdate.inside_outside(ts, population_size=1, mutation_rate=None, priors=prior)

def test_bad_population_size(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
for Ne in [0, -1]:
with pytest.raises(ValueError, match="greater than 0"):
tsdate.inside_outside(ts, mutation_rate=None, population_size=Ne)

def test_both_ne_and_population_size_specified(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Only provide one of Ne"):
tsdate.inside_outside(
ts, mutation_rate=1, population_size=PopulationSizeHistory(1), Ne=1
)
tsdate.inside_outside(ts, mutation_rate=1, Ne=PopulationSizeHistory(1))

def test_inside_outside_dangling_failure(self):
ts = utility_functions.single_tree_ts_n2_dangling()
ts = util.single_tree_ts_n2_dangling()
with pytest.raises(ValueError, match="simplified"):
tsdate.inside_outside(ts, mutation_rate=None, population_size=1)

def test_variational_gamma_dangling(self):
# Dangling nodes are fine for the variational gamma method
ts = utility_functions.single_tree_ts_n2_dangling()
ts = util.single_tree_ts_n2_dangling()
ts = msprime.sim_mutations(ts, rate=2, random_seed=1)
assert ts.num_mutations > 1
tsdate.variational_gamma(ts, mutation_rate=2)

def test_inside_outside_unary_failure(self):
ts = utility_functions.single_tree_ts_with_unary()
ts = util.single_tree_ts_with_unary()
with pytest.raises(ValueError, match="unary"):
tsdate.inside_outside(ts, mutation_rate=None, population_size=1)

@pytest.mark.skip("V_gamma should fail with unary nodes, but doesn't currently")
def test_variational_gamma_unary_failure(self):
ts = utility_functions.single_tree_ts_with_unary()
@pytest.mark.parametrize("method", tsdate.estimation_methods.keys())
@pytest.mark.parametrize(
"ts", [util.single_tree_ts_with_unary(), util.two_tree_ts_with_unary_n3()]
)
def test_allow_unary(self, method, ts):
Ne = None if method == "variational_gamma" else 1
ts = msprime.sim_mutations(ts, rate=1, random_seed=1)
with pytest.raises(ValueError, match="unary"):
tsdate.variational_gamma(ts, mutation_rate=1)
tsdate.date(ts, method=method, population_size=Ne, mutation_rate=1)
tsdate.date(
ts, method=method, population_size=Ne, mutation_rate=1, allow_unary=True
)

@pytest.mark.parametrize("probability_space", [LOG_GRID, LIN_GRID])
@pytest.mark.parametrize("mu", [None, 1])
def test_fails_with_recombination(self, probability_space, mu):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
with pytest.raises(NotImplementedError):
tsdate.inside_outside(
ts,
Expand All @@ -143,27 +149,27 @@ def test_fails_with_recombination(self, probability_space, mu):
)

def test_default_time_units(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts = tsdate.date(ts, mutation_rate=1)
assert ts.time_units == "generations"

def test_default_alternative_time_units(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts = tsdate.date(ts, mutation_rate=1, time_units="years")
assert ts.time_units == "years"

def test_deprecated_return_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
with pytest.raises(ValueError, match="deprecated"):
tsdate.date(ts, return_posteriors=True, mutation_rate=1)

def test_return_fit(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
_, fit = tsdate.date(ts, return_fit=True, mutation_rate=1)
assert hasattr(fit, "node_posteriors")

def test_no_maximization_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
_, fit = tsdate.date(
ts,
population_size=1,
Expand All @@ -175,7 +181,7 @@ def test_no_maximization_posteriors(self):
fit.node_posteriors()

def test_discretised_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts, fit = tsdate.inside_outside(
ts, mutation_rate=1, population_size=1, return_fit=True
)
Expand All @@ -190,7 +196,7 @@ def test_discretised_posteriors(self):
assert np.isclose(np.sum(nd_vals), 1)

def test_variational_node_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts, fit = tsdate.date(
ts,
mutation_rate=1e-2,
Expand All @@ -206,7 +212,7 @@ def test_variational_node_posteriors(self):
assert np.isclose(nd.metadata["vr"], vr)

def test_variational_mutation_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts, fit = tsdate.date(
ts,
mutation_rate=1e-2,
Expand All @@ -223,7 +229,7 @@ def test_variational_mutation_posteriors(self):

def test_variational_mean_edge_logconst(self):
# This should give a guide to EP convergence
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts, fit = tsdate.date(
ts,
mutation_rate=1e-2,
Expand All @@ -239,7 +245,7 @@ def test_variational_mean_edge_logconst(self):
assert np.all(obs[5:] == test_vals[-1])

def test_marginal_likelihood(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
_, _, marg_lik = tsdate.inside_outside(
ts,
mutation_rate=1,
Expand All @@ -253,8 +259,8 @@ def test_marginal_likelihood(self):
assert marg_lik == marg_lik_again

def test_intervals(self):
ts = utility_functions.two_tree_ts()
long_ts = utility_functions.two_tree_ts_extra_length()
ts = util.two_tree_ts()
long_ts = util.two_tree_ts_extra_length()
keep_ts = long_ts.keep_intervals([[0.0, 1.0]])
del_ts = long_ts.delete_intervals([[1.0, 1.5]])
dat_ts = tsdate.inside_outside(ts, mutation_rate=1, population_size=1)
Expand Down Expand Up @@ -414,9 +420,7 @@ def test_truncated_ts(self):
mutation_rate=mu,
random_seed=12,
)
truncated_ts = utility_functions.truncate_ts_samples(
ts, average_span=200, random_seed=123
)
truncated_ts = util.truncate_ts_samples(ts, average_span=200, random_seed=123)
dated_ts = tsdate.date(truncated_ts, population_size=Ne, mutation_rate=mu)
# We should ideally test whether *haplotypes* are the same here
# in case allele encoding has changed. But haplotypes() doesn't currently
Expand Down
24 changes: 18 additions & 6 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
priors=None,
return_likelihood=None,
return_fit=None,
allow_unary=None,
record_provenance=None,
constr_iterations=None,
progress=None,
Expand Down Expand Up @@ -140,6 +141,8 @@ def __init__(
)
self.constr_iterations = constr_iterations

self.allow_unary = False if allow_unary is None else allow_unary

if self.prior_grid_func_name is None:
if priors is not None:
raise ValueError(f"Priors are not used for method {self.name}")
Expand All @@ -157,7 +160,11 @@ def __init__(
# greater than DEFAULT_APPROX_PRIOR_SIZE samples
approx = ts.num_samples > prior.DEFAULT_APPROX_PRIOR_SIZE
self.priors = mk_prior(
ts, Ne, approximate_priors=approx, progress=progress
ts,
Ne,
approximate_priors=approx,
allow_unary=self.allow_unary,
progress=progress,
)
else:
logger.info("Using user-specified priors")
Expand Down Expand Up @@ -444,6 +451,7 @@ def run(
fit_obj = variational.ExpectationPropagation(
self.ts,
mutation_rate=self.mutation_rate,
allow_unary=self.allow_unary,
singletons_phased=singletons_phased,
)
fit_obj.infer(
Expand Down Expand Up @@ -552,7 +560,7 @@ def maximization(
"linear" space (fast, may overflow). Default: None treated as"logarithmic"
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
function, notably ``mutation_rate``, and ``population_size`` or ``priors``.
Further arguments include ``time_units``, ``progress``, and
Further arguments include ``time_units``, ``progress``, ``allow_unary`` and
``record_provenance``. The additional arguments ``return_fit`` and
``return_likelihood`` can be used to return additional information (see below).
:return:
Expand Down Expand Up @@ -685,7 +693,7 @@ def inside_outside(
"linear" space (fast, may overflow). Default: "logarithmic"
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
function, notably ``mutation_rate``, and ``population_size`` or ``priors``.
Further arguments include ``time_units``, ``progress``, and
Further arguments include ``time_units``, ``progress``, ``allow_unary`` and
``record_provenance``. The additional arguments ``return_fit`` and
``return_likelihood`` can be used to return additional information (see below).
:return:
Expand Down Expand Up @@ -784,9 +792,9 @@ def variational_gamma(
length are approximately equal, which gives unbiased estimates when there
are polytomies. Default ``False``.
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
function, including ``time_units``, ``progress``, and ``record_provenance``.
The arguments ``return_fit`` and ``return_likelihood`` can be
used to return additional information (see below).
function, including ``time_units``, ``progress``, ``allow_unary`` and
``record_provenance``. The arguments ``return_fit`` and ``return_likelihood``
can be used to return additional information (see below).
:return:
- **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with
updated node times based on the posterior mean, corrected where necessary to
Expand Down Expand Up @@ -866,6 +874,7 @@ def date(
constr_iterations=None,
return_fit=None,
return_likelihood=None,
allow_unary=None,
progress=None,
record_provenance=True,
# Other kwargs documented in the functions for each specific estimation-method
Expand Down Expand Up @@ -919,6 +928,8 @@ def date(
from the inside algorithm in addition to the dated tree sequence. If
``return_fit`` is also ``True``, then the marginal likelihood
will be the last element of the tuple. Default: None, treated as False.
:param bool allow_unary: Allow nodes that are "locally unary" (i.e. have only
one child in one or more local trees). Default: None, treated as False.
:param bool progress: Show a progress bar. Default: None, treated as False.
:param bool record_provenance: Should the tsdate command be appended to the
provenence information in the returned tree sequence?
Expand Down Expand Up @@ -947,6 +958,7 @@ def date(
constr_iterations=constr_iterations,
return_fit=return_fit,
return_likelihood=return_likelihood,
allow_unary=allow_unary,
record_provenance=record_provenance,
**kwargs,
)
8 changes: 4 additions & 4 deletions tsdate/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ def _check_valid_constraints(constraints, edges_parent, edges_child):
)

@staticmethod
def _check_valid_inputs(ts, mutation_rate):
def _check_valid_inputs(ts, mutation_rate, allow_unary):
if not mutation_rate > 0.0:
raise ValueError("Mutation rate must be positive")
if contains_unary_nodes(ts):
if not allow_unary and contains_unary_nodes(ts):
raise ValueError("Tree sequence contains unary nodes, simplify first")

@staticmethod
Expand All @@ -185,7 +185,7 @@ def _check_valid_state(
posterior_check += node_factors[:, CONSTRNT]
np.testing.assert_allclose(posterior_check, posterior)

def __init__(self, ts, *, mutation_rate, singletons_phased=True):
def __init__(self, ts, *, mutation_rate, allow_unary=None, singletons_phased=True):
"""
Initialize an expectation propagation algorithm for dating nodes
in a tree sequence.
Expand All @@ -202,7 +202,7 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True):
time unit.
"""

self._check_valid_inputs(ts, mutation_rate)
self._check_valid_inputs(ts, mutation_rate, allow_unary)
self.edge_parents = ts.edges_parent
self.edge_children = ts.edges_child

Expand Down

0 comments on commit c05676f

Please sign in to comment.