Skip to content

Commit

Permalink
Allow unary
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Nov 10, 2024
1 parent 9fa2bfc commit 40d6231
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 45 deletions.
4 changes: 1 addition & 3 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,5 @@ increase or decrease its stringency.

:::{note}
If unary regions are *correctly* estimated, they can help improve dating slightly.
There is therefore a specific route to date a tree sequence containing locally unary
nodes. For example, for discrete time methods, you can use the `allow_unary` option
when {ref}`building a prior<sec_priors>`.
You can set the `allow_unary=True` option to run tsdate on such tree sequences.
:::
68 changes: 36 additions & 32 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import pytest
import tsinfer
import tskit
import utility_functions
import utility_functions as util

import tsdate
from tsdate.demography import PopulationSizeHistory
Expand All @@ -50,24 +50,24 @@ class TestPrebuilt:
"""

def test_invalid_method_failure(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
with pytest.raises(ValueError, match="method must be one of"):
tsdate.date(ts, population_size=1, mutation_rate=None, method="foo")

def test_no_mutations_failure(self):
ts = utility_functions.single_tree_ts_n2()
ts = util.single_tree_ts_n2()
with pytest.raises(ValueError, match="No mutations present"):
tsdate.variational_gamma(ts, mutation_rate=1)

def test_no_population_size(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Must specify population size"):
tsdate.inside_outside(ts, mutation_rate=None)

def test_no_mutation(self):
for ts in (
utility_functions.two_tree_mutation_ts(),
utility_functions.single_tree_ts_mutation_n3(),
util.two_tree_mutation_ts(),
util.single_tree_ts_mutation_n3(),
):
with pytest.raises(ValueError, match="method requires mutation rate"):
tsdate.date(
Expand All @@ -86,53 +86,59 @@ def test_no_mutation(self):
)

def test_not_needed_population_size(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
prior = tsdate.build_prior_grid(ts, population_size=1, timepoints=10)
with pytest.raises(ValueError, match="Cannot specify population size"):
tsdate.inside_outside(ts, population_size=1, mutation_rate=None, priors=prior)

def test_bad_population_size(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
for Ne in [0, -1]:
with pytest.raises(ValueError, match="greater than 0"):
tsdate.inside_outside(ts, mutation_rate=None, population_size=Ne)

def test_both_ne_and_population_size_specified(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Only provide one of Ne"):
tsdate.inside_outside(
ts, mutation_rate=1, population_size=PopulationSizeHistory(1), Ne=1
)
tsdate.inside_outside(ts, mutation_rate=1, Ne=PopulationSizeHistory(1))

def test_inside_outside_dangling_failure(self):
ts = utility_functions.single_tree_ts_n2_dangling()
ts = util.single_tree_ts_n2_dangling()
with pytest.raises(ValueError, match="simplified"):
tsdate.inside_outside(ts, mutation_rate=None, population_size=1)

def test_variational_gamma_dangling(self):
# Dangling nodes are fine for the variational gamma method
ts = utility_functions.single_tree_ts_n2_dangling()
ts = util.single_tree_ts_n2_dangling()
ts = msprime.sim_mutations(ts, rate=2, random_seed=1)
assert ts.num_mutations > 1
tsdate.variational_gamma(ts, mutation_rate=2)

def test_inside_outside_unary_failure(self):
ts = utility_functions.single_tree_ts_with_unary()
ts = util.single_tree_ts_with_unary()
with pytest.raises(ValueError, match="unary"):
tsdate.inside_outside(ts, mutation_rate=None, population_size=1)

@pytest.mark.skip("V_gamma should fail with unary nodes, but doesn't currently")
def test_variational_gamma_unary_failure(self):
ts = utility_functions.single_tree_ts_with_unary()
@pytest.mark.parametrize("method", tsdate.estimation_methods.keys())
@pytest.mark.parametrize(
"ts", [util.single_tree_ts_with_unary(), util.two_tree_ts_with_unary_n3()]
)
def test_allow_unary(self, method, ts):
Ne = None if method == "variational_gamma" else 1
ts = msprime.sim_mutations(ts, rate=1, random_seed=1)
with pytest.raises(ValueError, match="unary"):
tsdate.variational_gamma(ts, mutation_rate=1)
tsdate.date(ts, method=method, population_size=Ne, mutation_rate=1)
tsdate.date(
ts, method=method, population_size=Ne, mutation_rate=1, allow_unary=True
)

@pytest.mark.parametrize("probability_space", [LOG_GRID, LIN_GRID])
@pytest.mark.parametrize("mu", [None, 1])
def test_fails_with_recombination(self, probability_space, mu):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
with pytest.raises(NotImplementedError):
tsdate.inside_outside(
ts,
Expand All @@ -143,27 +149,27 @@ def test_fails_with_recombination(self, probability_space, mu):
)

def test_default_time_units(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts = tsdate.date(ts, mutation_rate=1)
assert ts.time_units == "generations"

def test_default_alternative_time_units(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts = tsdate.date(ts, mutation_rate=1, time_units="years")
assert ts.time_units == "years"

def test_deprecated_return_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
with pytest.raises(ValueError, match="deprecated"):
tsdate.date(ts, return_posteriors=True, mutation_rate=1)

def test_return_fit(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
_, fit = tsdate.date(ts, return_fit=True, mutation_rate=1)
assert hasattr(fit, "node_posteriors")

def test_no_maximization_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
_, fit = tsdate.date(
ts,
population_size=1,
Expand All @@ -175,7 +181,7 @@ def test_no_maximization_posteriors(self):
fit.node_posteriors()

def test_discretised_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts, fit = tsdate.inside_outside(
ts, mutation_rate=1, population_size=1, return_fit=True
)
Expand All @@ -190,7 +196,7 @@ def test_discretised_posteriors(self):
assert np.isclose(np.sum(nd_vals), 1)

def test_variational_node_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts, fit = tsdate.date(
ts,
mutation_rate=1e-2,
Expand All @@ -206,7 +212,7 @@ def test_variational_node_posteriors(self):
assert np.isclose(nd.metadata["vr"], vr)

def test_variational_mutation_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts, fit = tsdate.date(
ts,
mutation_rate=1e-2,
Expand All @@ -223,7 +229,7 @@ def test_variational_mutation_posteriors(self):

def test_variational_mean_edge_logconst(self):
# This should give a guide to EP convergence
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
ts, fit = tsdate.date(
ts,
mutation_rate=1e-2,
Expand All @@ -239,7 +245,7 @@ def test_variational_mean_edge_logconst(self):
assert np.all(obs[5:] == test_vals[-1])

def test_marginal_likelihood(self):
ts = utility_functions.two_tree_mutation_ts()
ts = util.two_tree_mutation_ts()
_, _, marg_lik = tsdate.inside_outside(
ts,
mutation_rate=1,
Expand All @@ -253,8 +259,8 @@ def test_marginal_likelihood(self):
assert marg_lik == marg_lik_again

def test_intervals(self):
ts = utility_functions.two_tree_ts()
long_ts = utility_functions.two_tree_ts_extra_length()
ts = util.two_tree_ts()
long_ts = util.two_tree_ts_extra_length()
keep_ts = long_ts.keep_intervals([[0.0, 1.0]])
del_ts = long_ts.delete_intervals([[1.0, 1.5]])
dat_ts = tsdate.inside_outside(ts, mutation_rate=1, population_size=1)
Expand Down Expand Up @@ -414,9 +420,7 @@ def test_truncated_ts(self):
mutation_rate=mu,
random_seed=12,
)
truncated_ts = utility_functions.truncate_ts_samples(
ts, average_span=200, random_seed=123
)
truncated_ts = util.truncate_ts_samples(ts, average_span=200, random_seed=123)
dated_ts = tsdate.date(truncated_ts, population_size=Ne, mutation_rate=mu)
# We should ideally test whether *haplotypes* are the same here
# in case allele encoding has changed. But haplotypes() doesn't currently
Expand Down
24 changes: 18 additions & 6 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
priors=None,
return_likelihood=None,
return_fit=None,
allow_unary=None,
record_provenance=None,
constr_iterations=None,
progress=None,
Expand Down Expand Up @@ -140,6 +141,8 @@ def __init__(
)
self.constr_iterations = constr_iterations

self.allow_unary = False if allow_unary is None else allow_unary

if self.prior_grid_func_name is None:
if priors is not None:
raise ValueError(f"Priors are not used for method {self.name}")
Expand All @@ -157,7 +160,11 @@ def __init__(
# greater than DEFAULT_APPROX_PRIOR_SIZE samples
approx = ts.num_samples > prior.DEFAULT_APPROX_PRIOR_SIZE
self.priors = mk_prior(
ts, Ne, approximate_priors=approx, progress=progress
ts,
Ne,
approximate_priors=approx,
allow_unary=self.allow_unary,
progress=progress,
)
else:
logger.info("Using user-specified priors")
Expand Down Expand Up @@ -444,6 +451,7 @@ def run(
fit_obj = variational.ExpectationPropagation(
self.ts,
mutation_rate=self.mutation_rate,
allow_unary=self.allow_unary,
singletons_phased=singletons_phased,
)
fit_obj.infer(
Expand Down Expand Up @@ -552,7 +560,7 @@ def maximization(
"linear" space (fast, may overflow). Default: None treated as"logarithmic"
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
function, notably ``mutation_rate``, and ``population_size`` or ``priors``.
Further arguments include ``time_units``, ``progress``, and
Further arguments include ``time_units``, ``progress``, ``allow_unary`` and
``record_provenance``. The additional arguments ``return_fit`` and
``return_likelihood`` can be used to return additional information (see below).
:return:
Expand Down Expand Up @@ -685,7 +693,7 @@ def inside_outside(
"linear" space (fast, may overflow). Default: "logarithmic"
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
function, notably ``mutation_rate``, and ``population_size`` or ``priors``.
Further arguments include ``time_units``, ``progress``, and
Further arguments include ``time_units``, ``progress``, ``allow_unary`` and
``record_provenance``. The additional arguments ``return_fit`` and
``return_likelihood`` can be used to return additional information (see below).
:return:
Expand Down Expand Up @@ -784,9 +792,9 @@ def variational_gamma(
length are approximately equal, which gives unbiased estimates when there
are polytomies. Default ``False``.
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
function, including ``time_units``, ``progress``, and ``record_provenance``.
The arguments ``return_fit`` and ``return_likelihood`` can be
used to return additional information (see below).
function, including ``time_units``, ``progress``, ``allow_unary`` and
``record_provenance``. The arguments ``return_fit`` and ``return_likelihood``
can be used to return additional information (see below).
:return:
- **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with
updated node times based on the posterior mean, corrected where necessary to
Expand Down Expand Up @@ -866,6 +874,7 @@ def date(
constr_iterations=None,
return_fit=None,
return_likelihood=None,
allow_unary=None,
progress=None,
record_provenance=True,
# Other kwargs documented in the functions for each specific estimation-method
Expand Down Expand Up @@ -919,6 +928,8 @@ def date(
from the inside algorithm in addition to the dated tree sequence. If
``return_fit`` is also ``True``, then the marginal likelihood
will be the last element of the tuple. Default: None, treated as False.
:param bool allow_unary: Allow nodes that are "locally unary" (i.e. have only
one child in one or more local trees). Default: None, treated as False.
:param bool progress: Show a progress bar. Default: None, treated as False.
:param bool record_provenance: Should the tsdate command be appended to the
provenence information in the returned tree sequence?
Expand Down Expand Up @@ -947,6 +958,7 @@ def date(
constr_iterations=constr_iterations,
return_fit=return_fit,
return_likelihood=return_likelihood,
allow_unary=allow_unary,
record_provenance=record_provenance,
**kwargs,
)
8 changes: 4 additions & 4 deletions tsdate/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ def _check_valid_constraints(constraints, edges_parent, edges_child):
)

@staticmethod
def _check_valid_inputs(ts, mutation_rate):
def _check_valid_inputs(ts, mutation_rate, allow_unary):
if not mutation_rate > 0.0:
raise ValueError("Mutation rate must be positive")
if contains_unary_nodes(ts):
if not allow_unary and contains_unary_nodes(ts):
raise ValueError("Tree sequence contains unary nodes, simplify first")

@staticmethod
Expand All @@ -185,7 +185,7 @@ def _check_valid_state(
posterior_check += node_factors[:, CONSTRNT]
np.testing.assert_allclose(posterior_check, posterior)

def __init__(self, ts, *, mutation_rate, singletons_phased=True):
def __init__(self, ts, *, mutation_rate, allow_unary=None, singletons_phased=True):
"""
Initialize an expectation propagation algorithm for dating nodes
in a tree sequence.
Expand All @@ -202,7 +202,7 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True):
time unit.
"""

self._check_valid_inputs(ts, mutation_rate)
self._check_valid_inputs(ts, mutation_rate, allow_unary)
self.edge_parents = ts.edges_parent
self.edge_children = ts.edges_child

Expand Down

0 comments on commit 40d6231

Please sign in to comment.