From 094e064a96e6826b35e4fe1784f78fb0f5c5d60b Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Thu, 16 May 2024 12:52:20 -0700 Subject: [PATCH 01/29] Infrastructure for singleton phasing --- tests/test_approximations.py | 433 +++++++++++++++++++++++++++-------- tsdate/approx.py | 375 +++++++++++++++++++++++------- tsdate/evaluation.py | 37 ++- tsdate/hypergeo.py | 2 +- tsdate/phasing.py | 106 +++++++++ tsdate/variational.py | 57 +++++ 6 files changed, 828 insertions(+), 182 deletions(-) create mode 100644 tsdate/phasing.py diff --git a/tests/test_approximations.py b/tests/test_approximations.py index fb41f843..d7e6b0c7 100644 --- a/tests/test_approximations.py +++ b/tests/test_approximations.py @@ -39,10 +39,10 @@ # TODO: better test set? # TODO: test special case where child is fixed to age 0 _gamma_trio_test_cases = [ # [shape1, rate1, shape2, rate2, muts, rate] - [2.0, 0.0005, 2.0, 0.005, 0.0, 0.001], - [2.0, 0.0005, 2.0, 0.005, 1.0, 0.001], - [2.0, 0.0005, 2.0, 0.005, 2.0, 0.001], - [2.0, 0.0005, 2.0, 0.005, 3.0, 0.001], + [2.0, 0.0005, 1.5, 0.005, 0.0, 0.001], + [2.0, 0.0005, 1.5, 0.005, 1.0, 0.001], + [2.0, 0.0005, 1.5, 0.005, 2.0, 0.001], + [2.0, 0.0005, 1.5, 0.005, 3.0, 0.001], ] @@ -59,17 +59,15 @@ def pdf(t_i, t_j, a_i, b_i, a_j, b_j, y, mu): Target joint (pair) distribution, proportional to the parent/child marginals (gamma) and a Poisson mutation likelihood """ - if t_i < t_j: - return 0.0 - else: - return ( - t_i ** (a_i - 1) - * np.exp(-t_i * b_i) - * t_j ** (a_j - 1) - * np.exp(-t_j * b_j) - * (t_i - t_j) ** y - * np.exp(-(t_i - t_j) * mu) - ) + assert 0 < t_j < t_i + return ( + t_i ** (a_i - 1) + * np.exp(-t_i * b_i) + * t_j ** (a_j - 1) + * np.exp(-t_j * b_j) + * (t_i - t_j) ** y + * np.exp(-(t_i - t_j) * mu) + ) @staticmethod def pdf_rootward(t_i, t_j, a_i, b_i, y, mu): @@ -78,15 +76,13 @@ def pdf_rootward(t_i, t_j, a_i, b_i, y, mu): marginals (gamma) and a Poisson mutation likelihood at a fixed child age """ - if t_i < t_j: - return 0.0 - else: - return ( - t_i ** (a_i - 1) - * np.exp(-t_i * b_i) - * (t_i - t_j) ** y - * np.exp(-(t_i - t_j) * mu) - ) + assert 0 <= t_j < t_i + return ( + t_i ** (a_i - 1) + * np.exp(-t_i * b_i) + * (t_i - t_j) ** y + * np.exp(-(t_i - t_j) * mu) + ) @staticmethod def pdf_leafward(t_i, t_j, a_j, b_j, y, mu): @@ -95,31 +91,46 @@ def pdf_leafward(t_i, t_j, a_j, b_j, y, mu): marginals (gamma) and a Poisson mutation likelihood at a fixed parent age """ - if t_i < t_j: - return 0.0 - else: - return ( - t_j ** (a_j - 1) - * np.exp(-t_j * b_j) - * (t_i - t_j) ** y - * np.exp(-(t_i - t_j) * mu) - ) + assert 0 < t_j < t_i + return ( + t_j ** (a_j - 1) + * np.exp(-t_j * b_j) + * (t_i - t_j) ** y + * np.exp(-(t_i - t_j) * mu) + ) @staticmethod - def pdf_truncated(t_i, low, upp, a_i, b_i): + def pdf_unphased(t_i, t_j, a_i, b_i, a_j, b_j, y, mu): """ - Target proportional to the node marginals (gamma) and an indicator - function + Target joint (pair) distribution, proportional to the parent + marginals (gamma) and a Poisson mutation likelihood over the + two branches leading from (present-day) individual to parents """ - if low < t_i < upp: - return np.exp( - np.log(t_i) * (a_i - 1) - - t_i * b_i - - scipy.special.gammaln(a_i) - + np.log(b_i) * a_i - ) - else: - return 0.0 + assert t_i > 0 and t_j > 0 + return ( + t_i ** (a_i - 1) + * np.exp(-t_i * b_i) + * t_j ** (a_j - 1) + * np.exp(-t_j * b_j) + * (t_i + t_j) ** y + * np.exp(-(t_i + t_j) * mu) + ) + + @staticmethod + def pdf_unphased_rightward(t_i, t_j, a_j, b_j, y, mu): + """ + Target joint (pair) distribution, proportional to the parent + marginals (gamma) and a Poisson mutation likelihood over the + two branches leading from (present-day) individual to parents, + with left parent fixed to t_i + """ + assert t_i > 0 and t_j > 0 + return ( + t_j ** (a_j - 1) + * np.exp(-t_j * b_j) + * (t_i + t_j) ** y + * np.exp(-(t_i + t_j) * mu) + ) def test_moments(self, pars): """ @@ -183,35 +194,36 @@ def test_rootward_moments(self, pars): Test mean and variance of parent age when child age is fixed to a nonzero value """ a_i, b_i, a_j, b_j, y, mu = pars - t_j = a_j / b_j # point "estimate" for child pars_redux = (a_i, b_i, y, mu) - logconst, t_i, _, var_t_i = approx.rootward_moments(t_j, *pars_redux) - ck_normconst = scipy.integrate.quad( - lambda t_i: self.pdf_rootward(t_i, t_j, *pars_redux), - t_j, - np.inf, - epsabs=0, - )[0] - assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-2) - ck_t_i = scipy.integrate.quad( - lambda t_i: t_i * self.pdf_rootward(t_i, t_j, *pars_redux) / ck_normconst, - t_j, - np.inf, - epsabs=0, - )[0] - assert np.isclose(t_i, ck_t_i, rtol=2e-2) - ck_var_t_i = ( - scipy.integrate.quad( - lambda t_i: t_i**2 - * self.pdf_rootward(t_i, t_j, *pars_redux) - / ck_normconst, + mn_j = a_j / b_j # point "estimate" for child + for t_j in [0.0, mn_j]: + logconst, t_i, _, var_t_i = approx.rootward_moments(t_j, *pars_redux) + ck_normconst = scipy.integrate.quad( + lambda t_i: self.pdf_rootward(t_i, t_j, *pars_redux), t_j, np.inf, epsabs=0, )[0] - - ck_t_i**2 - ) - assert np.isclose(var_t_i, ck_var_t_i, rtol=2e-2) + assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-2) + ck_t_i = scipy.integrate.quad( + lambda t_i: t_i * self.pdf_rootward(t_i, t_j, *pars_redux) / ck_normconst, + t_j, + np.inf, + epsabs=0, + )[0] + assert np.isclose(t_i, ck_t_i, rtol=2e-2) + ck_var_t_i = ( + scipy.integrate.quad( + lambda t_i: t_i**2 + * self.pdf_rootward(t_i, t_j, *pars_redux) + / ck_normconst, + t_j, + np.inf, + epsabs=0, + )[0] + - ck_t_i**2 + ) + assert np.isclose(var_t_i, ck_var_t_i, rtol=2e-2) def test_leafward_moments(self, pars): """ @@ -248,41 +260,274 @@ def test_leafward_moments(self, pars): ) assert np.isclose(var_t_j, ck_var_t_j, rtol=2e-2) - def test_truncated_moments(self, pars): + def test_unphased_moments(self, pars): """ - Test mean and variance of child age when parent age is fixed to a nonzero value + Parent ages for an singleton nodes above an unphased individual """ - a_i, b_i, *_ = pars - upp = a_i / b_i * 2 - low = a_i / b_i / 2 - pars_redux = (low, upp, a_i, b_i) - logconst, t_i, _, var_t_i = approx.truncated_moments(*pars_redux) - ck_normconst = scipy.integrate.quad( - lambda t_i: self.pdf_truncated(t_i, *pars_redux), - low, - upp, + logconst, t_i, _, var_t_i, t_j, _, var_t_j = approx.unphased_moments(*pars) + ck_normconst = scipy.integrate.dblquad( + lambda t_i, t_j: self.pdf_unphased(t_i, t_j, *pars), + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-2) + ck_t_i = scipy.integrate.dblquad( + lambda t_i, t_j: t_i * self.pdf_unphased(t_i, t_j, *pars) / ck_normconst, + 0, + np.inf, + 0, + np.inf, epsabs=0, )[0] - assert np.isclose(logconst, np.log(ck_normconst), rtol=1e-4) - ck_t_i = scipy.integrate.quad( - lambda t_i: t_i * self.pdf_truncated(t_i, *pars_redux) / ck_normconst, - low, - upp, + assert np.isclose(t_i, ck_t_i, rtol=2e-2) + ck_t_j = scipy.integrate.dblquad( + lambda t_i, t_j: t_j * self.pdf_unphased(t_i, t_j, *pars) / ck_normconst, + 0, + np.inf, + 0, + np.inf, epsabs=0, )[0] - assert np.isclose(t_i, ck_t_i, rtol=1e-4) + assert np.isclose(t_j, ck_t_j, rtol=2e-2) ck_var_t_i = ( - scipy.integrate.quad( - lambda t_i: t_i**2 - * self.pdf_truncated(t_i, *pars_redux) - / ck_normconst, - low, - upp, + scipy.integrate.dblquad( + lambda t_i, t_j: t_i**2 * self.pdf_unphased(t_i, t_j, *pars) / ck_normconst, + 0, + np.inf, + 0, + np.inf, epsabs=0, )[0] - ck_t_i**2 ) - assert np.isclose(var_t_i, ck_var_t_i, rtol=1e-4) + assert np.isclose(var_t_i, ck_var_t_i, rtol=2e-2) + ck_var_t_j = ( + scipy.integrate.dblquad( + lambda t_i, t_j: t_j**2 * self.pdf_unphased(t_i, t_j, *pars) / ck_normconst, + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + - ck_t_j**2 + ) + assert np.isclose(var_t_j, ck_var_t_j, rtol=2e-2) + + def test_unphased_rightward_moments(self, pars): + """ + Parent ages for an singleton nodes above an unphased individual, where + second parent is fixed to a particular time + """ + a_i, b_i, a_j, b_j, y, mu = pars + pars_redux = (a_j, b_j, y, mu) + t_i = a_i / b_i # point "estimate" for left parent + nc, mn, _, va = approx.unphased_rightward_moments(t_i, *pars_redux) + ck_nc = scipy.integrate.quad( + lambda t_j: self.pdf_unphased_rightward(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] + assert np.isclose(np.exp(nc), ck_nc, rtol=2e-2) + ck_mn = scipy.integrate.quad( + lambda t_j: t_j * self.pdf_unphased_rightward(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] / ck_nc + assert np.isclose(mn, ck_mn, rtol=2e-2) + ck_va = scipy.integrate.quad( + lambda t_j: t_j**2 * self.pdf_unphased_rightward(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] / ck_nc - ck_mn**2 + assert np.isclose(va, ck_va, rtol=2e-2) + + def test_mutation_moments(self, pars): + """ + Mutation mapped to a single branch with both nodes free + """ + def f(t_i, t_j): + assert t_j < t_i + mn = t_i / 2 + t_j / 2 + sq = (t_i**2 + t_i*t_j + t_j**2) / 3 + return mn, sq + mn, _, va = approx.mutation_moments(*pars) + nc = scipy.integrate.dblquad( + lambda t_i, t_j: self.pdf(t_i, t_j, *pars), + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] + ck_mn = scipy.integrate.dblquad( + lambda t_i, t_j: f(t_i, t_j)[0] * self.pdf(t_i, t_j, *pars), + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] / nc + assert np.isclose(mn, ck_mn, rtol=2e-2) + ck_va = scipy.integrate.dblquad( + lambda t_i, t_j: f(t_i, t_j)[1] * self.pdf(t_i, t_j, *pars), + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] / nc - ck_mn**2 + assert np.isclose(va, ck_va, rtol=5e-2) + + def test_mutation_rootward_moments(self, pars): + """ + Mutation mapped to a single branch with child node fixed + """ + def f(t_i, t_j): # conditional moments + assert t_j < t_i + mn = t_i / 2 + t_j / 2 + sq = (t_i**2 + t_i*t_j + t_j**2) / 3 + return mn, sq + a_i, b_i, a_j, b_j, y, mu = pars + pars_redux = (a_i, b_i, y, mu) + mn_j = a_j / b_j # point "estimate" for child + for t_j in [0.0, mn_j]: + mn, _, va = approx.mutation_rootward_moments(t_j, *pars_redux) + nc = scipy.integrate.quad( + lambda t_i: self.pdf_rootward(t_i, t_j, *pars_redux), + t_j, + np.inf, + )[0] + ck_mn = scipy.integrate.quad( + lambda t_i: f(t_i, t_j)[0] * self.pdf_rootward(t_i, t_j, *pars_redux), + t_j, + np.inf, + )[0] / nc + assert np.isclose(mn, ck_mn, rtol=2e-2) + ck_va = scipy.integrate.quad( + lambda t_i: f(t_i, t_j)[1] * self.pdf_rootward(t_i, t_j, *pars_redux), + t_j, + np.inf, + )[0] / nc - ck_mn**2 + assert np.isclose(va, ck_va, rtol=2e-2) + + def test_mutation_leafward_moments(self, pars): + """ + Mutation mapped to a single branch with parent node fixed + """ + def f(t_i, t_j): + assert t_j < t_i + mn = t_i / 2 + t_j / 2 + sq = (t_i**2 + t_i*t_j + t_j**2) / 3 + return mn, sq + a_i, b_i, a_j, b_j, y, mu = pars + t_i = a_i / b_i # point "estimate" for parent + pars_redux = (a_j, b_j, y, mu) + mn, _, va = approx.mutation_leafward_moments(t_i, *pars_redux) + nc = scipy.integrate.quad( + lambda t_j: self.pdf_leafward(t_i, t_j, *pars_redux), + 0, + t_i, + )[0] + ck_mn = scipy.integrate.quad( + lambda t_j: f(t_i, t_j)[0] * self.pdf_leafward(t_i, t_j, *pars_redux), + 0, + t_i, + )[0] / nc + assert np.isclose(mn, ck_mn, rtol=2e-2) + ck_va = scipy.integrate.quad( + lambda t_j: f(t_i, t_j)[1] * self.pdf_leafward(t_i, t_j, *pars_redux), + 0, + t_i, + )[0] / nc - ck_mn**2 + assert np.isclose(va, ck_va, rtol=2e-2) + + def test_unphased_mutation_moments(self, pars): + """ + Mutation mapped to two singleton branches with children fixed to time zero + """ + def f(t_i, t_j): # conditional moments + pr = t_i / (t_i + t_j) + mn = pr * t_i / 2 + (1 - pr) * t_j / 2 + sq = pr * t_i**2 / 3 + (1 - pr) * t_j**2 / 3 + return pr, mn, sq + pr, mn, _, va = approx.unphased_mutation_moments(*pars) + nc = scipy.integrate.dblquad( + lambda t_i, t_j: self.pdf_unphased(t_i, t_j, *pars), + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + ck_pr = scipy.integrate.dblquad( + lambda t_i, t_j: f(t_i, t_j)[0] * self.pdf_unphased(t_i, t_j, *pars), + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] / nc + assert np.isclose(pr, ck_pr, rtol=2e-2) + ck_mn = scipy.integrate.dblquad( + lambda t_i, t_j: f(t_i, t_j)[1] * self.pdf_unphased(t_i, t_j, *pars), + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] / nc + assert np.isclose(mn, ck_mn, rtol=2e-2) + ck_va = scipy.integrate.dblquad( + lambda t_i, t_j: f(t_i, t_j)[2] * self.pdf_unphased(t_i, t_j, *pars), + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] / nc - ck_mn**2 + assert np.isclose(va, ck_va, rtol=2e-2) + + def test_unphased_mutation_rightward_moments(self, pars): + """ + Mutation mapped to two branches with children fixed to time zero, and + left parent (i) fixed + """ + def f(t_i, t_j): # conditional moments + pr = t_i / (t_i + t_j) + mn = pr * t_i / 2 + (1 - pr) * t_j / 2 + sq = pr * t_i**2 / 3 + (1 - pr) * t_j**2 / 3 + return pr, mn, sq + a_i, b_i, a_j, b_j, y, mu = pars + t_i = a_i / b_i # point "estimate" for left parent + pars_redux = (a_j, b_j, y, mu) + pr, mn, _, va = approx.unphased_mutation_rightward_moments(t_i, *pars_redux) + nc = scipy.integrate.quad( + lambda t_j: self.pdf_unphased_rightward(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] + ck_pr = scipy.integrate.quad( + lambda t_j: f(t_i, t_j)[0] * self.pdf_unphased_rightward(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] / nc + assert np.isclose(pr, ck_pr, rtol=2e-2) + ck_mn = scipy.integrate.quad( + lambda t_j: f(t_i, t_j)[1] * self.pdf_unphased_rightward(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] / nc + assert np.isclose(mn, ck_mn, rtol=2e-2) + ck_va = scipy.integrate.quad( + lambda t_j: f(t_i, t_j)[2] * self.pdf_unphased_rightward(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] / nc - ck_mn**2 + assert np.isclose(va, ck_va, rtol=2e-2) def test_approximate_gamma_kl(self, pars): _, t_i, ln_t_i, _, t_j, ln_t_j, _ = approx.moments(*pars) @@ -311,6 +556,8 @@ def test_approximate_gamma_mom(self, pars): assert np.isclose(va_t_j, ck_va_t_j) + + class TestPriorMomentMatching: """ Test approximation of the conditional coalescent prior via diff --git a/tsdate/approx.py b/tsdate/approx.py index cb832217..7f3aa621 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -28,6 +28,7 @@ from math import lgamma from math import log +import mpmath import numba import numpy as np from numba.types import Tuple as _tuple @@ -184,6 +185,17 @@ def average_gammas(alpha, beta): return approximate_gamma_kl(avg_x, avg_logx) +@numba.njit(_b(_f, _f, _f)) +def _valid_moments(mn, ln, va): + if not (mn > 0.0 and va > 0.0): + return False + if not (ln < log(mn)): + return False + return True + + +# --- node posteriors --- # + @numba.njit(_unituple(_f, 7)(_f, _f, _f, _f, _f, _f)) def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): """ @@ -233,7 +245,7 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): mn_i = mn_j * z + b / t sq_i = sq_j * z**2 + (b + 1) * (mn_i + mn_j * z) / t va_i = sq_i - mn_i**2 - ln_i = np.log(mn_i) - va_i / 2 / mn_i**2 if mn_j > 0 else -np.inf + ln_i = np.log(mn_i) - va_i / 2 / mn_i**2 if mn_i > 0 else -np.inf return logl, mn_i, ln_i, va_i, mn_j, ln_j, va_j @@ -371,42 +383,6 @@ def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): return logl, mn_j, ln_j, va_j -@numba.njit(_unituple(_f, 4)(_f, _f, _f, _f)) -def truncated_moments(low, upp, a_i, b_i): - """ - Calculate sufficient statistics for a gamma PDF truncated to (low, upp). - The logarithmic moments are approximated via a Taylor expansion around the - mean. - - :param float low: lower bound on node age - :param float upp: upper bound on node age - :param float a_i: the shape parameter of the cavity distribution for the node - :param float b_i: the rate parameter of the cavity distribution for the node - - See, e.g., https://doi.org/10.1080%2F03610920008832519 - """ - - assert upp > low - - gammainc = hypergeo._gammainc - p0 = gammainc(a_i + 0, b_i * upp) - gammainc(a_i + 0, b_i * low) - p1 = gammainc(a_i + 1, b_i * upp) - gammainc(a_i + 1, b_i * low) - p2 = gammainc(a_i + 2, b_i * upp) - gammainc(a_i + 2, b_i * low) - - # TODO: replace with error? skip update? - assert p0 > 0.0, "Zero mass in truncation region" - - t = a_i / b_i - - logl = log(p0) - mn_i = p1 / p0 * t - sq_i = p2 / p0 * t * (1 / b_i + t) - va_i = sq_i - mn_i**2 - ln_i = log(mn_i) - va_i / 2 / mn_i**2 if mn_i > 0 else -np.inf - - return logl, mn_i, ln_i, va_i - - @numba.njit(_b(_f, _f, _f, _f, _f, _f)) def _hyp2f1_valid_parameterization(a_i, b_i, a_j, b_j, y, mu): """Uses shape / rate parameterization""" @@ -454,16 +430,6 @@ def _hyperu_valid_parameterization(t_j, a_i, b_i, y, mu): return False return True - -@numba.njit(_b(_f, _f, _f)) -def _valid_moments(mn, ln, va): - if not (mn > 0.0 and va > 0.0): - return False - if not (ln < log(mn)): - return False - return True - - @numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r, _b)) def gamma_projection(pars_i, pars_j, pars_ij, min_kl): """ @@ -590,41 +556,8 @@ def rootward_projection(t_j, pars_i, pars_ij, min_kl): return logconst, np.array(proj_i) -@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r, _b)) -def truncated_projection(bounds_i, pars_i, min_kl): - r""" - Match a gamma distributions to the surrogate posterior :math:`Ga(t_i | a_i + - 1, b_i) I[low < t_i < upp]`, by minimizing KL divergence. - - :param float bounds_i: lower and upper bounds on node age - :param float pars_i: gamma natural parameters for the cavity distribution - :param bool min_kl: minimize KL divergence (match central moments if False) - - :return: normalizing constant, gamma natural parameters for node - """ - - # switch from natural to canonical parameterization - a_i, b_i = pars_i - low, upp = bounds_i - a_i += 1 - - logconst, t_i, ln_t_i, va_t_i = truncated_moments(low, upp, a_i, b_i) - - valid_i = _valid_moments(t_i, ln_t_i, va_t_i) - if not valid_i: - return np.nan, pars_i - - if min_kl: - proj_i = approximate_gamma_kl(t_i, ln_t_i) - else: - proj_i = approximate_gamma_mom(t_i, va_t_i) - - return logconst, np.array(proj_i) - - # --- mutation posteriors from node posteriors --- # - @numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f)) def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" @@ -831,3 +764,285 @@ def mutation_rootward_projection(t_j, pars_i, pars_ij, min_kl): proj_m = approximate_gamma_mom(t_m, va_t_m) return np.array(proj_m) + + +# --- unphased node posteriors --- # + +@numba.njit(_unituple(_f, 7)(_f, _f, _f, _f, _f, _f)) +def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): + """ + Calculate sufficient statistics for the PDF proportional to :math:`Ga(t_j | + a_j, b_j) Ga(t_i | a_i, b_i) Po(y_{ij} | \\mu_{ij} t_i + t_j)`, where + :math:`i` and :math:`j` are parents of the same individual (assumed to be at + time zero). The logarithmic moments are approximated via a Taylor expansion + around the mean. + + :param float a_i: the shape parameter of the cavity distribution for the first parent + :param float b_i: the rate parameter of the cavity distribution for the first parent + :param float a_j: the shape parameter of the cavity distribution for the second parent + :param float b_j: the rate parameter of the cavity distribution for the second parent + :param float y_ij: the number of mutations on the singleton edge pair + :param float mu_ij: the span-weighted mutation rate of the singleton edge pair + + :return: normalizing constant, E[t_i], E[log t_i], V[t_i], + E[t_j], E[log t_j], V[t_j] + """ + + a = a_j + b = a_i + a_j + y_ij + c = a_j + a_i + t = mu_ij + b_i + z = (mu_ij + b_j) / t + + #with numba.objmode(f0='f8', f1='f8', f2='f8'): + # f0 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 0, c + 0, 1 - z))) + # f1 = float(mpmath.log(mpmath.hyp2f1(a + 1, b + 1, c + 1, 1 - z))) + # f2 = float(mpmath.log(mpmath.hyp2f1(a + 2, b + 2, c + 2, 1 - z))) + hyp2f1 = hypergeo._hyp2f1_laplace + f0 = hyp2f1(a + 0, b + 0, c + 0, 1 - z) + f1 = hyp2f1(a + 1, b + 1, c + 1, 1 - z) + f2 = hyp2f1(a + 2, b + 2, c + 2, 1 - z) + s1 = a * b / c + s2 = s1 * (a + 1) * (b + 1) / (c + 1) + d1 = s1 * np.exp(f1 - f0) + d2 = s2 * np.exp(f2 - f0) + + logl = f0 + hypergeo._betaln(a_j, a_i) + hypergeo._gammaln(b) - b * np.log(t) + + mn_j = d1 / t + sq_j = d2 / t**2 + va_j = sq_j - mn_j**2 + ln_j = np.log(mn_j) - va_j / 2 / mn_j**2 if mn_j > 0 else -np.inf + + mn_i = -mn_j * z + b / t + sq_i = sq_j * z**2 + (b + 1) * (mn_i - mn_j * z) / t + va_i = sq_i - mn_i**2 + ln_i = np.log(mn_i) - va_i / 2 / mn_i**2 if mn_i > 0 else -np.inf + + return logl, mn_i, ln_i, va_i, mn_j, ln_j, va_j + + +@numba.njit(_unituple(_f, 4)(_f, _f, _f, _f, _f)) +def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): + """ + Calculate sufficient statistics for the PDF proportional to :math:`Ga(t_j | + a_j, b_j) Po(y_{ij} | \\mu_{ij} t_i + t_j)`, where :math:`i` and :math:`j` + are parents of the same individual (assumed to be at time zero). The + logarithmic moments are approximated via a Taylor expansion around the + mean. + + :param float t_i: the age of the first parent + :param float a_j: the shape parameter of the cavity distribution for the second parent + :param float b_j: the rate parameter of the cavity distribution for the second parent + :param float y_ij: the number of mutations on the singleton edge pair + :param float mu_ij: the span-weighted mutation rate of the singleton edge pair + + :return: normalizing constant, E[t_j], E[log t_j], V[t_j] + """ + + assert t_i > 0.0 + + a = a_j + b = a_j + y_ij + 1 + z = t_i * (mu_ij + b_j) + + hyperu = hypergeo._hyperu_laplace + f0, d0 = hyperu(a + 0, b + 0, z) + f1, d1 = hyperu(a + 1, b + 1, z) + + logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + hypergeo._gammaln(a) + mn_j = -t_i * d0 + va_j = t_i**2 * d0 * (d1 - d0) + ln_j = log(mn_j) - va_j / 2 / mn_j**2 if mn_j > 0 else -np.inf + + return logl, mn_j, ln_j, va_j + + +@numba.njit(_unituple(_f, 4)(_f, _f, _f, _f, _f, _f)) +def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): + r""" + Calculate gamma sufficient statistics for the PDF proportional to: + + ..math:: + + p(x) = \int_0^\infty \int_0^\infty (Unif(x | 0, t_i) + Unif(x | 0, t_j)) + Ga(t_i | a_i, b_i) Ga(t_j | a_j b_j) Po(y | \mu_ij (t_i + t_j)) dt_j dt_i + + which models the time :math:`x` of a mutation uniformly distributed between + zero and one of two parents with ages :math:`t_i` and :math:`t_j`, where + the mutation count on both branches is :math:`y_{ij}` with total mutation + rate :math:`\mu_{ij}`. + + Returns log P[x under i], E[x], E[\log x], V[x]. + """ + + # Conditioning on ages of parents: + # P[x under i | t_i, t_j] = t_i / (t_i + t_j) + # E[x | x under t_i, t_i] = t_i / 2 + # E[x^2 | x under t_i, t_i] = t_i**2 / 3 + # and equivalently for t_j. Integrating these moments over the EP surrogate + # density leads to hypergeometric functions similar to the node case, but + # with integer perturbations of a_i, a_j, y_ij. + + a = a_j + b = a_j + a_i + y_ij + c = a_j + a_i + t = mu_ij + b_i + z = (mu_ij + b_j) / t + + hyp2f1 = hypergeo._hyp2f1_laplace + f000 = hyp2f1(a + 0, b + 0, c + 0, 1 - z) + f001 = hyp2f1(a + 0, b + 0, c + 1, 1 - z) + f012 = hyp2f1(a + 0, b + 1, c + 2, 1 - z) + f023 = hyp2f1(a + 0, b + 2, c + 3, 1 - z) + f212 = hyp2f1(a + 2, b + 1, c + 2, 1 - z) + f323 = hyp2f1(a + 3, b + 2, c + 3, 1 - z) + + s0 = b / t / c / (c + 1) + s1 = (c - a) * (c - a + 1) + s2 = a * (a + 1) + d0 = s0 * (b + 1) / t / (c + 2) + d1 = s1 * (c - a + 2) + d2 = s2 * (a + 2) + + mn_m = s0 * s1 * exp(f012 - f000) / 2 + s0 * s2 * exp(f212 - f000) / 2 + sq_m = d0 * d1 * exp(f023 - f000) / 3 + d0 * d2 * exp(f323 - f000) / 3 + va_m = sq_m - mn_m**2 + ln_m = log(mn_m) - va_m / 2 / mn_m**2 if mn_m > 0 else -np.inf + pr_m = (c - a) / c * exp(f001 - f000) + + return pr_m, mn_m, ln_m, va_m + + +@numba.njit(_unituple(_f, 4)(_f, _f, _f, _f, _f)) +def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): + r""" + Calculate gamma sufficient statistics for the PDF proportional to: + + ..math:: + + p(x) = \int_0^\infty (Unif(x | 0, t_i) + Unif(x | 0, t_j)) + Ga(t_j | a_j b_j) Po(y | \mu_ij (t_i + t_j)) dt_j + + which models the time :math:`x` of a mutation uniformly distributed between + zero and one of two parents with ages :math:`t_i` and :math:`t_j`, where + the mutation count on both branches is :math:`y_{ij}` with total mutation + rate :math:`\mu_{ij}`. + + Returns log P[x under i], E[x], E[\log x], V[x]. + """ + + # Conditioning on ages of parents: + # P[x under i | t_i, t_j] = t_i / (t_i + t_j) + # E[x | x under t_i, t_i] = t_i / 2 + # E[x^2 | x under t_i, t_i] = t_i**2 / 3 + # and equivalently for t_j. Integrating these moments over the EP surrogate + # density leads to Tricomi functions similar to the node case, but + # with integer perturbations of a_j, y_ij. + + a = a_j + b = a_j + y_ij + 1 + z = t_i * (mu_ij + b_j) + + #with numba.objmode(f00='f8', f10='f8', f21='f8', f32='f8'): + # f00 = float(mpmath.log(mpmath.hyperu(a + 0, b + 0, z))) + # f10 = float(mpmath.log(mpmath.hyperu(a + 1, b + 0, z))) + # f21 = float(mpmath.log(mpmath.hyperu(a + 2, b + 1, z))) + # f32 = float(mpmath.log(mpmath.hyperu(a + 3, b + 2, z))) + + # direct but unstable: + hyperu = hypergeo._hyperu_laplace + f00, d00 = hyperu(a + 0, b + 0, z) + f10, d10 = hyperu(a + 1, b + 0, z) + f21, d21 = hyperu(a + 2, b + 1, z) + f32, d32 = hyperu(a + 3, b + 2, z) + pr_m = 1.0 - exp(f10 - f00) * a + mn_m = pr_m * t_i / 2 + t_i * exp(f21 - f00) * a * (a + 1) / 2 + sq_m = pr_m * t_i ** 2 / 3 + t_i ** 2 * exp(f32 - f00) * a * (a + 1) * (a + 2) / 3 + + # TODO: use a stabler approach with derivatives + # note that exp(f10 - f00) = (a + z * d00) / (a - b + 1) + # however the denominator is 0 if y_ij is 0 + # note that when y_ij == 0 then a == b + 1 and f00 = z**(-a) + + va_m = sq_m - mn_m**2 + ln_m = log(mn_m) - va_m / 2 / mn_m**2 if mn_m > 0 else -np.inf + + return pr_m, mn_m, ln_m, va_m + + +@numba.njit(_b(_f, _f, _f, _f, _f, _f)) +def _hyp2f1_unphased_valid_parameterization(a_i, b_i, a_j, b_j, y, mu): + """Uses shape / rate parameterization""" + a = a_j + b = a_i + a_j + y + c = a_j + a_i + s = mu + b_j + t = mu + b_i + # check that 2F1 argument is less than unity + if t <= 0.0: + return False + z = 1.0 - s / t + if z >= 1.0 or z / (z - 1) >= 1.0: + return False + # check that 2F1 is positive + if a <= 0: + return False + if b <= 0: + return False + if c <= 0: + return False + return True + + +@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r, _b)) +def unphased_projection(pars_i, pars_j, pars_ij, min_kl): + """ + Match a pair of gamma distributions to the potential function :math:`Ga(t_j + | a_j + 1, b_j) Ga(t_i | a_i + 1, b_i) Po(y_{ij} | \\mu_{ij} t_i + t_j)`, + where :math:`i` and :math:`j` are parents of the same individual (assumed + to be at time zero), by minimizing KL divergence. + + :param float pars_i: gamma natural parameters for the first parent's cavity + distribution + :param float pars_j: gamma natural parameters for the second parent's + cavity distribution + :param float pars_ij: gamma natural parameters for the edge pair likelihood + :param bool min_kl: minimize KL divergence (match central moments if False) + + :return: normalizing constant, gamma natural parameters for parents + """ + + # switch from natural to canonical parameterization + a_i, b_i = pars_i + a_j, b_j = pars_j + y_ij, mu_ij = pars_ij + a_i += 1 + a_j += 1 + + if not _hyp2f1_unphased_valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): + print("DEBUG") #DEBUG + return np.nan, pars_i, pars_j + + logconst, t_i, ln_t_i, va_t_i, t_j, ln_t_j, va_t_j = unphased_moments( + a_i, + b_i, + a_j, + b_j, + y_ij, + mu_ij, + ) + + valid_i = _valid_moments(t_i, ln_t_i, va_t_i) + valid_j = _valid_moments(t_j, ln_t_j, va_t_j) + if not (valid_i and valid_j): + return np.nan, pars_i, pars_j + + if min_kl: + proj_i = approximate_gamma_kl(t_i, ln_t_i) + proj_j = approximate_gamma_kl(t_j, ln_t_j) + else: + proj_i = approximate_gamma_mom(t_i, va_t_i) + proj_j = approximate_gamma_mom(t_j, va_t_j) + + return logconst, np.array(proj_i), np.array(proj_j) diff --git a/tsdate/evaluation.py b/tsdate/evaluation.py index 7ed44784..1fbf2447 100644 --- a/tsdate/evaluation.py +++ b/tsdate/evaluation.py @@ -533,7 +533,7 @@ def mutation_coverage(ts, inferred_ts, alpha): return prop_covered -def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None): +def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None, title=None, subtending_node=False): """ Return true and inferred mutation ages, optionally creating a scatterplot and filtering by minimum or maximum frequency. @@ -555,8 +555,6 @@ def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None) missing = np.logical_or(true_mut == tskit.NULL, infr_mut == tskit.NULL) infr_mut = infr_mut[~missing] true_mut = true_mut[~missing] - mean = inferred_ts.mutations_time[infr_mut] - truth = ts.mutations_time[true_mut] # filter by frequency if min_freq is not None or max_freq is not None: freq = np.zeros(inferred_ts.num_mutations) @@ -569,9 +567,26 @@ def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None) max_freq = np.max(freq) freq = freq[infr_mut] is_freq = np.logical_and(freq >= min_freq, freq <= max_freq) - mean = mean[is_freq] - truth = truth[is_freq] - # plot + infr_mut = infr_mut[is_freq] + true_mut = true_mut[is_freq] + # get age of mutation or subtended node + if subtending_node: + infr_node = inferred_ts.mutations_node[infr_mut] + true_node = ts.mutations_node[true_mut] + _, uniq_idx = np.unique(infr_node, return_index=True) + infr_node = infr_node[uniq_idx] + true_node = true_node[uniq_idx] + _, uniq_idx = np.unique(true_node, return_index=True) + infr_node = infr_node[uniq_idx] + true_node = true_node[uniq_idx] + mean = inferred_ts.nodes_time[infr_node] + truth = ts.nodes_time[true_node] + nonzero = np.logical_and(mean > 0, truth > 0) + mean = mean[nonzero] + truth = truth[nonzero] + else: + mean = inferred_ts.mutations_time[infr_mut] + truth = ts.mutations_time[true_mut] if plotpath is not None: rsq = np.corrcoef(np.log10(mean), np.log10(truth))[0, 1] ** 2 bias = np.mean(np.log10(mean) - np.log10(truth)) @@ -581,8 +596,14 @@ def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None) plt.hexbin(truth, mean, xscale="log", yscale="log", mincnt=1) plt.text(0.01, 0.99, info, ha="left", va="top", transform=plt.gca().transAxes) plt.axline(pt1, pt2, linestyle="--", color="firebrick") - plt.xlabel("True mutation age") - plt.ylabel("Estimated mutation age") + if subtending_node: + plt.xlabel("True node age") + plt.ylabel("Estimated node age") + else: + plt.xlabel("True mutation age") + plt.ylabel("Estimated mutation age") + if title is not None: + plt.title(title) plt.tight_layout() plt.savefig(plotpath) plt.clf() diff --git a/tsdate/hypergeo.py b/tsdate/hypergeo.py index b64d9b69..1df623f8 100644 --- a/tsdate/hypergeo.py +++ b/tsdate/hypergeo.py @@ -163,7 +163,7 @@ def _hyperu_laplace(a, b, x): TODO: details """ - assert b > a > 0.0 + assert b >= a > 0.0 assert x > 0.0 t = b - x - 1 diff --git a/tsdate/phasing.py b/tsdate/phasing.py new file mode 100644 index 00000000..350dc2e4 --- /dev/null +++ b/tsdate/phasing.py @@ -0,0 +1,106 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Tools for phasing singleton mutations +""" +import numpy as np +import tskit + +#def _mutations_frequency(ts): +# +#def mutations_frequency(ts): + +def remove_singletons(ts): + """ + Remove all singleton mutations from the tree sequence. + + Return the new ts, along with the id of the removed mutations in the + original tree sequence. + """ + + nodes_sample = np.bitwise_and(ts.nodes_flags, tskit.NODE_IS_SAMPLE).astype(bool) + assert np.sum(nodes_sample) == ts.num_samples + assert np.all(~nodes_sample[ts.edges_parent]), "Sample node has a child" + singletons = nodes_sample[ts.mutations_node] + + old_metadata = np.array(tskit.unpack_strings( + ts.tables.mutations.metadata, + ts.tables.mutations.metadata_offset, + )) + old_state = np.array(tskit.unpack_strings( + ts.tables.mutations.derived_state, + ts.tables.mutations.derived_state_offset, + )) + new_metadata, new_metadata_offset = tskit.pack_strings(old_metadata[~singletons]) + new_state, new_state_offset = tskit.pack_strings(old_state[~singletons]) + + tables = ts.dump_tables() + tables.mutations.set_columns( + node=ts.mutations_node[~singletons], + time=ts.mutations_time[~singletons], + site=ts.mutations_site[~singletons], + derived_state=new_state, + derived_state_offset=new_state_offset, + metadata=new_metadata, + metadata_offset=new_metadata_offset, + ) + tables.sort() + tables.build_index() + tables.compute_mutation_parents() + + return tables.tree_sequence(), np.flatnonzero(singletons) + + +def rephase_singletons(ts, use_node_times=True, random_seed=None): + """ + Rephase singleton mutations in the tree sequence. How If `use_node_times` + is True, singletons are added to permissable branches with probability + proportional to the branch length (and with equal probability otherwise). + """ + rng = np.random.default_rng(random_seed) + + mutations_node = ts.mutations_node.copy() + mutations_time = ts.mutations_time.copy() + + singletons = np.bitwise_and(ts.nodes_flags[mutations_node], tskit.NODE_IS_SAMPLE) + singletons = np.flatnonzero(singletons) + tree = ts.first() + for i in singletons: + position = ts.sites_position[ts.mutations_site[i]] + individual = ts.nodes_individual[ts.mutations_node[i]] + time = ts.nodes_time[ts.mutations_node[i]] + assert individual != tskit.NULL + assert time == 0.0 + tree.seek(position) + nodes_id = ts.individual(individual).nodes + nodes_length = np.array([tree.time(tree.parent(n)) - time for n in nodes_id]) + nodes_prob = nodes_length if use_node_times else np.ones(nodes_id.size) + mutations_node[i] = rng.choice(nodes_id, p=nodes_prob / nodes_prob.sum(), size=1) + if not np.isnan(mutations_time[i]): + mutations_time[i] = (time + tree.time(tree.parent(mutations_node[i]))) / 2 + + # TODO: add metadata with phase probability + tables = ts.dump_tables() + tables.mutations.node = mutations_node + tables.mutations.time = mutations_time + tables.sort() + return tables.tree_sequence(), singletons diff --git a/tsdate/variational.py b/tsdate/variational.py index aa11e4e5..e6db3eb5 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -55,6 +55,10 @@ ROOTWARD = 0 # edge likelihood to parent LEAFWARD = 1 # edge likelihood to child +# columns for unphased_factors +FIRSTPAR = 0 # edge likelihood to first parent +SECNDPAR = 1 # edge likelihood to second parent + # columns for node_factors MIXPRIOR = 0 # mixture prior to node CONSTRNT = 1 # bounds on node ages @@ -182,6 +186,7 @@ def _check_valid_state( for i, (p, c) in enumerate(zip(edges_parent, edges_child)): posterior_check[p] += edge_factors[i, ROOTWARD] posterior_check[c] += edge_factors[i, LEAFWARD] + # TODO: unphased factors posterior_check += node_factors[:, MIXPRIOR] posterior_check += node_factors[:, CONSTRNT] return np.allclose(posterior_check, posterior) @@ -233,9 +238,12 @@ def __init__(self, ts, likelihoods, constraints, mutations_edge): self.constraints = constraints self.mutations_edge = mutations_edge + # TODO: get likelihoods + unphaseed + # mutable self.node_factors = np.zeros((ts.num_nodes, 2, 2)) self.edge_factors = np.zeros((ts.num_edges, 2, 2)) + #self.unph_factors = np.zeros((..., 2, 2)) #TODO self.posterior = np.zeros((ts.num_nodes, 2)) self.log_partition = np.zeros(ts.num_edges) self.scale = np.ones(ts.num_nodes) @@ -452,6 +460,46 @@ def posterior_damping(x): return np.nan + # @staticmethod + # @numba.njit(_f(_i2r, _i1r, _f2r, _f2w, _f3w, _f1w, _f)) + # def propagate_unphased( + # parents, individual, likelihoods, posterior, factors, scale, max_shape + # ): + # """ + # Update approximating factors for unphased singletons. + + # :param ndarray parents: rows are unphased intervals, columns are first + # and second parents of an individual over that interval. + # :param ndarray individual: the individual associated with each + # unphased interval. + # :param ndarray likelihoods: rows are unphased intervals, columns are + # number of singleton mutations and interval span. + # :param ndarray posterior: rows are nodes, columns are first and + # second natural parameters of gamma posteriors. Updated in + # place. + # :param ndarray factors: rows are unphased intervals, columns index + # different types of updates. Updated in place. + # :param ndarray scale: array of dimension `[num_nodes]` containing a + # scaling factor for the posteriors, updated in-place. + # :param float max_shape: the maximum allowed shape for node posteriors. + # """ + + # # TODO assert ??? + # assert max_shape >= 1.0 + # assert 0.0 < min_step < 1.0 + + # def cavity_damping(x, y): + # return _damp(x, y, min_step) + + # def posterior_damping(x): + # return _rescale(x, max_shape) + + # # TODO copy from propagate_likelihood... + + # # TODO copy from propagate_likelihood... + + # return np.nan + @staticmethod @numba.njit(_f2w(_i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b)) def propagate_mutations( @@ -501,6 +549,7 @@ def propagate_mutations( if i == tskit.NULL: # skip mutations above root mutations_posterior[m] = np.nan continue + # TODO: if unphased skip, set to nan p, c = edges_parent[i], edges_child[i] if fixed[p] and fixed[c]: child_age = constraints[c, 0] @@ -550,6 +599,9 @@ def rescale_factors(edges_parent, edges_child, node_factors, edge_factors, scale edge_factors[:, LEAFWARD] *= scale[c, np.newaxis] node_factors[:, MIXPRIOR] *= scale[:, np.newaxis] node_factors[:, CONSTRNT] *= scale[:, np.newaxis] + # TODO: unphased factors + #unph_factors[:, FIRSTPAR] *= scale[:, np.newaxis] + #unph_factors[:, SECNDPAR] *= scale[:, np.newaxis] scale[:] = 1.0 def iterate( @@ -563,6 +615,11 @@ def iterate( regularise=True, check_valid=False, ): + # TODO: pass through unphased intervals + #self.propagate_unphased( + # ... + #) + # rootward + leafward pass through edges self.propagate_likelihood( self.edge_order, From a192d099fd631cc1b135fa2770f8d113052c133f Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Thu, 16 May 2024 15:20:52 -0700 Subject: [PATCH 02/29] Remove min_kl slash match_central_moments --- tests/test_approximations.py | 94 +++----------- tsdate/approx.py | 234 ++++++++++------------------------- tsdate/core.py | 8 +- tsdate/variational.py | 25 ++-- 4 files changed, 91 insertions(+), 270 deletions(-) diff --git a/tests/test_approximations.py b/tests/test_approximations.py index d7e6b0c7..09e599d7 100644 --- a/tests/test_approximations.py +++ b/tests/test_approximations.py @@ -136,7 +136,7 @@ def test_moments(self, pars): """ Test mean and variance when ages of both nodes are free """ - logconst, t_i, _, var_t_i, t_j, _, var_t_j = approx.moments(*pars) + logconst, t_i, var_t_i, t_j, var_t_j = approx.moments(*pars) ck_normconst = scipy.integrate.dblquad( lambda t_i, t_j: self.pdf(t_i, t_j, *pars), 0, @@ -197,7 +197,7 @@ def test_rootward_moments(self, pars): pars_redux = (a_i, b_i, y, mu) mn_j = a_j / b_j # point "estimate" for child for t_j in [0.0, mn_j]: - logconst, t_i, _, var_t_i = approx.rootward_moments(t_j, *pars_redux) + logconst, t_i, var_t_i = approx.rootward_moments(t_j, *pars_redux) ck_normconst = scipy.integrate.quad( lambda t_i: self.pdf_rootward(t_i, t_j, *pars_redux), t_j, @@ -232,7 +232,7 @@ def test_leafward_moments(self, pars): a_i, b_i, a_j, b_j, y, mu = pars t_i = a_i / b_i # point "estimate" for parent pars_redux = (a_j, b_j, y, mu) - logconst, t_j, _, var_t_j = approx.leafward_moments(t_i, *pars_redux) + logconst, t_j, var_t_j = approx.leafward_moments(t_i, *pars_redux) ck_normconst = scipy.integrate.quad( lambda t_j: self.pdf_leafward(t_i, t_j, *pars_redux), 0, @@ -264,7 +264,7 @@ def test_unphased_moments(self, pars): """ Parent ages for an singleton nodes above an unphased individual """ - logconst, t_i, _, var_t_i, t_j, _, var_t_j = approx.unphased_moments(*pars) + logconst, t_i, var_t_i, t_j, var_t_j = approx.unphased_moments(*pars) ck_normconst = scipy.integrate.dblquad( lambda t_i, t_j: self.pdf_unphased(t_i, t_j, *pars), 0, @@ -325,7 +325,7 @@ def test_unphased_rightward_moments(self, pars): a_i, b_i, a_j, b_j, y, mu = pars pars_redux = (a_j, b_j, y, mu) t_i = a_i / b_i # point "estimate" for left parent - nc, mn, _, va = approx.unphased_rightward_moments(t_i, *pars_redux) + nc, mn, va = approx.unphased_rightward_moments(t_i, *pars_redux) ck_nc = scipy.integrate.quad( lambda t_j: self.pdf_unphased_rightward(t_i, t_j, *pars_redux), 0, @@ -354,7 +354,7 @@ def f(t_i, t_j): mn = t_i / 2 + t_j / 2 sq = (t_i**2 + t_i*t_j + t_j**2) / 3 return mn, sq - mn, _, va = approx.mutation_moments(*pars) + mn, va = approx.mutation_moments(*pars) nc = scipy.integrate.dblquad( lambda t_i, t_j: self.pdf(t_i, t_j, *pars), 0, @@ -395,7 +395,7 @@ def f(t_i, t_j): # conditional moments pars_redux = (a_i, b_i, y, mu) mn_j = a_j / b_j # point "estimate" for child for t_j in [0.0, mn_j]: - mn, _, va = approx.mutation_rootward_moments(t_j, *pars_redux) + mn, va = approx.mutation_rootward_moments(t_j, *pars_redux) nc = scipy.integrate.quad( lambda t_i: self.pdf_rootward(t_i, t_j, *pars_redux), t_j, @@ -426,7 +426,7 @@ def f(t_i, t_j): a_i, b_i, a_j, b_j, y, mu = pars t_i = a_i / b_i # point "estimate" for parent pars_redux = (a_j, b_j, y, mu) - mn, _, va = approx.mutation_leafward_moments(t_i, *pars_redux) + mn, va = approx.mutation_leafward_moments(t_i, *pars_redux) nc = scipy.integrate.quad( lambda t_j: self.pdf_leafward(t_i, t_j, *pars_redux), 0, @@ -454,7 +454,7 @@ def f(t_i, t_j): # conditional moments mn = pr * t_i / 2 + (1 - pr) * t_j / 2 sq = pr * t_i**2 / 3 + (1 - pr) * t_j**2 / 3 return pr, mn, sq - pr, mn, _, va = approx.unphased_mutation_moments(*pars) + pr, mn, va = approx.unphased_mutation_moments(*pars) nc = scipy.integrate.dblquad( lambda t_i, t_j: self.pdf_unphased(t_i, t_j, *pars), 0, @@ -504,7 +504,7 @@ def f(t_i, t_j): # conditional moments a_i, b_i, a_j, b_j, y, mu = pars t_i = a_i / b_i # point "estimate" for left parent pars_redux = (a_j, b_j, y, mu) - pr, mn, _, va = approx.unphased_mutation_rightward_moments(t_i, *pars_redux) + pr, mn, va = approx.unphased_mutation_rightward_moments(t_i, *pars_redux) nc = scipy.integrate.quad( lambda t_j: self.pdf_unphased_rightward(t_i, t_j, *pars_redux), 0, @@ -530,7 +530,11 @@ def f(t_i, t_j): # conditional moments assert np.isclose(va, ck_va, rtol=2e-2) def test_approximate_gamma_kl(self, pars): - _, t_i, ln_t_i, _, t_j, ln_t_j, _ = approx.moments(*pars) + a_i, b_i, a_j, b_j, y, mu = pars + t_i = a_i / b_i + ln_t_i = hypergeo._digamma(a_i) - np.log(b_i) + t_j = a_j / b_j + ln_t_j = hypergeo._digamma(a_j) - np.log(b_j) alpha_i, beta_i = approx.approximate_gamma_kl(t_i, ln_t_i) alpha_j, beta_j = approx.approximate_gamma_kl(t_j, ln_t_j) ck_t_i = (alpha_i + 1) / beta_i @@ -543,7 +547,7 @@ def test_approximate_gamma_kl(self, pars): assert np.isclose(ln_t_j, ck_ln_t_j) def test_approximate_gamma_mom(self, pars): - _, t_i, _, va_t_i, t_j, _, va_t_j = approx.moments(*pars) + _, t_i, va_t_i, t_j, va_t_j = approx.moments(*pars) alpha_i, beta_i = approx.approximate_gamma_mom(t_i, va_t_i) alpha_j, beta_j = approx.approximate_gamma_mom(t_j, va_t_j) ck_t_i = (alpha_i + 1) / beta_i @@ -556,72 +560,6 @@ def test_approximate_gamma_mom(self, pars): assert np.isclose(va_t_j, ck_va_t_j) - - -class TestPriorMomentMatching: - """ - Test approximation of the conditional coalescent prior via - moment matching to a gamma distribution - """ - - n = 10 - priors = prior.ConditionalCoalescentTimes(False) - priors.add(n) - - @pytest.mark.parametrize("k", np.arange(2, 10)) - def test_conditional_coalescent_pdf(self, k): - """ - Check that the utility function matches the implementation in - `tsdate.prior` - """ - mean, _ = scipy.integrate.quad( - lambda x: x * conditional_coalescent_pdf(x, self.n, k), 0, np.inf - ) - var, _ = scipy.integrate.quad( - lambda x: x**2 * conditional_coalescent_pdf(x, self.n, k), 0, np.inf - ) - var -= mean**2 - mean_column = prior.PriorParams.field_index("mean") - var_column = prior.PriorParams.field_index("var") - assert np.isclose(mean, self.priors[self.n][k][mean_column]) - assert np.isclose(var, self.priors[self.n][k][var_column]) - - @pytest.mark.parametrize("k", np.arange(2, 10)) - def test_approximate_gamma(self, k): - """ - Test that matching gamma to Taylor-series-approximated sufficient - statistics will result in lower KL divergence than matching to - mean/variance - """ - mean_column = prior.PriorParams.field_index("mean") - var_column = prior.PriorParams.field_index("var") - x = self.priors[self.n][k][mean_column] - xvar = self.priors[self.n][k][var_column] - # match mean/variance - alpha_0, beta_0 = approx.approximate_gamma_mom(x, xvar) - ck_x = (alpha_0 + 1) / beta_0 - ck_xvar = (alpha_0 + 1) / beta_0**2 - assert np.isclose(x, ck_x) - assert np.isclose(xvar, ck_xvar) - # match approximate sufficient statistics - logx, _, _ = approx.approximate_log_moments(x, xvar) - alpha_1, beta_1 = approx.approximate_gamma_kl(x, logx) - ck_x = (alpha_1 + 1) / beta_1 - ck_logx = hypergeo._digamma(alpha_1 + 1) - np.log(beta_1) - assert np.isclose(x, ck_x) - assert np.isclose(logx, ck_logx) - # compare KL divergence between strategies - kl_0 = kl_divergence( - lambda x: conditional_coalescent_pdf(x, self.n, k), - lambda x: scipy.stats.gamma.logpdf(x, alpha_0 + 1, scale=1 / beta_0), - ) - kl_1 = kl_divergence( - lambda x: conditional_coalescent_pdf(x, self.n, k), - lambda x: scipy.stats.gamma.logpdf(x, alpha_1 + 1, scale=1 / beta_1), - ) - assert kl_1 < kl_0 - - class TestGammaFactorization: """ Test various functions for manipulating factorizations of gamma distributions diff --git a/tsdate/approx.py b/tsdate/approx.py index 7f3aa621..cbe2aa32 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -185,18 +185,16 @@ def average_gammas(alpha, beta): return approximate_gamma_kl(avg_x, avg_logx) -@numba.njit(_b(_f, _f, _f)) -def _valid_moments(mn, ln, va): +@numba.njit(_b(_f, _f)) +def _valid_moments(mn, va): if not (mn > 0.0 and va > 0.0): return False - if not (ln < log(mn)): - return False return True # --- node posteriors --- # -@numba.njit(_unituple(_f, 7)(_f, _f, _f, _f, _f, _f)) +@numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): """ Calculate sufficient statistics for the PDF proportional to @@ -240,67 +238,15 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): mn_j = d1 / t sq_j = d2 / t**2 va_j = sq_j - mn_j**2 - ln_j = np.log(mn_j) - va_j / 2 / mn_j**2 if mn_j > 0 else -np.inf mn_i = mn_j * z + b / t sq_i = sq_j * z**2 + (b + 1) * (mn_i + mn_j * z) / t va_i = sq_i - mn_i**2 - ln_i = np.log(mn_i) - va_i / 2 / mn_i**2 if mn_i > 0 else -np.inf - - return logl, mn_i, ln_i, va_i, mn_j, ln_j, va_j - -# @numba.njit("UniTuple(f8, 4)(f8, f8, f8, f8, f8)") -# def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): -# """ -# Numerically unstable. -# -# Calculate sufficient statistics for the PDF proportional to -# :math:`Ga(t_i | a_i, b_i) Po(y_{ij} | \\mu_{ij} t_i - t_j)`, where -# :math:`i` is the parent and :math:`j` is the child. The logarithmic moments -# are approximated via a Taylor expansion around the mean. -# -# :param float t_j: the age of the child -# :param float a_i: the shape parameter of the cavity distribution for the parent -# :param float b_i: the rate parameter of the cavity distribution for the parent -# :param float y_ij: the number of mutations on the edge -# :param float mu_ij: the span-weighted mutation rate of the edge -# -# :return: normalizing constant, E[t_i], E[log t_i], V[t_i] -# """ -# -# assert t_j >= 0.0 -# -# if t_j == 0.0: -# shape = a_i + y_ij -# rate = mu_ij + b_i -# logl = hypergeo._gammaln(shape) - shape * log(rate) -# mn_i = shape / rate -# va_i = shape / rate**2 -# ln_i = hypergeo._digamma(shape) - log(rate) -# return logl, mn_i, ln_i, va_i -# -# a = y_ij + 1 -# b = a_i + y_ij + 1 -# z = t_j * (mu_ij + b_i) -# -# hyperu = hypergeo._hyperu_laplace -# f0 = hyperu(a + 0, b + 0, z) -# f1 = hyperu(a + 1, b + 1, z) -# f2 = hyperu(a + 2, b + 2, z) -# d1 = a * exp(f1 - f0) -# d2 = a * (a + 1) * exp(f2 - f0) -# -# logl = f0 - b_i * t_j + (b - 1) * log(t_j) + hypergeo._gammaln(a) -# mn_i = t_j * (1 + d1) -# sq_i = t_j**2 * (1 + 2 * d1 + d2) -# va_i = sq_i - mn_i**2 -# ln_i = log(mn_i) - va_i / 2 / mn_i**2 -# -# return logl, mn_i, ln_i, va_i + return logl, mn_i, va_i, mn_j, va_j -@numba.njit(_unituple(_f, 4)(_f, _f, _f, _f, _f)) +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): """ Calculate sufficient statistics for the PDF proportional to @@ -329,8 +275,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): logl = hypergeo._gammaln(s) - s * log(r) mn_i = s / r va_i = s / r**2 - ln_i = hypergeo._digamma(s) - log(r) - return logl, mn_i, ln_i, va_i + return logl, mn_i, va_i hyperu = hypergeo._hyperu_laplace f0, d0 = hyperu(a + 0, b + 0, z) @@ -339,12 +284,11 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): logl = f0 - b_i * t_j + (b - 1) * log(t_j) + hypergeo._gammaln(a) mn_i = t_j * (1 - d0) va_i = t_j**2 * d0 * (d1 - d0) - ln_i = log(mn_i) - va_i / 2 / mn_i**2 if mn_i > 0 else -np.inf - return logl, mn_i, ln_i, va_i + return logl, mn_i, va_i -@numba.njit(_unituple(_f, 4)(_f, _f, _f, _f, _f)) +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): """ Calculate sufficient statistics for the PDF proportional to @@ -378,9 +322,8 @@ def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): mn_j = t_i * d1 sq_j = t_i**2 * d2 va_j = sq_j - mn_j**2 - ln_j = log(mn_j) - va_j / 2 / mn_j**2 if mn_j > 0 else -np.inf - return logl, mn_j, ln_j, va_j + return logl, mn_j, va_j @numba.njit(_b(_f, _f, _f, _f, _f, _f)) @@ -430,8 +373,8 @@ def _hyperu_valid_parameterization(t_j, a_i, b_i, y, mu): return False return True -@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r, _b)) -def gamma_projection(pars_i, pars_j, pars_ij, min_kl): +@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) +def gamma_projection(pars_i, pars_j, pars_ij): """ Match a pair of gamma distributions to the potential function :math:`Ga(t_j | a_j + 1, b_j) Ga(t_i | a_i + 1, b_i) Po(y_{ij} | @@ -441,7 +384,6 @@ def gamma_projection(pars_i, pars_j, pars_ij, min_kl): :param float pars_i: gamma natural parameters for the parent cavity distribution :param float pars_j: gamma natural parameters for the child cavity distribution :param float pars_ij: gamma natural parameters for the edge likelihood - :param bool min_kl: minimize KL divergence (match central moments if False) :return: normalizing constant, gamma natural parameters for parent and child """ @@ -456,7 +398,7 @@ def gamma_projection(pars_i, pars_j, pars_ij, min_kl): if not _hyp2f1_valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): return np.nan, pars_i, pars_j - logconst, t_i, ln_t_i, va_t_i, t_j, ln_t_j, va_t_j = moments( + logconst, t_i, va_t_i, t_j, va_t_j = moments( a_i, b_i, a_j, @@ -465,23 +407,17 @@ def gamma_projection(pars_i, pars_j, pars_ij, min_kl): mu_ij, ) - valid_i = _valid_moments(t_i, ln_t_i, va_t_i) - valid_j = _valid_moments(t_j, ln_t_j, va_t_j) - if not (valid_i and valid_j): + if not _valid_moments(t_i, va_t_i) or not _valid_moments(t_j, va_t_j): return np.nan, pars_i, pars_j - if min_kl: - proj_i = approximate_gamma_kl(t_i, ln_t_i) - proj_j = approximate_gamma_kl(t_j, ln_t_j) - else: - proj_i = approximate_gamma_mom(t_i, va_t_i) - proj_j = approximate_gamma_mom(t_j, va_t_j) + proj_i = approximate_gamma_mom(t_i, va_t_i) + proj_j = approximate_gamma_mom(t_j, va_t_j) return logconst, np.array(proj_i), np.array(proj_j) -@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r, _b)) -def leafward_projection(t_i, pars_j, pars_ij, min_kl): +@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) +def leafward_projection(t_i, pars_j, pars_ij): r""" Match a gamma distributions to the potential function :math:`Ga(t_j | a_j + 1, b_j) Po(y_{ij} | \mu_{ij} t_i - t_j)`, where :math:`i` is the parent and @@ -490,7 +426,6 @@ def leafward_projection(t_i, pars_j, pars_ij, min_kl): :param float t_i: the age of the parent :param float pars_j: gamma natural parameters for the child cavity distribution :param float pars_ij: gamma natural parameters for the edge likelihood - :param bool min_kl: minimize KL divergence (match central moments if False) :return: normalizing constant, gamma natural parameters for child """ @@ -504,22 +439,18 @@ def leafward_projection(t_i, pars_j, pars_ij, min_kl): if not _hyp1f1_valid_parameterization(t_i, a_j, b_j, y_ij, mu_ij): return np.nan, pars_j - logconst, t_j, ln_t_j, va_t_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) + logconst, t_j, va_t_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) - valid_j = _valid_moments(t_j, ln_t_j, va_t_j) - if not valid_j: + if not _valid_moments(t_j, va_t_j): return np.nan, pars_j - if min_kl: - proj_j = approximate_gamma_kl(t_j, ln_t_j) - else: - proj_j = approximate_gamma_mom(t_j, va_t_j) + proj_j = approximate_gamma_mom(t_j, va_t_j) return logconst, np.array(proj_j) -@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r, _b)) -def rootward_projection(t_j, pars_i, pars_ij, min_kl): +@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) +def rootward_projection(t_j, pars_i, pars_ij): r""" Match a gamma distributions to the potential function :math:`Ga(t_i | a_i + 1, b_i) Po(y_{ij} | \mu_{ij} t_i - t_j)`, where :math:`i` is the parent and @@ -528,7 +459,6 @@ def rootward_projection(t_j, pars_i, pars_ij, min_kl): :param float t_j: the age of the child :param float pars_i: gamma natural parameters for the parent cavity distribution :param float pars_ij: gamma natural parameters for the edge likelihood - :param bool min_kl: minimize KL divergence (match central moments if False) :return: normalizing constant, gamma natural parameters for child """ @@ -542,23 +472,19 @@ def rootward_projection(t_j, pars_i, pars_ij, min_kl): if not _hyperu_valid_parameterization(t_j, a_i, b_i, y_ij, mu_ij): return np.nan, pars_i - logconst, t_i, ln_t_i, va_t_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) + logconst, t_i, va_t_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) - valid_i = _valid_moments(t_i, ln_t_i, va_t_i) - if not valid_i: + if not _valid_moments(t_i, va_t_i): return np.nan, pars_i - if min_kl: - proj_i = approximate_gamma_kl(t_i, ln_t_i) - else: - proj_i = approximate_gamma_mom(t_i, va_t_i) + proj_i = approximate_gamma_mom(t_i, va_t_i) return logconst, np.array(proj_i) # --- mutation posteriors from node posteriors --- # -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f)) +@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f)) def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" Calculate gamma sufficient statistics for the PDF proportional to: @@ -575,19 +501,18 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): Returns E[x], E[\log x], V[x]. """ - f, t_i, _, _, t_j, _, _ = moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) - f_ii, _, _, _, _, _, _ = moments(a_i + 2, b_i, a_j, b_j, y_ij, mu_ij) - f_ij, _, _, _, _, _, _ = moments(a_i + 1, b_i, a_j + 1, b_j, y_ij, mu_ij) - f_jj, _, _, _, _, _, _ = moments(a_i, b_i, a_j + 2, b_j, y_ij, mu_ij) + f, t_i, _, t_j, _ = moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) + f_ii, _, _, _, _ = moments(a_i + 2, b_i, a_j, b_j, y_ij, mu_ij) + f_ij, _, _, _, _ = moments(a_i + 1, b_i, a_j + 1, b_j, y_ij, mu_ij) + f_jj, _, _, _, _ = moments(a_i, b_i, a_j + 2, b_j, y_ij, mu_ij) mn_m = t_i / 2 + t_j / 2 sq_m = 1 / 3 * (np.exp(f_ii - f) + np.exp(f_ij - f) + np.exp(f_jj - f)) va_m = sq_m - mn_m**2 - ln_m = log(mn_m) - va_m / 2 / mn_m**2 if mn_m > 0 else -np.inf - return mn_m, ln_m, va_m + return mn_m, va_m -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) +@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): r""" Calculate gamma sufficient statistics for the PDF proportional to: @@ -604,16 +529,15 @@ def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): Returns E[x], E[\log x], V[x]. """ - _, mn_i, _, va_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) + _, mn_i, va_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) mn_m = mn_i / 2 + t_j / 2 sq_m = (va_i + mn_i**2 + mn_i * t_j + t_j**2) / 3 va_m = sq_m - mn_m**2 - ln_m = log(mn_m) - va_m / 2 / mn_m**2 if mn_m > 0 else -np.inf - return mn_m, ln_m, va_m + return mn_m, va_m -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) +@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" Calculate gamma sufficient statistics for the PDF proportional to: @@ -630,17 +554,16 @@ def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): Returns E[x], E[\log x], V[x]. """ - _, mn_j, _, va_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) + _, mn_j, va_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) mn_m = mn_j / 2 + t_i / 2 sq_m = (va_j + mn_j**2 + mn_j * t_i + t_i**2) / 3 va_m = sq_m - mn_m**2 - ln_m = log(mn_m) - va_m / 2 / mn_m**2 if mn_m > 0 else -np.inf - return mn_m, ln_m, va_m + return mn_m, va_m -@numba.njit(_f1r(_f1r, _f1r, _f1r, _b)) -def mutation_gamma_projection(pars_i, pars_j, pars_ij, min_kl): +@numba.njit(_f1r(_f1r, _f1r, _f1r)) +def mutation_gamma_projection(pars_i, pars_j, pars_ij): r""" Match a gamma distribution via KL minimization to the potential function @@ -668,22 +591,18 @@ def mutation_gamma_projection(pars_i, pars_j, pars_ij, min_kl): if not _hyp2f1_valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): return np.full(2, np.nan) - t_m, ln_t_m, va_t_m = mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) + t_m, va_t_m = mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) - valid = _valid_moments(t_m, ln_t_m, va_t_m) - if not valid: + if not _valid_moments(t_m, va_t_m): return np.full(2, np.nan) - if min_kl: - proj_m = approximate_gamma_kl(t_m, ln_t_m) - else: - proj_m = approximate_gamma_mom(t_m, va_t_m) + proj_m = approximate_gamma_mom(t_m, va_t_m) return np.array(proj_m) -@numba.njit(_f1r(_f, _f1r, _f1r, _b)) -def mutation_leafward_projection(t_i, pars_j, pars_ij, min_kl): +@numba.njit(_f1r(_f, _f1r, _f1r)) +def mutation_leafward_projection(t_i, pars_j, pars_ij): r""" Match a gamma distribution via KL minimization to the potential function @@ -710,22 +629,18 @@ def mutation_leafward_projection(t_i, pars_j, pars_ij, min_kl): if not _hyp1f1_valid_parameterization(t_i, a_j, b_j, y_ij, mu_ij): return np.full(2, np.nan) - t_m, ln_t_m, va_t_m = mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) + t_m, va_t_m = mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) - valid = _valid_moments(t_m, ln_t_m, va_t_m) - if not valid: + if not _valid_moments(t_m, va_t_m): return np.full(2, np.nan) - if min_kl: - proj_m = approximate_gamma_kl(t_m, ln_t_m) - else: - proj_m = approximate_gamma_mom(t_m, va_t_m) + proj_m = approximate_gamma_mom(t_m, va_t_m) return np.array(proj_m) -@numba.njit(_f1r(_f, _f1r, _f1r, _b)) -def mutation_rootward_projection(t_j, pars_i, pars_ij, min_kl): +@numba.njit(_f1r(_f, _f1r, _f1r)) +def mutation_rootward_projection(t_j, pars_i, pars_ij): r""" Match a gamma distribution via KL minimization to the potential function @@ -752,23 +667,19 @@ def mutation_rootward_projection(t_j, pars_i, pars_ij, min_kl): if not _hyperu_valid_parameterization(t_j, a_i, b_i, y_ij, mu_ij): return np.full(2, np.nan) - t_m, ln_t_m, va_t_m = mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) + t_m, va_t_m = mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) - valid = _valid_moments(t_m, ln_t_m, va_t_m) - if not valid: + if not _valid_moments(t_m, va_t_m): return np.full(2, np.nan) - if min_kl: - proj_m = approximate_gamma_kl(t_m, ln_t_m) - else: - proj_m = approximate_gamma_mom(t_m, va_t_m) + proj_m = approximate_gamma_mom(t_m, va_t_m) return np.array(proj_m) # --- unphased node posteriors --- # -@numba.njit(_unituple(_f, 7)(_f, _f, _f, _f, _f, _f)) +@numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): """ Calculate sufficient statistics for the PDF proportional to :math:`Ga(t_j | @@ -812,17 +723,15 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): mn_j = d1 / t sq_j = d2 / t**2 va_j = sq_j - mn_j**2 - ln_j = np.log(mn_j) - va_j / 2 / mn_j**2 if mn_j > 0 else -np.inf mn_i = -mn_j * z + b / t sq_i = sq_j * z**2 + (b + 1) * (mn_i - mn_j * z) / t va_i = sq_i - mn_i**2 - ln_i = np.log(mn_i) - va_i / 2 / mn_i**2 if mn_i > 0 else -np.inf - return logl, mn_i, ln_i, va_i, mn_j, ln_j, va_j + return logl, mn_i, va_i, mn_j, va_j -@numba.njit(_unituple(_f, 4)(_f, _f, _f, _f, _f)) +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): """ Calculate sufficient statistics for the PDF proportional to :math:`Ga(t_j | @@ -853,12 +762,11 @@ def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + hypergeo._gammaln(a) mn_j = -t_i * d0 va_j = t_i**2 * d0 * (d1 - d0) - ln_j = log(mn_j) - va_j / 2 / mn_j**2 if mn_j > 0 else -np.inf - return logl, mn_j, ln_j, va_j + return logl, mn_j, va_j -@numba.njit(_unituple(_f, 4)(_f, _f, _f, _f, _f, _f)) +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f)) def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" Calculate gamma sufficient statistics for the PDF proportional to: @@ -908,13 +816,12 @@ def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): mn_m = s0 * s1 * exp(f012 - f000) / 2 + s0 * s2 * exp(f212 - f000) / 2 sq_m = d0 * d1 * exp(f023 - f000) / 3 + d0 * d2 * exp(f323 - f000) / 3 va_m = sq_m - mn_m**2 - ln_m = log(mn_m) - va_m / 2 / mn_m**2 if mn_m > 0 else -np.inf pr_m = (c - a) / c * exp(f001 - f000) - return pr_m, mn_m, ln_m, va_m + return pr_m, mn_m, va_m -@numba.njit(_unituple(_f, 4)(_f, _f, _f, _f, _f)) +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" Calculate gamma sufficient statistics for the PDF proportional to: @@ -966,9 +873,8 @@ def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): # note that when y_ij == 0 then a == b + 1 and f00 = z**(-a) va_m = sq_m - mn_m**2 - ln_m = log(mn_m) - va_m / 2 / mn_m**2 if mn_m > 0 else -np.inf - return pr_m, mn_m, ln_m, va_m + return pr_m, mn_m, va_m @numba.njit(_b(_f, _f, _f, _f, _f, _f)) @@ -995,8 +901,8 @@ def _hyp2f1_unphased_valid_parameterization(a_i, b_i, a_j, b_j, y, mu): return True -@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r, _b)) -def unphased_projection(pars_i, pars_j, pars_ij, min_kl): +@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) +def unphased_projection(pars_i, pars_j, pars_ij): """ Match a pair of gamma distributions to the potential function :math:`Ga(t_j | a_j + 1, b_j) Ga(t_i | a_i + 1, b_i) Po(y_{ij} | \\mu_{ij} t_i + t_j)`, @@ -1008,7 +914,6 @@ def unphased_projection(pars_i, pars_j, pars_ij, min_kl): :param float pars_j: gamma natural parameters for the second parent's cavity distribution :param float pars_ij: gamma natural parameters for the edge pair likelihood - :param bool min_kl: minimize KL divergence (match central moments if False) :return: normalizing constant, gamma natural parameters for parents """ @@ -1021,10 +926,9 @@ def unphased_projection(pars_i, pars_j, pars_ij, min_kl): a_j += 1 if not _hyp2f1_unphased_valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): - print("DEBUG") #DEBUG return np.nan, pars_i, pars_j - logconst, t_i, ln_t_i, va_t_i, t_j, ln_t_j, va_t_j = unphased_moments( + logconst, t_i, va_t_i, t_j, va_t_j = unphased_moments( a_i, b_i, a_j, @@ -1033,16 +937,10 @@ def unphased_projection(pars_i, pars_j, pars_ij, min_kl): mu_ij, ) - valid_i = _valid_moments(t_i, ln_t_i, va_t_i) - valid_j = _valid_moments(t_j, ln_t_j, va_t_j) - if not (valid_i and valid_j): + if not _valid_moments(t_i, va_t_i) or not _valid_moments(t_j, va_t_j): return np.nan, pars_i, pars_j - if min_kl: - proj_i = approximate_gamma_kl(t_i, ln_t_i) - proj_j = approximate_gamma_kl(t_j, ln_t_j) - else: - proj_i = approximate_gamma_mom(t_i, va_t_i) - proj_j = approximate_gamma_mom(t_j, va_t_j) + proj_i = approximate_gamma_mom(t_i, va_t_i) + proj_j = approximate_gamma_mom(t_j, va_t_j) return logconst, np.array(proj_i), np.array(proj_j) diff --git a/tsdate/core.py b/tsdate/core.py index a90cbf7e..e23ceb96 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1255,7 +1255,6 @@ def run( eps, max_iterations, max_shape, - match_central_moments, rescaling_intervals, match_segregating_sites, regularise_roots, @@ -1270,12 +1269,10 @@ def run( raise ValueError("Variational gamma method requires mutation rate") # match sufficient statistics or match central moments - min_kl = not match_central_moments dynamic_prog = self.main_algorithm() dynamic_prog.run( ep_maxitt=max_iterations, max_shape=max_shape, - min_kl=min_kl, rescale_intervals=rescaling_intervals, regularise=regularise_roots, rescale_segsites=match_segregating_sites, @@ -1578,9 +1575,9 @@ def variational_gamma( rescaling_intervals=None, # deliberately undocumented parameters below. We may eventually document these max_shape=None, - match_central_moments=None, match_segregating_sites=None, regularise_roots=None, + rescaling_intervals=None, **kwargs, ): """ @@ -1643,8 +1640,6 @@ def variational_gamma( max_shape = 1000 if rescaling_intervals is None: rescaling_intervals = DEFAULT_RESCALING_INTERVALS - if match_central_moments is None: - match_central_moments = True if match_segregating_sites is None: match_segregating_sites = False if regularise_roots is None: @@ -1660,7 +1655,6 @@ def variational_gamma( eps=eps, max_iterations=max_iterations, max_shape=max_shape, - match_central_moments=match_central_moments, rescaling_intervals=rescaling_intervals, match_segregating_sites=match_segregating_sites, regularise_roots=regularise_roots, diff --git a/tsdate/variational.py b/tsdate/variational.py index e6db3eb5..1c970969 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -272,7 +272,7 @@ def __init__(self, ts, likelihoods, constraints, mutations_edge): ) @staticmethod - @numba.njit(_f(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f, _b)) + @numba.njit(_f(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f)) def propagate_likelihood( edge_order, edges_parent, @@ -285,7 +285,6 @@ def propagate_likelihood( scale, max_shape, min_step, - min_kl, ): """ Update approximating factors for Poisson mutation likelihoods on edges. @@ -308,7 +307,6 @@ def propagate_likelihood( scaling factor for the posteriors, updated in-place. :param float max_shape: the maximum allowed shape for node posteriors. :param float min_step: the minimum allowed step size in (0, 1). - :param bool min_kl: minimize KL divergence or match central moments. """ assert constraints.shape == posterior.shape @@ -341,7 +339,7 @@ def posterior_damping(x): # match moments and update factor parent_age = constraints[p, LOWER] lognorm[i], posterior[c] = approx.leafward_projection( - parent_age, child_cavity, edge_likelihood, min_kl + parent_age, child_cavity, edge_likelihood, ) factors[i, LEAFWARD] *= 1.0 - child_delta factors[i, LEAFWARD] += (posterior[c] - child_cavity) / scale[c] @@ -361,7 +359,7 @@ def posterior_damping(x): # match moments and update factor child_age = constraints[c, LOWER] lognorm[i], posterior[p] = approx.rootward_projection( - child_age, parent_cavity, edge_likelihood, min_kl + child_age, parent_cavity, edge_likelihood, ) factors[i, ROOTWARD] *= 1.0 - parent_delta @@ -385,7 +383,7 @@ def posterior_damping(x): # match moments and update factors lognorm[i], posterior[p], posterior[c] = approx.gamma_projection( - parent_cavity, child_cavity, edge_likelihood, min_kl + parent_cavity, child_cavity, edge_likelihood, ) factors[i, ROOTWARD] *= 1.0 - delta factors[i, ROOTWARD] += (posterior[p] - parent_cavity) / scale[p] @@ -501,7 +499,7 @@ def posterior_damping(x): # return np.nan @staticmethod - @numba.njit(_f2w(_i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b)) + @numba.njit(_f2w(_i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r)) def propagate_mutations( mutations_edge, edges_parent, @@ -511,7 +509,6 @@ def propagate_mutations( posterior, factors, scale, - min_kl, ): """ Calculate posteriors for mutations. @@ -531,7 +528,6 @@ def propagate_mutations( edge, updated in-place. :param ndarray scale: array of dimension `[num_nodes]` containing a scaling factor for the posteriors, updated in-place. - :param bool min_kl: minimize KL divergence or match central moments. """ # TODO: scale should be 1.0, can we delete @@ -564,7 +560,7 @@ def propagate_mutations( edge_likelihood = child_delta * likelihoods[i] parent_age = constraints[p, LOWER] mutations_posterior[m] = approx.mutation_leafward_projection( - parent_age, child_cavity, edge_likelihood, min_kl + parent_age, child_cavity, edge_likelihood, ) elif fixed[c] and not fixed[p]: parent_message = factors[i, ROOTWARD] * scale[p] @@ -573,7 +569,7 @@ def propagate_mutations( edge_likelihood = parent_delta * likelihoods[i] child_age = constraints[c, LOWER] mutations_posterior[m] = approx.mutation_rootward_projection( - child_age, parent_cavity, edge_likelihood, min_kl + child_age, parent_cavity, edge_likelihood, ) else: parent_message = factors[i, ROOTWARD] * scale[p] @@ -585,7 +581,7 @@ def propagate_mutations( child_cavity = posterior[c] - delta * child_message edge_likelihood = delta * likelihoods[i] mutations_posterior[m] = approx.mutation_gamma_projection( - parent_cavity, child_cavity, edge_likelihood, min_kl + parent_cavity, child_cavity, edge_likelihood, ) return mutations_posterior @@ -611,7 +607,6 @@ def iterate( min_step=0.1, em_maxitt=100, em_reltol=1e-8, - min_kl=False, regularise=True, check_valid=False, ): @@ -633,7 +628,6 @@ def iterate( self.scale, max_shape, min_step, - min_kl, ) # exponential regularization on roots @@ -711,7 +705,6 @@ def run( ep_maxitt=10, max_shape=1000, min_step=0.1, - min_kl=False, rescale_intervals=1000, rescale_segsites=False, regularise=True, @@ -726,7 +719,6 @@ def run( self.iterate( max_shape=max_shape, min_step=min_step, - min_kl=min_kl, regularise=regularise, ) nodes_timing -= time.time() @@ -745,7 +737,6 @@ def run( self.posterior, self.edge_factors, self.scale, - min_kl, ) muts_timing -= time.time() skipped_muts = np.sum(np.isnan(self.mutations_posterior[:, 0])) From ebac4815737de3b81b94997a6e6d0859bd736969 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Thu, 16 May 2024 18:46:43 -0700 Subject: [PATCH 03/29] Add aux file with exact moments --- tests/exact_moments.py | 189 +++++++++++++++++++++++++++++++ tests/test_approximations.py | 2 - tsdate/approx.py | 210 ++++++++++++++--------------------- 3 files changed, 273 insertions(+), 128 deletions(-) create mode 100644 tests/exact_moments.py diff --git a/tests/exact_moments.py b/tests/exact_moments.py new file mode 100644 index 00000000..5616ba14 --- /dev/null +++ b/tests/exact_moments.py @@ -0,0 +1,189 @@ +""" +Moments for EP updates using exact hypergeometric evaluations rather than +a Laplace approximation; intended for testing and accuracy benchmarking. +""" + +import mpmath +import numpy as np +from scipy.special import betaln, gammaln +from math import log, exp + + +def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): + """ + log p(t_i, t_j) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + \ + log(t_j) * (a_j - 1) - mu_ij * t_j + """ + a = a_j + b = a_i + a_j + y_ij + c = a_j + y_ij + 1 + t = mu_ij + b_i + z = (mu_ij - b_j) / t + f0 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 0, c + 0, z))) + f1 = float(mpmath.log(mpmath.hyp2f1(a + 1, b + 1, c + 1, z))) + f2 = float(mpmath.log(mpmath.hyp2f1(a + 2, b + 2, c + 2, z))) + logl = f0 + betaln(y_ij + 1, a) + gammaln(b) - b * log(t) + mn_j = d1 / t + sq_j = d2 / t**2 + va_j = sq_j - mn_j**2 + mn_i = mn_j * z + b / t + sq_i = sq_j * z**2 + (b + 1) * (mn_i + mn_j * z) / t + va_i = sq_i - mn_i**2 + return logl, mn_i, va_i, mn_j, va_j + + +def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): + r""" + log p(t_i) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + """ + assert t_j >= 0.0 + s = a_i + y_ij + r = mu_ij + b_i + a = y_ij + 1 + b = s + 1 + z = t_j * r + if t_j == 0.0: + logl = gammaln(s) - s * log(r) + mn_i = s / r + va_i = s / r**2 + return logl, mn_i, va_i + f0 = float(mpmath.log(mpmath.hyperu(a + 0, b + 0, z))) + f1 = float(mpmath.log(mpmath.hyperu(a + 1, b + 1, z))) + f2 = float(mpmath.log(mpmath.hyperu(a + 2, b + 2, z))) + d0 = -a * f1 / f0 + d1 = -(a + 1) * f2 / f1 + logl = f0 - b_i * t_j + (b - 1) * log(t_j) + gammaln(a) + mn_i = t_j * (1 - d0) + va_i = t_j**2 * d0 * (d1 - d0) + return logl, mn_i, va_i + + +def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_i, t_j) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_j) * (a_j - 1) - mu_ij * t_j + """ + assert t_i > 0.0 + a = a_j + b = a_j + y_ij + 1 + z = t_i * (mu_ij - b_j) + f0 = float(mpmath.log(mpmath.hyp1f1(a + 0, b + 0, z))) + f1 = float(mpmath.log(mpmath.hyp1f1(a + 1, b + 1, z))) + f2 = float(mpmath.log(mpmath.hyp1f1(a + 2, b + 2, z))) + d1 = a / b * exp(f1 - f0) + d2 = a / b * (a + 1) / (b + 1) * exp(f2 - f0) + logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + betaln(a, b - a) + mn_j = t_i * d1 + sq_j = t_i**2 * d2 + va_j = sq_j - mn_j**2 + return logl, mn_j, va_j + + +def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_i, t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + \ + log(t_j) * (a_j - 1) - mu_ij * t_j + """ + a = a_j + b = a_i + a_j + y_ij + c = a_j + a_i + t = mu_ij + b_i + z = (mu_ij + b_j) / t + f0 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 0, c + 0, 1 - z))) + f1 = float(mpmath.log(mpmath.hyp2f1(a + 1, b + 1, c + 1, 1 - z))) + f2 = float(mpmath.log(mpmath.hyp2f1(a + 2, b + 2, c + 2, 1 - z))) + s1 = a * b / c + s2 = s1 * (a + 1) * (b + 1) / (c + 1) + d1 = s1 * exp(f1 - f0) + d2 = s2 * exp(f2 - f0) + logl = f0 + _betaln(a_j, a_i) + _gammaln(b) - b * log(t) + mn_j = d1 / t + sq_j = d2 / t**2 + va_j = sq_j - mn_j**2 + mn_i = b / t - mn_j * z + sq_i = sq_j * z**2 + (b + 1) * (mn_i - mn_j * z) / t + va_i = sq_i - mn_i**2 + return logl, mn_i, va_i, mn_j, va_j + + +def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_i, t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + \ + log(t_j) * (a_j - 1) - mu_ij * t_j + """ + assert t_i > 0.0 + a = a_j + b = a_j + y_ij + 1 + z = t_i * (mu_ij + b_j) + f0, d0 = float(mpmath.log(mpmath.hyperu(a + 0, b + 0, z))) + f1, d1 = float(mpmath.log(mpmath.hyperu(a + 1, b + 1, z))) + logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + gammaln(a) + mn_j = -t_i * d0 + va_j = t_i**2 * d0 * (d1 - d0) + return logl, mn_j, va_j + + +def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_m, t_i, t_j) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + \ + log(t_j) * (a_j - 1) - mu_ij * t_j - \ + log(t_i - t_j) * int(t_j < t_m < t_i) + """ + a = a_j + b = a_i + a_j + y_ij + c = a_j + y_ij + 1 + t = mu_ij + b_i + z = (mu_ij - b_j) / t + f000 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 0, c + 0, z))) + f020 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 2, c + 0, z))) + f111 = float(mpmath.log(mpmath.hyp2f1(a + 1, b + 1, c + 1, z))) + f121 = float(mpmath.log(mpmath.hyp2f1(a + 1, b + 2, c + 1, z))) + f222 = float(mpmath.log(mpmath.hyp2f1(a + 2, b + 2, c + 2, z))) + s1 = a * b / c + s2 = s1 * (a + 1) * (b + 1) / (c + 1) + d1 = b * (b + 1) / t ** 2 + d2 = d1 * a / c + d3 = d2 * (a + 1) / (c + 1) + mn_m = s1 * exp(f111 - f000) / t / 2 * (1 + z) + b / t / 2 + sq_m = d1 * exp(f020 - f000) / 3 + d2 * exp(f121 - f000) / 3 + d3 * exp(f222 - f000) / 3 + va_m = sq_m - mn_m**2 + return mn_m, va_m + + +def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): + r""" + log p(t_m, t_i) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + \ + log(t_i - t_j) * int(t_j < t_m < t_i) + """ + logl, mn_i, va_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) + mn_m = mn_i / 2 + t_j / 2 + sq_m = (va_i + mn_i**2 + mn_i * t_j + t_j**2) / 3 + va_m = sq_m - mn_m**2 + return mn_m, va_m + + +def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_m, t_j) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_j) * (a_j - 1) - mu_ij * t_j - \ + log(t_i - t_j) * int(t_j < t_m < t_i) + """ + logl, mn_j, va_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) + mn_m = mn_j / 2 + t_i / 2 + sq_m = (va_j + mn_j**2 + mn_j * t_i + t_i**2) / 3 + va_m = sq_m - mn_m**2 + return mn_m, va_m diff --git a/tests/test_approximations.py b/tests/test_approximations.py index 09e599d7..6dc60538 100644 --- a/tests/test_approximations.py +++ b/tests/test_approximations.py @@ -29,8 +29,6 @@ import scipy.integrate import scipy.special import scipy.stats -from distribution_functions import conditional_coalescent_pdf -from distribution_functions import kl_divergence from tsdate import approx from tsdate import hypergeo diff --git a/tsdate/approx.py b/tsdate/approx.py index cbe2aa32..dd4068ab 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -187,11 +187,38 @@ def average_gammas(alpha, beta): @numba.njit(_b(_f, _f)) def _valid_moments(mn, va): + if not (np.isfinite(mn) and np.isfinite(va)): + return False if not (mn > 0.0 and va > 0.0): return False return True +@numba.njit(_b(_f, _f, _f)) +def _valid_hyp1f1(a, b, z): + if not (b >= a > 0.0): + return False + return True + + +@numba.njit(_b(_f, _f, _f)) +def _valid_hyperu(a, b, z): + if z <= 0.0: + return False + if not (b > a > 0.0): + return False + return True + + +@numba.njit(_b(_f, _f, _f, _f)) +def _valid_hyp2f1(a, b, c, z): + if z >= 1 or z / (z - 1) >= 1.0: + return False + if a <= 0 or b <= 0 or c <= 0: + return False + return True + + # --- node posteriors --- # @numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) @@ -220,10 +247,6 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): t = mu_ij + b_i z = (mu_ij - b_j) / t - # with numba.objmode(f0='f8', f1='f8', f2='f8'): - # f0 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 0, c + 0, z))) - # f1 = float(mpmath.log(mpmath.hyp2f1(a + 1, b + 1, c + 1, z))) - # f2 = float(mpmath.log(mpmath.hyp2f1(a + 2, b + 2, c + 2, z))) hyp2f1 = hypergeo._hyp2f1_laplace f0 = hyp2f1(a + 0, b + 0, c + 0, z) f1 = hyp2f1(a + 1, b + 1, c + 1, z) @@ -233,7 +256,7 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): d1 = s1 * np.exp(f1 - f0) d2 = s2 * np.exp(f2 - f0) - logl = f0 + hypergeo._betaln(y_ij + 1, a) + hypergeo._gammaln(b) - b * np.log(t) + logl = f0 + hypergeo._betaln(y_ij + 1, a) + hypergeo._gammaln(b) - b * log(t) mn_j = d1 / t sq_j = d2 / t**2 @@ -388,7 +411,6 @@ def gamma_projection(pars_i, pars_j, pars_ij): :return: normalizing constant, gamma natural parameters for parent and child """ - # switch from natural to canonical parameterization a_i, b_i = pars_i a_j, b_j = pars_j y_ij, mu_ij = pars_ij @@ -398,22 +420,15 @@ def gamma_projection(pars_i, pars_j, pars_ij): if not _hyp2f1_valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): return np.nan, pars_i, pars_j - logconst, t_i, va_t_i, t_j, va_t_j = moments( - a_i, - b_i, - a_j, - b_j, - y_ij, - mu_ij, - ) + logl, mn_i, va_i, mn_j, va_j = moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) - if not _valid_moments(t_i, va_t_i) or not _valid_moments(t_j, va_t_j): + if not _valid_moments(mn_i, va_i) or not _valid_moments(mn_j, va_j): return np.nan, pars_i, pars_j - proj_i = approximate_gamma_mom(t_i, va_t_i) - proj_j = approximate_gamma_mom(t_j, va_t_j) + proj_i = approximate_gamma_mom(mn_i, va_i) + proj_j = approximate_gamma_mom(mn_j, va_j) - return logconst, np.array(proj_i), np.array(proj_j) + return logl, np.array(proj_i), np.array(proj_j) @numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) @@ -430,23 +445,21 @@ def leafward_projection(t_i, pars_j, pars_ij): :return: normalizing constant, gamma natural parameters for child """ - # switch from natural to canonical parameterization a_j, b_j = pars_j y_ij, mu_ij = pars_ij a_j += 1 - # skip update, zeroing out message if not _hyp1f1_valid_parameterization(t_i, a_j, b_j, y_ij, mu_ij): return np.nan, pars_j - logconst, t_j, va_t_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) + logl, mn_j, va_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) - if not _valid_moments(t_j, va_t_j): + if not _valid_moments(mn_j, va_j): return np.nan, pars_j - proj_j = approximate_gamma_mom(t_j, va_t_j) + proj_j = approximate_gamma_mom(mn_j, va_j) - return logconst, np.array(proj_j) + return logl, np.array(proj_j) @numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) @@ -463,23 +476,21 @@ def rootward_projection(t_j, pars_i, pars_ij): :return: normalizing constant, gamma natural parameters for child """ - # switch from natural to canonical parameterization a_i, b_i = pars_i y_ij, mu_ij = pars_ij a_i += 1 - # skip update, zeroing out message if not _hyperu_valid_parameterization(t_j, a_i, b_i, y_ij, mu_ij): return np.nan, pars_i - logconst, t_i, va_t_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) + logl, mn_i, va_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) - if not _valid_moments(t_i, va_t_i): + if not _valid_moments(mn_i, va_i): return np.nan, pars_i - proj_i = approximate_gamma_mom(t_i, va_t_i) + proj_i = approximate_gamma_mom(mn_i, va_i) - return logconst, np.array(proj_i) + return logl, np.array(proj_i) # --- mutation posteriors from node posteriors --- # @@ -487,26 +498,36 @@ def rootward_projection(t_j, pars_i, pars_ij): @numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f)) def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" - Calculate gamma sufficient statistics for the PDF proportional to: + log p(t_m, t_i, t_j) = \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + \ + log(t_j) * (a_j - 1) - mu_ij * t_j - \ + log(t_i - t_j) * int(t_j < t_m < t_i) - ..math:: + Returns E[t_m], V[t_m]. + """ - p(x) = \int_0^\infty \int_0^{t_i} Unif(x | t_i, t_j) - Ga(t_i | a_i, b_i) Ga(t_j | a_j b_j) Po(y | \mu_ij (t_i - t_j)) dt_j dt_i + a = a_j + b = a_i + a_j + y_ij + c = a_j + y_ij + 1 + t = mu_ij + b_i + z = (mu_ij - b_j) / t - which models the time :math:`x` of a mutation uniformly distributed between - parent age :math:`t_i` and child age :math:`t_j`, on a branch with - :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. + hyp2f1 = hypergeo._hyp2f1_laplace + f000 = hyp2f1(a + 0, b + 0, c + 0, z) + f020 = hyp2f1(a + 0, b + 2, c + 0, z) + f111 = hyp2f1(a + 1, b + 1, c + 1, z) + f121 = hyp2f1(a + 1, b + 2, c + 1, z) + f222 = hyp2f1(a + 2, b + 2, c + 2, z) - Returns E[x], E[\log x], V[x]. - """ + s1 = a * b / c + s2 = s1 * (a + 1) * (b + 1) / (c + 1) + d1 = b * (b + 1) / t ** 2 + d2 = d1 * a / c + d3 = d2 * (a + 1) / (c + 1) - f, t_i, _, t_j, _ = moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) - f_ii, _, _, _, _ = moments(a_i + 2, b_i, a_j, b_j, y_ij, mu_ij) - f_ij, _, _, _, _ = moments(a_i + 1, b_i, a_j + 1, b_j, y_ij, mu_ij) - f_jj, _, _, _, _ = moments(a_i, b_i, a_j + 2, b_j, y_ij, mu_ij) - mn_m = t_i / 2 + t_j / 2 - sq_m = 1 / 3 * (np.exp(f_ii - f) + np.exp(f_ij - f) + np.exp(f_jj - f)) + mn_m = s1 * exp(f111 - f000) / t / 2 * (1 + z) + b / t / 2 + sq_m = d1 * exp(f020 - f000) / 3 + d2 * exp(f121 - f000) / 3 + d3 * exp(f222 - f000) / 3 va_m = sq_m - mn_m**2 return mn_m, va_m @@ -581,7 +602,6 @@ def mutation_gamma_projection(pars_i, pars_j, pars_ij): :return: gamma parameters for mutation age """ - # switch from natural to canonical parameterization a_i, b_i = pars_i a_j, b_j = pars_j y_ij, mu_ij = pars_ij @@ -591,12 +611,12 @@ def mutation_gamma_projection(pars_i, pars_j, pars_ij): if not _hyp2f1_valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): return np.full(2, np.nan) - t_m, va_t_m = mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) + mn_m, va_m = mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) - if not _valid_moments(t_m, va_t_m): + if not _valid_moments(mn_m, va_m): return np.full(2, np.nan) - proj_m = approximate_gamma_mom(t_m, va_t_m) + proj_m = approximate_gamma_mom(mn_m, va_m) return np.array(proj_m) @@ -620,21 +640,19 @@ def mutation_leafward_projection(t_i, pars_j, pars_ij): :return: gamma parameters for mutation age """ - # switch from natural to canonical parameterization a_j, b_j = pars_j y_ij, mu_ij = pars_ij a_j += 1 - # skip update, zeroing out message if not _hyp1f1_valid_parameterization(t_i, a_j, b_j, y_ij, mu_ij): return np.full(2, np.nan) - t_m, va_t_m = mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) + mn_m, va_m = mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) - if not _valid_moments(t_m, va_t_m): + if not _valid_moments(mn_m, va_m): return np.full(2, np.nan) - proj_m = approximate_gamma_mom(t_m, va_t_m) + proj_m = approximate_gamma_mom(mn_m, va_m) return np.array(proj_m) @@ -658,21 +676,19 @@ def mutation_rootward_projection(t_j, pars_i, pars_ij): :return: gamma parameters for mutation age """ - # switch from natural to canonical parameterization a_i, b_i = pars_i y_ij, mu_ij = pars_ij a_i += 1 - # skip update, zeroing out message if not _hyperu_valid_parameterization(t_j, a_i, b_i, y_ij, mu_ij): return np.full(2, np.nan) - t_m, va_t_m = mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) + mn_m, va_m = mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) - if not _valid_moments(t_m, va_t_m): + if not _valid_moments(mn_m, va_m): return np.full(2, np.nan) - proj_m = approximate_gamma_mom(t_m, va_t_m) + proj_m = approximate_gamma_mom(mn_m, va_m) return np.array(proj_m) @@ -705,26 +721,22 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): t = mu_ij + b_i z = (mu_ij + b_j) / t - #with numba.objmode(f0='f8', f1='f8', f2='f8'): - # f0 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 0, c + 0, 1 - z))) - # f1 = float(mpmath.log(mpmath.hyp2f1(a + 1, b + 1, c + 1, 1 - z))) - # f2 = float(mpmath.log(mpmath.hyp2f1(a + 2, b + 2, c + 2, 1 - z))) hyp2f1 = hypergeo._hyp2f1_laplace f0 = hyp2f1(a + 0, b + 0, c + 0, 1 - z) f1 = hyp2f1(a + 1, b + 1, c + 1, 1 - z) f2 = hyp2f1(a + 2, b + 2, c + 2, 1 - z) s1 = a * b / c s2 = s1 * (a + 1) * (b + 1) / (c + 1) - d1 = s1 * np.exp(f1 - f0) - d2 = s2 * np.exp(f2 - f0) + d1 = s1 * exp(f1 - f0) + d2 = s2 * exp(f2 - f0) - logl = f0 + hypergeo._betaln(a_j, a_i) + hypergeo._gammaln(b) - b * np.log(t) + logl = f0 + hypergeo._betaln(a_j, a_i) + hypergeo._gammaln(b) - b * log(t) mn_j = d1 / t sq_j = d2 / t**2 va_j = sq_j - mn_j**2 - mn_i = -mn_j * z + b / t + mn_i = b / t - mn_j * z sq_i = sq_j * z**2 + (b + 1) * (mn_i - mn_j * z) / t va_i = sq_i - mn_i**2 @@ -877,70 +889,16 @@ def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): return pr_m, mn_m, va_m -@numba.njit(_b(_f, _f, _f, _f, _f, _f)) -def _hyp2f1_unphased_valid_parameterization(a_i, b_i, a_j, b_j, y, mu): - """Uses shape / rate parameterization""" - a = a_j - b = a_i + a_j + y - c = a_j + a_i - s = mu + b_j - t = mu + b_i - # check that 2F1 argument is less than unity - if t <= 0.0: - return False - z = 1.0 - s / t - if z >= 1.0 or z / (z - 1) >= 1.0: - return False - # check that 2F1 is positive - if a <= 0: - return False - if b <= 0: - return False - if c <= 0: - return False - return True - - @numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) def unphased_projection(pars_i, pars_j, pars_ij): - """ - Match a pair of gamma distributions to the potential function :math:`Ga(t_j - | a_j + 1, b_j) Ga(t_i | a_i + 1, b_i) Po(y_{ij} | \\mu_{ij} t_i + t_j)`, - where :math:`i` and :math:`j` are parents of the same individual (assumed - to be at time zero), by minimizing KL divergence. - - :param float pars_i: gamma natural parameters for the first parent's cavity - distribution - :param float pars_j: gamma natural parameters for the second parent's - cavity distribution - :param float pars_ij: gamma natural parameters for the edge pair likelihood - - :return: normalizing constant, gamma natural parameters for parents - """ - - # switch from natural to canonical parameterization a_i, b_i = pars_i a_j, b_j = pars_j y_ij, mu_ij = pars_ij a_i += 1 a_j += 1 - - if not _hyp2f1_unphased_valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): - return np.nan, pars_i, pars_j - - logconst, t_i, va_t_i, t_j, va_t_j = unphased_moments( - a_i, - b_i, - a_j, - b_j, - y_ij, - mu_ij, - ) - - if not _valid_moments(t_i, va_t_i) or not _valid_moments(t_j, va_t_j): + logl, mn_i, va_i, mn_j, va_j = unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) + if not _valid_moments(mn_i, va_i) or not _valid_moments(mn_j, va_j): return np.nan, pars_i, pars_j - - proj_i = approximate_gamma_mom(t_i, va_t_i) - proj_j = approximate_gamma_mom(t_j, va_t_j) - - return logconst, np.array(proj_i), np.array(proj_j) + proj_i = approximate_gamma_mom(mn_i, va_i) + proj_j = approximate_gamma_mom(mn_j, va_j) + return logl, np.array(proj_i), np.array(proj_j) From e623324890f8c665a13c615936a06aab53934250 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Fri, 17 May 2024 15:54:55 -0700 Subject: [PATCH 04/29] WIP --- tests/exact_moments.py | 2 +- tsdate/approx.py | 253 ++++++++++++++++++----------------------- tsdate/phasing.py | 60 ++++++++++ tsdate/variational.py | 6 +- 4 files changed, 176 insertions(+), 145 deletions(-) diff --git a/tests/exact_moments.py b/tests/exact_moments.py index 5616ba14..cfff45d5 100644 --- a/tests/exact_moments.py +++ b/tests/exact_moments.py @@ -64,7 +64,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" - log p(t_i, t_j) := \ + log p(t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ log(t_j) * (a_j - 1) - mu_ij * t_j """ diff --git a/tsdate/approx.py b/tsdate/approx.py index dd4068ab..05ddbf4c 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -27,6 +27,7 @@ from math import inf from math import lgamma from math import log +from math import nan import mpmath import numba @@ -196,6 +197,8 @@ def _valid_moments(mn, va): @numba.njit(_b(_f, _f, _f)) def _valid_hyp1f1(a, b, z): + if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(z)): + return False if not (b >= a > 0.0): return False return True @@ -203,6 +206,8 @@ def _valid_hyp1f1(a, b, z): @numba.njit(_b(_f, _f, _f)) def _valid_hyperu(a, b, z): + if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(z)): + return False if z <= 0.0: return False if not (b > a > 0.0): @@ -212,40 +217,34 @@ def _valid_hyperu(a, b, z): @numba.njit(_b(_f, _f, _f, _f)) def _valid_hyp2f1(a, b, c, z): - if z >= 1 or z / (z - 1) >= 1.0: + if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(c)): + return False + if not np.isfinite(z) or z >= 1 or z / (z - 1) >= 1: return False - if a <= 0 or b <= 0 or c <= 0: + if not (a > 0 and b > 0 and c > 0): return False return True -# --- node posteriors --- # - @numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): """ - Calculate sufficient statistics for the PDF proportional to - :math:`Ga(t_j | a_j, b_j) Ga(t_i | a_i, b_i) Po(y_{ij} | - \\mu_{ij} t_i - t_j)`, where :math:`i` is the parent and :math:`j` is - the child. The logarithmic moments are approximated via a Taylor - expansion around the mean. - - :param float a_i: the shape parameter of the cavity distribution for the parent - :param float b_i: the rate parameter of the cavity distribution for the parent - :param float a_j: the shape parameter of the cavity distribution for the child - :param float b_j: the rate parameter of the cavity distribution for the child - :param float y_ij: the number of mutations on the edge - :param float mu_ij: the span-weighted mutation rate of the edge + log p(t_i, t_j) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + \ + log(t_j) * (a_j - 1) - mu_ij * t_j - :return: normalizing constant, E[t_i], E[log t_i], V[t_i], - E[t_j], E[log t_j], V[t_j] + Returns normalizing constant, E[t_i], V[t_i], E[t_j], V[t_j]. """ a = a_j b = a_i + a_j + y_ij c = a_j + y_ij + 1 t = mu_ij + b_i - z = (mu_ij - b_j) / t + z = (mu_ij - b_j) / t if t > 0 else nan + + if not _valid_hyp2f1(a, b, c, z): + return nan, nan, nan, nan, nan hyp2f1 = hypergeo._hyp2f1_laplace f0 = hyp2f1(a + 0, b + 0, c + 0, z) @@ -253,8 +252,8 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): f2 = hyp2f1(a + 2, b + 2, c + 2, z) s1 = a * b / c s2 = s1 * (a + 1) * (b + 1) / (c + 1) - d1 = s1 * np.exp(f1 - f0) - d2 = s2 * np.exp(f2 - f0) + d1 = s1 * exp(f1 - f0) + d2 = s2 * exp(f2 - f0) logl = f0 + hypergeo._betaln(y_ij + 1, a) + hypergeo._gammaln(b) - b * log(t) @@ -272,18 +271,11 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): @numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): """ - Calculate sufficient statistics for the PDF proportional to - :math:`Ga(t_i | a_i, b_i) Po(y_{ij} | \\mu_{ij} t_i - t_j)`, where - :math:`i` is the parent and :math:`j` is the child. The logarithmic moments - are approximated via a Taylor expansion around the mean. - - :param float t_j: the age of the child - :param float a_i: the shape parameter of the cavity distribution for the parent - :param float b_i: the rate parameter of the cavity distribution for the parent - :param float y_ij: the number of mutations on the edge - :param float mu_ij: the span-weighted mutation rate of the edge + log p(t_i) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i - :return: normalizing constant, E[t_i], E[log t_i], V[t_i] + Returns normalizing constant, E[t_i], V[t_i]. """ assert t_j >= 0.0 @@ -314,18 +306,11 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): @numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): """ - Calculate sufficient statistics for the PDF proportional to - :math:`Ga(t_j | a_j, b_j) Po(y_{ij} | \\mu_{ij} t_i - t_j)`, where - :math:`i` is the parent and :math:`j` is the child. The logarithmic moments - are approximated via a Taylor expansion around the mean. - - :param float t_i: the age of the parent - :param float a_j: the shape parameter of the cavity distribution for the child - :param float b_j: the rate parameter of the cavity distribution for the child - :param float y_ij: the number of mutations on the edge - :param float mu_ij: the span-weighted mutation rate of the edge + log p(t_j) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_j) * (a_j - 1) - mu_ij * t_j - :return: normalizing constant, E[t_j], E[log t_j], V[t_j] + Returns normalizing constant, E[t_j], V[t_j]. """ assert t_i > 0.0 @@ -349,6 +334,82 @@ def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): return logl, mn_j, va_j +@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f)) +def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_m, t_i, t_j) = \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + \ + log(t_j) * (a_j - 1) - mu_ij * t_j - \ + log(t_i - t_j) * int(t_j < t_m < t_i) + + Returns E[t_m], V[t_m]. + """ + + a = a_j + b = a_i + a_j + y_ij + c = a_j + y_ij + 1 + t = mu_ij + b_i + z = (mu_ij - b_j) / t + + hyp2f1 = hypergeo._hyp2f1_laplace + f000 = hyp2f1(a + 0, b + 0, c + 0, z) + f020 = hyp2f1(a + 0, b + 2, c + 0, z) + f111 = hyp2f1(a + 1, b + 1, c + 1, z) + f121 = hyp2f1(a + 1, b + 2, c + 1, z) + f222 = hyp2f1(a + 2, b + 2, c + 2, z) + + s1 = a * b / c + s2 = s1 * (a + 1) * (b + 1) / (c + 1) + d1 = b * (b + 1) / t ** 2 + d2 = d1 * a / c + d3 = d2 * (a + 1) / (c + 1) + + mn_m = s1 * exp(f111 - f000) / t / 2 * (1 + z) + b / t / 2 + sq_m = d1 * exp(f020 - f000) / 3 + d2 * exp(f121 - f000) / 3 + d3 * exp(f222 - f000) / 3 + va_m = sq_m - mn_m**2 + + return mn_m, va_m + + +@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) +def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): + r""" + log p(t_m, t_i) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + \ + log(t_i - t_j) * int(t_j < t_m < t_i) + + Returns E[t_m], V[t_m]. + """ + + logl, mn_i, va_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) + mn_m = mn_i / 2 + t_j / 2 + sq_m = (va_i + mn_i**2 + mn_i * t_j + t_j**2) / 3 + va_m = sq_m - mn_m**2 + + return mn_m, va_m + + +@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) +def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_m, t_j) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_j) * (a_j - 1) - mu_ij * t_j - \ + log(t_i - t_j) * int(t_j < t_m < t_i) + + Returns E[t_m], V[t_m]. + """ + + logl, mn_j, va_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) + mn_m = mn_j / 2 + t_i / 2 + sq_m = (va_j + mn_j**2 + mn_j * t_i + t_i**2) / 3 + va_m = sq_m - mn_m**2 + + return mn_m, va_m + + @numba.njit(_b(_f, _f, _f, _f, _f, _f)) def _hyp2f1_valid_parameterization(a_i, b_i, a_j, b_j, y, mu): """Uses shape / rate parameterization""" @@ -399,16 +460,12 @@ def _hyperu_valid_parameterization(t_j, a_i, b_i, y, mu): @numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) def gamma_projection(pars_i, pars_j, pars_ij): """ - Match a pair of gamma distributions to the potential function - :math:`Ga(t_j | a_j + 1, b_j) Ga(t_i | a_i + 1, b_i) Po(y_{ij} | - \\mu_{ij} t_i - t_j)`, where :math:`i` is the parent and :math:`j` is - the child, by minimizing KL divergence. - - :param float pars_i: gamma natural parameters for the parent cavity distribution - :param float pars_j: gamma natural parameters for the child cavity distribution - :param float pars_ij: gamma natural parameters for the edge likelihood + log p(t_i, t_j) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - mu_ij * t_i + \ + log(t_j) * (a_j - 1) - mu_ij * t_j - :return: normalizing constant, gamma natural parameters for parent and child + Returns normalizing constant, gamma natural parameters for parent and child. """ a_i, b_i = pars_i @@ -417,12 +474,9 @@ def gamma_projection(pars_i, pars_j, pars_ij): a_i += 1 a_j += 1 - if not _hyp2f1_valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): - return np.nan, pars_i, pars_j - logl, mn_i, va_i, mn_j, va_j = moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) - if not _valid_moments(mn_i, va_i) or not _valid_moments(mn_j, va_j): + if not (_valid_moments(mn_i, va_i) and _valid_moments(mn_j, va_j)): return np.nan, pars_i, pars_j proj_i = approximate_gamma_mom(mn_i, va_i) @@ -495,93 +549,6 @@ def rootward_projection(t_j, pars_i, pars_ij): # --- mutation posteriors from node posteriors --- # -@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f)) -def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): - r""" - log p(t_m, t_i, t_j) = \ - log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + \ - log(t_j) * (a_j - 1) - mu_ij * t_j - \ - log(t_i - t_j) * int(t_j < t_m < t_i) - - Returns E[t_m], V[t_m]. - """ - - a = a_j - b = a_i + a_j + y_ij - c = a_j + y_ij + 1 - t = mu_ij + b_i - z = (mu_ij - b_j) / t - - hyp2f1 = hypergeo._hyp2f1_laplace - f000 = hyp2f1(a + 0, b + 0, c + 0, z) - f020 = hyp2f1(a + 0, b + 2, c + 0, z) - f111 = hyp2f1(a + 1, b + 1, c + 1, z) - f121 = hyp2f1(a + 1, b + 2, c + 1, z) - f222 = hyp2f1(a + 2, b + 2, c + 2, z) - - s1 = a * b / c - s2 = s1 * (a + 1) * (b + 1) / (c + 1) - d1 = b * (b + 1) / t ** 2 - d2 = d1 * a / c - d3 = d2 * (a + 1) / (c + 1) - - mn_m = s1 * exp(f111 - f000) / t / 2 * (1 + z) + b / t / 2 - sq_m = d1 * exp(f020 - f000) / 3 + d2 * exp(f121 - f000) / 3 + d3 * exp(f222 - f000) / 3 - va_m = sq_m - mn_m**2 - - return mn_m, va_m - - -@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) -def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): - r""" - Calculate gamma sufficient statistics for the PDF proportional to: - - ..math:: - - p(x) = \int_{t_j}^\infty Unif(x | t_i, t_j) - Ga(t_i | a_i, b_i) Po(y | \mu_ij (t_i - t_j)) dt_i - - which models the time :math:`x` of a mutation uniformly distributed between - parent age :math:`t_i` and child age :math:`t_j`, on a branch with - :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. - - Returns E[x], E[\log x], V[x]. - """ - - _, mn_i, va_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) - mn_m = mn_i / 2 + t_j / 2 - sq_m = (va_i + mn_i**2 + mn_i * t_j + t_j**2) / 3 - va_m = sq_m - mn_m**2 - - return mn_m, va_m - - -@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) -def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): - r""" - Calculate gamma sufficient statistics for the PDF proportional to: - - ..math:: - - p(x) = \int_0^{t_i} Unif(x | t_i, t_j) - Ga(t_j | a_j, b_j) Po(y | \mu_ij (t_i - t_j)) dt_j - - which models the time :math:`x` of a mutation uniformly distributed between - parent age :math:`t_i` and child age :math:`t_j`, on a branch with - :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. - - Returns E[x], E[\log x], V[x]. - """ - - _, mn_j, va_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) - mn_m = mn_j / 2 + t_i / 2 - sq_m = (va_j + mn_j**2 + mn_j * t_i + t_i**2) / 3 - va_m = sq_m - mn_m**2 - - return mn_m, va_m - @numba.njit(_f1r(_f1r, _f1r, _f1r)) def mutation_gamma_projection(pars_i, pars_j, pars_ij): diff --git a/tsdate/phasing.py b/tsdate/phasing.py index 350dc2e4..cddaee04 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -104,3 +104,63 @@ def rephase_singletons(ts, use_node_times=True, random_seed=None): tables.mutations.time = mutations_time tables.sort() return tables.tree_sequence(), singletons + + +def insert_unphased_singletons(ts, position, individual, reference_state, alternate_state, allow_overlapping_sites=False): + """ + Insert unphased singletons into the tree sequence. The phase is arbitrarily chosen + so that the mutation subtends the node with the lowest id, at a given position for a + a given individual. + + :param tskit.TreeSequence ts: the tree sequence to add singletons to + :param np.ndarray position: the position of the variants + :param np.ndarray individual: the individual id in which the variant occurs + :param np.ndarray reference_state: the reference state of the variant + :param np.ndarray alternate_state: the alternate state of the variant + :param bool allow_overlapping_sites: whether to permit insertion of + singletons at existing sites (in which case the reference states must be + consistent) + + :returns: A copy of the tree sequence with singletons inserted + """ + # TODO: provenance / metdata + tables = ts.dump_tables() + individuals_node = {i.id: min(i.nodes) for i in ts.individuals()} + sites_id = {p: i for i, p in enumerate(ts.sites_position)} + overlap = False + for pos, ind, ref, alt in zip(position, individual, reference_state, alternate_state): + if ind not in individuals_nodes: + raise LookupError(f"Individual {ind} is not in the tree sequence") + if pos in sites_id: + if not allow_overlapping_sites: + raise ValueError(f"A site already exists at position {pos}") + if ref != ts.site(sites_id[pos]).ancestral_state: + raise ValueError( + f"Existing site at position {pos} has a different ancestral state" + ) + overlap = True + else: + sites_id[pos] = tables.sites.add_row(position=pos, ancestral_state=ref) + tables.mutations.add_row( + site=sites_id[pos], + node=individuals_node[ind], + time=tskit.UNKNOWN_TIME, + derived_state=alt, + ) + tables.sort() + if allow_overlapping_sites and overlap: + tables.build_index() + tables.compute_mutation_parents() + return tables.tree_sequence() + + +def unphased_to_likelihood(likelihoods, mutations_phase, mutations_block, block_edges): + for m, b in enumerate(mutations_block): + if b == tskit.NULL: + continue + i, j = block_edges[b] + likelihoods[i] += mutations_phase + likelihoods[j] += 1 - mutations_phase + + + diff --git a/tsdate/variational.py b/tsdate/variational.py index 1c970969..921d3a54 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -460,7 +460,7 @@ def posterior_damping(x): # @staticmethod # @numba.njit(_f(_i2r, _i1r, _f2r, _f2w, _f3w, _f1w, _f)) - # def propagate_unphased( + # def propagate_unphased_likelihoods( # parents, individual, likelihoods, posterior, factors, scale, max_shape # ): # """ @@ -586,6 +586,10 @@ def propagate_mutations( return mutations_posterior + # @staticmethod + # @numba.njit(_f(_i2r, _i1r, _f2r, _f2w, _f3w, _f1w, _f)) + # def propagate_unphased_mutations() + @staticmethod @numba.njit(_void(_i1r, _i1r, _f3w, _f3w, _f1w)) def rescale_factors(edges_parent, edges_child, node_factors, edge_factors, scale): From 71f82eb89c2ae961b4dca2643ffef09b8731e6ee Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Fri, 17 May 2024 16:05:10 -0700 Subject: [PATCH 05/29] WIP --- tsdate/phasing.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tsdate/phasing.py b/tsdate/phasing.py index cddaee04..4f30991e 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -154,13 +154,27 @@ def insert_unphased_singletons(ts, position, individual, reference_state, altern return tables.tree_sequence() -def unphased_to_likelihood(likelihoods, mutations_phase, mutations_block, block_edges): - for m, b in enumerate(mutations_block): +def accumulate_unphased(edges_mutations, mutations_phase, mutations_block, block_edges): + """ + Add a proportion of each unphased singleton mutation to one of the two + edges to which it maps. + """ + unphased = mutations_block != tskit.NULL + assert np.all(mutations_phase[~unphased] == 1.0) + assert np.all( + np.logical_and( + mutations_phase[unphased] <= 1.0, + mutations_phase[unphased] >= 0.0, + ) + ) + for b in mutations_block[unphased]: if b == tskit.NULL: continue i, j = block_edges[b] - likelihoods[i] += mutations_phase - likelihoods[j] += 1 - mutations_phase + edges_mutations[i] += mutations_phase + edges_mutations[j] += 1 - mutations_phase + assert np.sum(edges_mutations) == mutations_block.size + return edges_mutations From e9091d2ad2b57fd80f89c1682752e566a93b9fe8 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Fri, 7 Jun 2024 14:42:52 -0700 Subject: [PATCH 06/29] WIP --- tests/exact_moments.py | 113 +++++-- tsdate/approx.py | 714 ++++++++++++++++++++++------------------- tsdate/phasing.py | 8 +- tsdate/variational.py | 116 ++++--- 4 files changed, 536 insertions(+), 415 deletions(-) diff --git a/tests/exact_moments.py b/tests/exact_moments.py index cfff45d5..0c9d30e6 100644 --- a/tests/exact_moments.py +++ b/tests/exact_moments.py @@ -13,8 +13,8 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): """ log p(t_i, t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + \ - log(t_j) * (a_j - 1) - mu_ij * t_j + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j """ a = a_j b = a_i + a_j + y_ij @@ -35,10 +35,10 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): - r""" + """ log p(t_i) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + log(t_i) * (a_i - 1) - b_i * t_i """ assert t_j >= 0.0 s = a_i + y_ij @@ -63,10 +63,10 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): - r""" + """ log p(t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_j) * (a_j - 1) - mu_ij * t_j + log(t_j) * (a_j - 1) - b_j * t_j """ assert t_i > 0.0 a = a_j @@ -85,11 +85,11 @@ def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): - r""" + """ log p(t_i, t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + \ - log(t_j) * (a_j - 1) - mu_ij * t_j + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j """ a = a_j b = a_i + a_j + y_ij @@ -114,11 +114,10 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): - r""" - log p(t_i, t_j) := \ + """ + log p(t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + \ - log(t_j) * (a_j - 1) - mu_ij * t_j + log(t_j) * (a_j - 1) - b_j * t_j """ assert t_i > 0.0 a = a_j @@ -133,12 +132,12 @@ def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): - r""" + """ log p(t_m, t_i, t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + \ - log(t_j) * (a_j - 1) - mu_ij * t_j - \ - log(t_i - t_j) * int(t_j < t_m < t_i) + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j - \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) """ a = a_j b = a_i + a_j + y_ij @@ -162,11 +161,11 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): - r""" + """ log p(t_m, t_i) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + \ - log(t_i - t_j) * int(t_j < t_m < t_i) + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) """ logl, mn_i, va_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) mn_m = mn_i / 2 + t_j / 2 @@ -176,14 +175,82 @@ def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): - r""" + """ log p(t_m, t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_j) * (a_j - 1) - mu_ij * t_j - \ - log(t_i - t_j) * int(t_j < t_m < t_i) + log(t_j) * (a_j - 1) - b_j * t_j - \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) """ logl, mn_j, va_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) mn_m = mn_j / 2 + t_i / 2 sq_m = (va_j + mn_j**2 + mn_j * t_i + t_i**2) / 3 va_m = sq_m - mn_m**2 return mn_m, va_m + + +def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): + """ + log p(t_m, t_i, t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j + \ + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j)) + """ + a = a_j + b = a_j + a_i + y_ij + c = a_j + a_i + t = mu_ij + b_i + z = (mu_ij + b_j) / t + f000 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 0, c + 0, 1 - z))) + f001 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 0, c + 1, 1 - z))) + f012 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 1, c + 2, 1 - z))) + f023 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 2, c + 3, 1 - z))) + f212 = float(mpmath.log(mpmath.hyp2f1(a + 2, b + 1, c + 2, 1 - z))) + f323 = float(mpmath.log(mpmath.hyp2f1(a + 3, b + 2, c + 3, 1 - z))) + s0 = b / t / c / (c + 1) + s1 = (c - a) * (c - a + 1) + s2 = a * (a + 1) + d0 = s0 * (b + 1) / t / (c + 2) + d1 = s1 * (c - a + 2) + d2 = s2 * (a + 2) + mn_m = s0 * s1 * exp(f012 - f000) / 2 + s0 * s2 * exp(f212 - f000) / 2 + sq_m = d0 * d1 * exp(f023 - f000) / 3 + d0 * d2 * exp(f323 - f000) / 3 + va_m = sq_m - mn_m**2 + pr_m = (c - a) / c * exp(f001 - f000) + return pr_m, mn_m, va_m + + +def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): + """ + log p(t_m, t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_j) * (a_j - 1) - b_j * t_j + \ + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j)) + """ + a = a_j + b = a_j + y_ij + 1 + z = t_i * (mu_ij + b_j) + f00 = float(mpmath.log(mpmath.hyperu(a + 0, b + 0, z))) + f10 = float(mpmath.log(mpmath.hyperu(a + 1, b + 0, z))) + f21 = float(mpmath.log(mpmath.hyperu(a + 2, b + 1, z))) + f32 = float(mpmath.log(mpmath.hyperu(a + 3, b + 2, z))) + pr_m = 1.0 - exp(f10 - f00) * a + mn_m = pr_m * t_i / 2 + t_i * exp(f21 - f00) * a * (a + 1) / 2 + sq_m = pr_m * t_i ** 2 / 3 + t_i ** 2 * exp(f32 - f00) * a * (a + 1) * (a + 2) / 3 + va_m = sq_m - mn_m**2 + return pr_m, mn_m, va_m + + +def unphased_mutation_fixed_moments(t_i, t_j): + """ + log p(t_m) := \ + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j)) + """ + pr_m = t_i / (t_i + t_j) + mn_m = pr_m * t_i / 2 + (1 - pr_m) * t_j / 2 + sq_m = pr_m * t_i ** 2 / 3 + (1 - pr_m) * t_j ** 2 / 3 + va_m = sq_m - mn_m**2 + return pr_m, mn_m, va_m diff --git a/tsdate/approx.py b/tsdate/approx.py index 05ddbf4c..8e7b28ba 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -195,6 +195,15 @@ def _valid_moments(mn, va): return True +@numba.njit(_b(_f, _f)) +def _valid_gamma(s, r): + if not (np.isfinite(s) and np.isfinite(r)): + return False + if s <= 0.0 or r <= 0.0: + return False + return True + + @numba.njit(_b(_f, _f, _f)) def _valid_hyp1f1(a, b, z): if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(z)): @@ -226,13 +235,15 @@ def _valid_hyp2f1(a, b, c, z): return True +# --- various EP updates --- # + @numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): - """ + r""" log p(t_i, t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + \ - log(t_j) * (a_j - 1) - mu_ij * t_j + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j Returns normalizing constant, E[t_i], V[t_i], E[t_j], V[t_j]. """ @@ -270,10 +281,10 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): @numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): - """ + r""" log p(t_i) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + log(t_i) * (a_i - 1) - b_i * t_i Returns normalizing constant, E[t_i], V[t_i]. """ @@ -282,9 +293,9 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): s = a_i + y_ij r = mu_ij + b_i - a = y_ij + 1 - b = s + 1 - z = t_j * r + + if not _valid_gamma(s, r): + return nan, nan, nan if t_j == 0.0: logl = hypergeo._gammaln(s) - s * log(r) @@ -292,6 +303,13 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): va_i = s / r**2 return logl, mn_i, va_i + a = y_ij + 1 + b = s + 1 + z = t_j * r + + if not _valid_hyperu(a, b, z): + return nan, nan, nan + hyperu = hypergeo._hyperu_laplace f0, d0 = hyperu(a + 0, b + 0, z) f1, d1 = hyperu(a + 1, b + 1, z) @@ -305,10 +323,10 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): @numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): - """ + r""" log p(t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_j) * (a_j - 1) - mu_ij * t_j + log(t_j) * (a_j - 1) - b_j * t_j Returns normalizing constant, E[t_j], V[t_j]. """ @@ -319,6 +337,9 @@ def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): b = a_j + y_ij + 1 z = t_i * (mu_ij - b_j) + if not _valid_hyp1f1(a, b, z): + return nan, nan, nan + hyp1f1 = hypergeo._hyp1f1_laplace f0 = hyp1f1(a + 0, b + 0, z) f1 = hyp1f1(a + 1, b + 1, z) @@ -334,14 +355,86 @@ def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): return logl, mn_j, va_j +@numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) +def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_i, t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j + + Returns normalizing constant, E[t_i], V[t_i], E[t_j], V[t_j]. + """ + + a = a_j + b = a_i + a_j + y_ij + c = a_j + a_i + t = mu_ij + b_i + z = (mu_ij + b_j) / t if t > 0 else nan + + if not _valid_hyp2f1(a, b, c, z): + return nan, nan, nan, nan, nan + + hyp2f1 = hypergeo._hyp2f1_laplace + f0 = hyp2f1(a + 0, b + 0, c + 0, 1 - z) + f1 = hyp2f1(a + 1, b + 1, c + 1, 1 - z) + f2 = hyp2f1(a + 2, b + 2, c + 2, 1 - z) + s1 = a * b / c + s2 = s1 * (a + 1) * (b + 1) / (c + 1) + d1 = s1 * exp(f1 - f0) + d2 = s2 * exp(f2 - f0) + + logl = f0 + hypergeo._betaln(a_j, a_i) + hypergeo._gammaln(b) - b * log(t) + + mn_j = d1 / t + sq_j = d2 / t**2 + va_j = sq_j - mn_j**2 + + mn_i = b / t - mn_j * z + sq_i = sq_j * z**2 + (b + 1) * (mn_i - mn_j * z) / t + va_i = sq_i - mn_i**2 + + return logl, mn_i, va_i, mn_j, va_j + + +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) +def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_j) * (a_j - 1) - b_j * t_j + + Returns normalizing constant, E[t_j], V[t_j]. + """ + + assert t_i > 0.0 + + a = a_j + b = a_j + y_ij + 1 + z = t_i * (mu_ij + b_j) + + if not _valid_hyperu(a, b, z): + return nan, nan, nan + + hyperu = hypergeo._hyperu_laplace + f0, d0 = hyperu(a + 0, b + 0, z) + f1, d1 = hyperu(a + 1, b + 1, z) + + logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + hypergeo._gammaln(a) + mn_j = -t_i * d0 + va_j = t_i**2 * d0 * (d1 - d0) + + return logl, mn_j, va_j + + @numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f)) def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_m, t_i, t_j) = \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + \ - log(t_j) * (a_j - 1) - mu_ij * t_j - \ - log(t_i - t_j) * int(t_j < t_m < t_i) + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j - \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) Returns E[t_m], V[t_m]. """ @@ -350,7 +443,10 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): b = a_i + a_j + y_ij c = a_j + y_ij + 1 t = mu_ij + b_i - z = (mu_ij - b_j) / t + z = (mu_ij - b_j) / t if t > 0 else nan + + if not _valid_hyp2f1(a, b, c, z): + return nan, nan hyp2f1 = hypergeo._hyp2f1_laplace f000 = hyp2f1(a + 0, b + 0, c + 0, z) @@ -377,8 +473,8 @@ def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): r""" log p(t_m, t_i) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + \ - log(t_i - t_j) * int(t_j < t_m < t_i) + log(t_i) * (a_i - 1) - b_i * t_i - \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) Returns E[t_m], V[t_m]. """ @@ -396,8 +492,8 @@ def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_m, t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_j) * (a_j - 1) - mu_ij * t_j - \ - log(t_i - t_j) * int(t_j < t_m < t_i) + log(t_j) * (a_j - 1) - b_j * t_j - \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) Returns E[t_m], V[t_m]. """ @@ -410,64 +506,155 @@ def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): return mn_m, va_m -@numba.njit(_b(_f, _f, _f, _f, _f, _f)) -def _hyp2f1_valid_parameterization(a_i, b_i, a_j, b_j, y, mu): - """Uses shape / rate parameterization""" +@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) +def mutation_fixed_moments(t_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_m) := \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) + + Returns E[t_m], V[t_m]. + """ + + mn_m = 1 / 2 * (t_i + t_j) + va_m = 1 / 12 * (t_i - t_j) ** 2 + + return mn_m, va_m + + +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f)) +def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_m, t_i, t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j + \ + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j)) + + Returns P[m under i], E[t_m], V[t_m]. + """ + + # Conditioning on ages of parents: + # P[x under i | t_i, t_j] = t_i / (t_i + t_j) + # E[x | x under t_i, t_i] = t_i / 2 + # E[x^2 | x under t_i, t_i] = t_i**2 / 3 + # and equivalently for t_j. Integrating these moments over the EP surrogate + # density leads to hypergeometric functions similar to the node case, but + # with integer perturbations of a_i, a_j, y_ij. + a = a_j - b = a_i + a_j + y - c = a_j + y + 1 - s = mu - b_j - t = mu + b_i - # check that 2F1 argument is less than unity - if t <= 0.0: - return False - z = s / t - if z >= 1.0 or z / (z - 1) >= 1.0: - return False - # check that 2F1 is positive - if a <= 0: - return False - if b <= 0: - return False - if c <= 0: - return False - return True + b = a_j + a_i + y_ij + c = a_j + a_i + t = mu_ij + b_i + z = (mu_ij + b_j) / t if t > 0 else nan + + if not _valid_hyp2f1(a, b, c, z): + return nan, nan + hyp2f1 = hypergeo._hyp2f1_laplace + f000 = hyp2f1(a + 0, b + 0, c + 0, 1 - z) + f001 = hyp2f1(a + 0, b + 0, c + 1, 1 - z) + f012 = hyp2f1(a + 0, b + 1, c + 2, 1 - z) + f023 = hyp2f1(a + 0, b + 2, c + 3, 1 - z) + f212 = hyp2f1(a + 2, b + 1, c + 2, 1 - z) + f323 = hyp2f1(a + 3, b + 2, c + 3, 1 - z) + + s0 = b / t / c / (c + 1) + s1 = (c - a) * (c - a + 1) + s2 = a * (a + 1) + d0 = s0 * (b + 1) / t / (c + 2) + d1 = s1 * (c - a + 2) + d2 = s2 * (a + 2) + + mn_m = s0 * s1 * exp(f012 - f000) / 2 + s0 * s2 * exp(f212 - f000) / 2 + sq_m = d0 * d1 * exp(f023 - f000) / 3 + d0 * d2 * exp(f323 - f000) / 3 + va_m = sq_m - mn_m**2 + pr_m = (c - a) / c * exp(f001 - f000) + + return pr_m, mn_m, va_m + + +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) +def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): + r""" + log p(t_m, t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_j) * (a_j - 1) - b_j * t_j + \ + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j)) + + Returns P[m under i], E[t_m], V[t_m]. + """ + + assert t_i > 0 + + # Conditioning on ages of parents: + # P[x under i | t_i, t_j] = t_i / (t_i + t_j) + # E[x | x under t_i, t_i] = t_i / 2 + # E[x^2 | x under t_i, t_i] = t_i**2 / 3 + # and equivalently for t_j. Integrating these moments over the EP surrogate + # density leads to Tricomi functions similar to the node case, but + # with integer perturbations of a_j, y_ij. -@numba.njit(_b(_f, _f, _f, _f, _f)) -def _hyp1f1_valid_parameterization(t_i, a_j, b_j, y, mu): - """Uses shape / rate parameterization""" a = a_j - b = a_j + y + 1 - if not (b > a > 0.0): - return False - return True + b = a_j + y_ij + 1 + z = t_i * (mu_ij + b_j) + if not _valid_hyperu(a, b, z): + return nan, nan, nan -@numba.njit(_b(_f, _f, _f, _f, _f)) -def _hyperu_valid_parameterization(t_j, a_i, b_i, y, mu): - """Uses shape / rate parameterization""" - a = y + 1 - b = a_i + y + 1 - if t_j < 0.0: - return False - if mu + b_i <= 0.0: - return False - if not (b > a > 0.0): - return False - return True + # direct but unstable: + hyperu = hypergeo._hyperu_laplace + f00, d00 = hyperu(a + 0, b + 0, z) + f10, d10 = hyperu(a + 1, b + 0, z) + f21, d21 = hyperu(a + 2, b + 1, z) + f32, d32 = hyperu(a + 3, b + 2, z) + pr_m = 1.0 - exp(f10 - f00) * a + mn_m = pr_m * t_i / 2 + t_i * exp(f21 - f00) * a * (a + 1) / 2 + sq_m = pr_m * t_i ** 2 / 3 + t_i ** 2 * exp(f32 - f00) * a * (a + 1) * (a + 2) / 3 + + # TODO: use a stabler approach with derivatives + # note that exp(f10 - f00) = (a + z * d00) / (a - b + 1) + # however the denominator is 0 if y_ij is 0 + # note that when y_ij == 0 then a == b + 1 and f00 = z**(-a) + + va_m = sq_m - mn_m**2 + + return pr_m, mn_m, va_m + + +def unphased_mutation_fixed_moments(t_i, t_j): + r""" + log p(t_m) := \ + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j)) + + Returns P[m under i], E[t_m], V[t_m]. + """ + + assert t_i > 0 + assert t_j > 0 + + pr_m = t_i / (t_i + t_j) + mn_m = pr_m * t_i / 2 + (1 - pr_m) * t_j / 2 + sq_m = pr_m * t_i ** 2 / 3 + (1 - pr_m) * t_j ** 2 / 3 + va_m = sq_m - mn_m**2 + + return pr_m, mn_m, va_m + + +# --- wrappers around updates --- # @numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) def gamma_projection(pars_i, pars_j, pars_ij): - """ + r""" log p(t_i, t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - mu_ij * t_i + \ - log(t_j) * (a_j - 1) - mu_ij * t_j + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j - Returns normalizing constant, gamma natural parameters for parent and child. + Returns normalizing constant, gamma natural parameters for parent and child ages """ - a_i, b_i = pars_i a_j, b_j = pars_j y_ij, mu_ij = pars_ij @@ -488,24 +675,16 @@ def gamma_projection(pars_i, pars_j, pars_ij): @numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def leafward_projection(t_i, pars_j, pars_ij): r""" - Match a gamma distributions to the potential function :math:`Ga(t_j | a_j + - 1, b_j) Po(y_{ij} | \mu_{ij} t_i - t_j)`, where :math:`i` is the parent and - :math:`j` is the child, by minimizing KL divergence. - - :param float t_i: the age of the parent - :param float pars_j: gamma natural parameters for the child cavity distribution - :param float pars_ij: gamma natural parameters for the edge likelihood + log p(t_j) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_j) * (a_j - 1) - b_j * t_j - :return: normalizing constant, gamma natural parameters for child + Returns normalizing constant, gamma natural parameters for child age """ - a_j, b_j = pars_j y_ij, mu_ij = pars_ij a_j += 1 - if not _hyp1f1_valid_parameterization(t_i, a_j, b_j, y_ij, mu_ij): - return np.nan, pars_j - logl, mn_j, va_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) if not _valid_moments(mn_j, va_j): @@ -519,24 +698,16 @@ def leafward_projection(t_i, pars_j, pars_ij): @numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def rootward_projection(t_j, pars_i, pars_ij): r""" - Match a gamma distributions to the potential function :math:`Ga(t_i | a_i + - 1, b_i) Po(y_{ij} | \mu_{ij} t_i - t_j)`, where :math:`i` is the parent and - :math:`j` is the child, by minimizing KL divergence. - - :param float t_j: the age of the child - :param float pars_i: gamma natural parameters for the parent cavity distribution - :param float pars_ij: gamma natural parameters for the edge likelihood + log p(t_i) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - b_i * t_i - :return: normalizing constant, gamma natural parameters for child + Returns normalizing constant, gamma natural parameters for parent age """ - a_i, b_i = pars_i y_ij, mu_ij = pars_ij a_i += 1 - if not _hyperu_valid_parameterization(t_j, a_i, b_i, y_ij, mu_ij): - return np.nan, pars_i - logl, mn_i, va_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) if not _valid_moments(mn_i, va_i): @@ -547,325 +718,218 @@ def rootward_projection(t_j, pars_i, pars_ij): return logl, np.array(proj_i) -# --- mutation posteriors from node posteriors --- # +@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) +def unphased_projection(pars_i, pars_j, pars_ij): + r""" + log p(t_i, t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j + Returns normalizing constant, gamma natural parameters for parent ages + """ + a_i, b_i = pars_i + a_j, b_j = pars_j + y_ij, mu_ij = pars_ij + a_i += 1 + a_j += 1 -@numba.njit(_f1r(_f1r, _f1r, _f1r)) -def mutation_gamma_projection(pars_i, pars_j, pars_ij): - r""" - Match a gamma distribution via KL minimization to the potential function + logl, mn_i, va_i, mn_j, va_j = unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) - ..math:: + if not _valid_moments(mn_i, va_i) or not _valid_moments(mn_j, va_j): + return np.nan, pars_i, pars_j - p(x) = \int_0^\infty \int_0^{t_i} Unif(x | t_i, t_j) - Ga(t_i | a_i, b_i) Ga(t_j | a_j b_j) Po(y | \mu_ij (t_i - t_j)) dt_j dt_i + proj_i = approximate_gamma_mom(mn_i, va_i) + proj_j = approximate_gamma_mom(mn_j, va_j) - which models the time :math:`x` of a mutation uniformly distributed between - parent age :math:`t_i` and child age :math:`t_j`, on a branch with - :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. + return logl, np.array(proj_i), np.array(proj_j) - TODO: params - :return: gamma parameters for mutation age - """ +@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) +def unphased_rightward_projection(t_i, pars_j, pars_ij): + r""" + log p(t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_j) * (a_j - 1) - b_j * t_j + Returns normalizing constant, gamma natural parameters for nonfixed parent age + """ a_i, b_i = pars_i a_j, b_j = pars_j y_ij, mu_ij = pars_ij a_i += 1 a_j += 1 - if not _hyp2f1_valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): - return np.full(2, np.nan) + logl, mn_j, va_j = unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij) + + if not _valid_moments(mn_j, va_j): + return np.nan, pars_j + + proj_j = approximate_gamma_mom(mn_j, va_j) + + return logl, np.array(proj_j) + + +@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r)) +def mutation_gamma_projection(pars_i, pars_j, pars_ij): + r""" + log p(t_m, t_i, t_j) = \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j - \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) + + Returns phase probability, gamma natural parameters for mutation age + """ + a_i, b_i = pars_i + a_j, b_j = pars_j + y_ij, mu_ij = pars_ij + a_i += 1 + a_j += 1 mn_m, va_m = mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) if not _valid_moments(mn_m, va_m): - return np.full(2, np.nan) + return np.nan, np.full(2, np.nan) proj_m = approximate_gamma_mom(mn_m, va_m) - return np.array(proj_m) + return 1.0, np.array(proj_m) -@numba.njit(_f1r(_f, _f1r, _f1r)) +@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def mutation_leafward_projection(t_i, pars_j, pars_ij): r""" - Match a gamma distribution via KL minimization to the potential function - - ..math:: - - p(x) = \int_0^{t_i} Unif(x | t_i, t_j) - Ga(t_j | a_j, b_j) Po(y | \mu_ij (t_i - t_j)) dt_j - - which models the time :math:`x` of a mutation uniformly distributed between - parent age :math:`t_i` and child age :math:`t_j`, on a branch with - :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. - - TODO + log p(t_m, t_j) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_j) * (a_j - 1) - b_j * t_j - \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) - :return: gamma parameters for mutation age + Returns phase probability, gamma natural parameters for mutation age """ - a_j, b_j = pars_j y_ij, mu_ij = pars_ij a_j += 1 - if not _hyp1f1_valid_parameterization(t_i, a_j, b_j, y_ij, mu_ij): - return np.full(2, np.nan) - mn_m, va_m = mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) if not _valid_moments(mn_m, va_m): - return np.full(2, np.nan) + return np.nan, np.full(2, np.nan) proj_m = approximate_gamma_mom(mn_m, va_m) - return np.array(proj_m) + return 1.0, np.array(proj_m) -@numba.njit(_f1r(_f, _f1r, _f1r)) +@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def mutation_rootward_projection(t_j, pars_i, pars_ij): r""" - Match a gamma distribution via KL minimization to the potential function - - ..math:: - - p(x) = \int_{t_j}^{\infty} Unif(x | t_i, t_j) - Ga(t_i | a_i, b_i) Po(y | \mu_ij (t_i - t_j)) dt_i - - which models the time :math:`x` of a mutation uniformly distributed between - parent age :math:`t_i` and child age :math:`t_j`, on a branch with - :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. - - TODO + log p(t_m, t_i) := \ + log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ + log(t_i) * (a_i - 1) - b_i * t_i - \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) - :return: gamma parameters for mutation age + Returns phase probability, gamma natural parameters for mutation age """ - a_i, b_i = pars_i y_ij, mu_ij = pars_ij a_i += 1 - if not _hyperu_valid_parameterization(t_j, a_i, b_i, y_ij, mu_ij): - return np.full(2, np.nan) - mn_m, va_m = mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) if not _valid_moments(mn_m, va_m): - return np.full(2, np.nan) + return np.nan, np.full(2, np.nan) proj_m = approximate_gamma_mom(mn_m, va_m) - return np.array(proj_m) - + return 1.0, np.array(proj_m) -# --- unphased node posteriors --- # -@numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) -def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): - """ - Calculate sufficient statistics for the PDF proportional to :math:`Ga(t_j | - a_j, b_j) Ga(t_i | a_i, b_i) Po(y_{ij} | \\mu_{ij} t_i + t_j)`, where - :math:`i` and :math:`j` are parents of the same individual (assumed to be at - time zero). The logarithmic moments are approximated via a Taylor expansion - around the mean. - - :param float a_i: the shape parameter of the cavity distribution for the first parent - :param float b_i: the rate parameter of the cavity distribution for the first parent - :param float a_j: the shape parameter of the cavity distribution for the second parent - :param float b_j: the rate parameter of the cavity distribution for the second parent - :param float y_ij: the number of mutations on the singleton edge pair - :param float mu_ij: the span-weighted mutation rate of the singleton edge pair +@numba.njit(_tuple((_f, _f1r))(_f, _f)) +def mutation_fixed_projection(t_i, t_j): + r""" + log p(t_m) := \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) - :return: normalizing constant, E[t_i], E[log t_i], V[t_i], - E[t_j], E[log t_j], V[t_j] + Returns phase probability, gamma natural parameters for mutation age """ + mn_m, va_m = mutation_fixed_moments(t_i, t_j) - a = a_j - b = a_i + a_j + y_ij - c = a_j + a_i - t = mu_ij + b_i - z = (mu_ij + b_j) / t - - hyp2f1 = hypergeo._hyp2f1_laplace - f0 = hyp2f1(a + 0, b + 0, c + 0, 1 - z) - f1 = hyp2f1(a + 1, b + 1, c + 1, 1 - z) - f2 = hyp2f1(a + 2, b + 2, c + 2, 1 - z) - s1 = a * b / c - s2 = s1 * (a + 1) * (b + 1) / (c + 1) - d1 = s1 * exp(f1 - f0) - d2 = s2 * exp(f2 - f0) - - logl = f0 + hypergeo._betaln(a_j, a_i) + hypergeo._gammaln(b) - b * log(t) - - mn_j = d1 / t - sq_j = d2 / t**2 - va_j = sq_j - mn_j**2 - - mn_i = b / t - mn_j * z - sq_i = sq_j * z**2 + (b + 1) * (mn_i - mn_j * z) / t - va_i = sq_i - mn_i**2 + if not _valid_moments(mn_m, va_m): + return np.nan, np.full(2, np.nan) - return logl, mn_i, va_i, mn_j, va_j + proj_m = approx.approximate_gamma_mom(mn_m, va_m) + return 1.0, np.array(proj_m) -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) -def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): - """ - Calculate sufficient statistics for the PDF proportional to :math:`Ga(t_j | - a_j, b_j) Po(y_{ij} | \\mu_{ij} t_i + t_j)`, where :math:`i` and :math:`j` - are parents of the same individual (assumed to be at time zero). The - logarithmic moments are approximated via a Taylor expansion around the - mean. - :param float t_i: the age of the first parent - :param float a_j: the shape parameter of the cavity distribution for the second parent - :param float b_j: the rate parameter of the cavity distribution for the second parent - :param float y_ij: the number of mutations on the singleton edge pair - :param float mu_ij: the span-weighted mutation rate of the singleton edge pair - - :return: normalizing constant, E[t_j], E[log t_j], V[t_j] +@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r)) +def mutation_unphased_projection(pars_i, pars_j, pars_ij): + r""" + log p(t_m, t_i, t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(t_j) * (a_j - 1) - b_j * t_j + \ + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j)) + + Returns phase probability, gamma natural parameters for mutation age """ + a_i, b_i = pars_i + a_j, b_j = pars_j + y_ij, mu_ij = pars_ij + a_i += 1 + a_j += 1 - assert t_i > 0.0 - - a = a_j - b = a_j + y_ij + 1 - z = t_i * (mu_ij + b_j) + pr_m, mn_m, va_m = unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) - hyperu = hypergeo._hyperu_laplace - f0, d0 = hyperu(a + 0, b + 0, z) - f1, d1 = hyperu(a + 1, b + 1, z) + if not _valid_moments(mn_m, va_m) or not (0 <= pr_m <= 1): + return np.nan, np.full(2, np.nan) - logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + hypergeo._gammaln(a) - mn_j = -t_i * d0 - va_j = t_i**2 * d0 * (d1 - d0) + proj_m = approximate_gamma_mom(mn_m, va_m) - return logl, mn_j, va_j + return pr_m, np.array(proj_m) -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f)) -def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): +@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) +def mutation_unphased_rightward_projection(t_i, pars_j, pars_ij): r""" - Calculate gamma sufficient statistics for the PDF proportional to: - - ..math:: - - p(x) = \int_0^\infty \int_0^\infty (Unif(x | 0, t_i) + Unif(x | 0, t_j)) - Ga(t_i | a_i, b_i) Ga(t_j | a_j b_j) Po(y | \mu_ij (t_i + t_j)) dt_j dt_i - - which models the time :math:`x` of a mutation uniformly distributed between - zero and one of two parents with ages :math:`t_i` and :math:`t_j`, where - the mutation count on both branches is :math:`y_{ij}` with total mutation - rate :math:`\mu_{ij}`. + log p(t_m, t_j) := \ + log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ + log(t_j) * (a_j - 1) - b_j * t_j + \ + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j)) - Returns log P[x under i], E[x], E[\log x], V[x]. + Returns phase probability, gamma natural parameters for mutation age """ + a_j, b_j = pars_j + y_ij, mu_ij = pars_ij + a_j += 1 - # Conditioning on ages of parents: - # P[x under i | t_i, t_j] = t_i / (t_i + t_j) - # E[x | x under t_i, t_i] = t_i / 2 - # E[x^2 | x under t_i, t_i] = t_i**2 / 3 - # and equivalently for t_j. Integrating these moments over the EP surrogate - # density leads to hypergeometric functions similar to the node case, but - # with integer perturbations of a_i, a_j, y_ij. - - a = a_j - b = a_j + a_i + y_ij - c = a_j + a_i - t = mu_ij + b_i - z = (mu_ij + b_j) / t - - hyp2f1 = hypergeo._hyp2f1_laplace - f000 = hyp2f1(a + 0, b + 0, c + 0, 1 - z) - f001 = hyp2f1(a + 0, b + 0, c + 1, 1 - z) - f012 = hyp2f1(a + 0, b + 1, c + 2, 1 - z) - f023 = hyp2f1(a + 0, b + 2, c + 3, 1 - z) - f212 = hyp2f1(a + 2, b + 1, c + 2, 1 - z) - f323 = hyp2f1(a + 3, b + 2, c + 3, 1 - z) + pr_m, mn_m, va_m = unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij) - s0 = b / t / c / (c + 1) - s1 = (c - a) * (c - a + 1) - s2 = a * (a + 1) - d0 = s0 * (b + 1) / t / (c + 2) - d1 = s1 * (c - a + 2) - d2 = s2 * (a + 2) + if not _valid_moments(mn_m, va_m) or not (0 <= pr_m <= 1): + return np.nan, np.full(2, np.nan) - mn_m = s0 * s1 * exp(f012 - f000) / 2 + s0 * s2 * exp(f212 - f000) / 2 - sq_m = d0 * d1 * exp(f023 - f000) / 3 + d0 * d2 * exp(f323 - f000) / 3 - va_m = sq_m - mn_m**2 - pr_m = (c - a) / c * exp(f001 - f000) + proj_m = approximate_gamma_mom(mn_m, va_m) - return pr_m, mn_m, va_m + return pr_m, np.array(proj_m) -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) -def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): +@numba.njit(_tuple((_f, _f1r))(_f, _f)) +def mutation_unphased_fixed_projection(t_i, t_j): r""" - Calculate gamma sufficient statistics for the PDF proportional to: - - ..math:: - - p(x) = \int_0^\infty (Unif(x | 0, t_i) + Unif(x | 0, t_j)) - Ga(t_j | a_j b_j) Po(y | \mu_ij (t_i + t_j)) dt_j + log p(t_m) := \ + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j)) - which models the time :math:`x` of a mutation uniformly distributed between - zero and one of two parents with ages :math:`t_i` and :math:`t_j`, where - the mutation count on both branches is :math:`y_{ij}` with total mutation - rate :math:`\mu_{ij}`. - - Returns log P[x under i], E[x], E[\log x], V[x]. + Returns phase probability, gamma natural parameters for mutation age """ + pr_m, mn_m, va_m = unphased_mutation_fixed_moments(t_i, t_j) - # Conditioning on ages of parents: - # P[x under i | t_i, t_j] = t_i / (t_i + t_j) - # E[x | x under t_i, t_i] = t_i / 2 - # E[x^2 | x under t_i, t_i] = t_i**2 / 3 - # and equivalently for t_j. Integrating these moments over the EP surrogate - # density leads to Tricomi functions similar to the node case, but - # with integer perturbations of a_j, y_ij. - - a = a_j - b = a_j + y_ij + 1 - z = t_i * (mu_ij + b_j) - - #with numba.objmode(f00='f8', f10='f8', f21='f8', f32='f8'): - # f00 = float(mpmath.log(mpmath.hyperu(a + 0, b + 0, z))) - # f10 = float(mpmath.log(mpmath.hyperu(a + 1, b + 0, z))) - # f21 = float(mpmath.log(mpmath.hyperu(a + 2, b + 1, z))) - # f32 = float(mpmath.log(mpmath.hyperu(a + 3, b + 2, z))) - - # direct but unstable: - hyperu = hypergeo._hyperu_laplace - f00, d00 = hyperu(a + 0, b + 0, z) - f10, d10 = hyperu(a + 1, b + 0, z) - f21, d21 = hyperu(a + 2, b + 1, z) - f32, d32 = hyperu(a + 3, b + 2, z) - pr_m = 1.0 - exp(f10 - f00) * a - mn_m = pr_m * t_i / 2 + t_i * exp(f21 - f00) * a * (a + 1) / 2 - sq_m = pr_m * t_i ** 2 / 3 + t_i ** 2 * exp(f32 - f00) * a * (a + 1) * (a + 2) / 3 - - # TODO: use a stabler approach with derivatives - # note that exp(f10 - f00) = (a + z * d00) / (a - b + 1) - # however the denominator is 0 if y_ij is 0 - # note that when y_ij == 0 then a == b + 1 and f00 = z**(-a) - - va_m = sq_m - mn_m**2 - - return pr_m, mn_m, va_m + if not _valid_moments(mn_m, va_m) or not (0 <= pr_m <= 1): + return np.nan, np.full(2, np.nan) + proj_m = approximate_gamma_mom(mn_m, va_m) -@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) -def unphased_projection(pars_i, pars_j, pars_ij): - a_i, b_i = pars_i - a_j, b_j = pars_j - y_ij, mu_ij = pars_ij - a_i += 1 - a_j += 1 - logl, mn_i, va_i, mn_j, va_j = unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) - if not _valid_moments(mn_i, va_i) or not _valid_moments(mn_j, va_j): - return np.nan, pars_i, pars_j - proj_i = approximate_gamma_mom(mn_i, va_i) - proj_j = approximate_gamma_mom(mn_j, va_j) - return logl, np.array(proj_i), np.array(proj_j) + return pr_m, np.array(proj_m) diff --git a/tsdate/phasing.py b/tsdate/phasing.py index 4f30991e..2511cfbf 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -25,10 +25,6 @@ import numpy as np import tskit -#def _mutations_frequency(ts): -# -#def mutations_frequency(ts): - def remove_singletons(ts): """ Remove all singleton mutations from the tree sequence. @@ -72,7 +68,7 @@ def remove_singletons(ts): def rephase_singletons(ts, use_node_times=True, random_seed=None): """ - Rephase singleton mutations in the tree sequence. How If `use_node_times` + Rephase singleton mutations in the tree sequence. If `use_node_times` is True, singletons are added to permissable branches with probability proportional to the branch length (and with equal probability otherwise). """ @@ -177,4 +173,4 @@ def accumulate_unphased(edges_mutations, mutations_phase, mutations_block, block return edges_mutations - +# TODO: mutation sort order diff --git a/tsdate/variational.py b/tsdate/variational.py index 921d3a54..d5f9b398 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -187,6 +187,7 @@ def _check_valid_state( posterior_check[p] += edge_factors[i, ROOTWARD] posterior_check[c] += edge_factors[i, LEAFWARD] # TODO: unphased factors + assert False posterior_check += node_factors[:, MIXPRIOR] posterior_check += node_factors[:, CONSTRNT] return np.allclose(posterior_check, posterior) @@ -243,10 +244,14 @@ def __init__(self, ts, likelihoods, constraints, mutations_edge): # mutable self.node_factors = np.zeros((ts.num_nodes, 2, 2)) self.edge_factors = np.zeros((ts.num_edges, 2, 2)) - #self.unph_factors = np.zeros((..., 2, 2)) #TODO + self.block_factors = np.zeros((num_blocks, 2, 2)) #TODO self.posterior = np.zeros((ts.num_nodes, 2)) + #self.mutation_posterior = np.full((ts.num_mutations, np.nan)) self.log_partition = np.zeros(ts.num_edges) + #self.edge_logconst = ... + #self.block_logconst = ... self.scale = np.ones(ts.num_nodes) + assert False # terminal nodes has_parent = np.full(ts.num_nodes, False) @@ -260,6 +265,8 @@ def __init__(self, ts, likelihoods, constraints, mutations_edge): # edge traversal order edges = np.arange(ts.num_edges, dtype=np.int32) + # TODO: mask singleton edges + assert False self.edge_order = np.concatenate((edges[:-1], np.flip(edges))) self.edge_weights = edge_sampling_weight( self.leaves, @@ -270,6 +277,7 @@ def __init__(self, ts, likelihoods, constraints, mutations_edge): ts.indexes_edge_insertion_order, ts.indexes_edge_removal_order, ) + self.block_order = np.arange(num_blocks, dtype=np.int32) @staticmethod @numba.njit(_f(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f)) @@ -322,6 +330,14 @@ def cavity_damping(x, y): def posterior_damping(x): return _rescale(x, max_shape) + # TODO + # in "unphased" mode, edges are singleton blocks, and the two parents + # of each block are given by "parents" and "children" + assert False + leafward_projection = approx.leafward_projection if unphased else approx.unphased_fixed_projection + rootward_projection = approx.rootward_projection if unphased else approx.unphased_fixed_projection + gamma_projection = approx.gamma_projection if unphased else approx.unphased_projection + fixed = constraints[:, LOWER] == constraints[:, UPPER] for i in edge_order: @@ -338,7 +354,7 @@ def posterior_damping(x): # match moments and update factor parent_age = constraints[p, LOWER] - lognorm[i], posterior[c] = approx.leafward_projection( + lognorm[i], posterior[c] = leafward_projection( parent_age, child_cavity, edge_likelihood, ) factors[i, LEAFWARD] *= 1.0 - child_delta @@ -358,7 +374,7 @@ def posterior_damping(x): # match moments and update factor child_age = constraints[c, LOWER] - lognorm[i], posterior[p] = approx.rootward_projection( + lognorm[i], posterior[p] = rootward_projection( child_age, parent_cavity, edge_likelihood, ) @@ -382,7 +398,7 @@ def posterior_damping(x): edge_likelihood = delta * likelihoods[i] # match moments and update factors - lognorm[i], posterior[p], posterior[c] = approx.gamma_projection( + lognorm[i], posterior[p], posterior[c] = gamma_projection( parent_cavity, child_cavity, edge_likelihood, ) factors[i, ROOTWARD] *= 1.0 - delta @@ -458,49 +474,11 @@ def posterior_damping(x): return np.nan - # @staticmethod - # @numba.njit(_f(_i2r, _i1r, _f2r, _f2w, _f3w, _f1w, _f)) - # def propagate_unphased_likelihoods( - # parents, individual, likelihoods, posterior, factors, scale, max_shape - # ): - # """ - # Update approximating factors for unphased singletons. - - # :param ndarray parents: rows are unphased intervals, columns are first - # and second parents of an individual over that interval. - # :param ndarray individual: the individual associated with each - # unphased interval. - # :param ndarray likelihoods: rows are unphased intervals, columns are - # number of singleton mutations and interval span. - # :param ndarray posterior: rows are nodes, columns are first and - # second natural parameters of gamma posteriors. Updated in - # place. - # :param ndarray factors: rows are unphased intervals, columns index - # different types of updates. Updated in place. - # :param ndarray scale: array of dimension `[num_nodes]` containing a - # scaling factor for the posteriors, updated in-place. - # :param float max_shape: the maximum allowed shape for node posteriors. - # """ - - # # TODO assert ??? - # assert max_shape >= 1.0 - # assert 0.0 < min_step < 1.0 - - # def cavity_damping(x, y): - # return _damp(x, y, min_step) - - # def posterior_damping(x): - # return _rescale(x, max_shape) - - # # TODO copy from propagate_likelihood... - - # # TODO copy from propagate_likelihood... - - # return np.nan - + # TODO add arguments, void return @staticmethod @numba.njit(_f2w(_i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r)) def propagate_mutations( + TODO, mutations_edge, edges_parent, edges_child, @@ -513,6 +491,11 @@ def propagate_mutations( """ Calculate posteriors for mutations. + :param ndarray mutations_order: integer array giving order in + which to traverse mutations + :param ndarray mutations_posterior: array of dimension `[num_mutations, 2]` + containing natural parameters for each mutation + :param ndarray mutations_edge: integer array giving edge for each mutation :param ndarray edges_parent: integer array of parent ids per edge @@ -539,20 +522,23 @@ def propagate_mutations( assert factors.shape == (edges_parent.size, 2, 2) assert likelihoods.shape == (edges_parent.size, 2) - mutations_posterior = np.zeros((mutations_edge.size, 2)) + #mutations_posterior = np.zeros((mutations_edge.size, 2)) + #pass in mutations posterior filled with nan TODO fixed = constraints[:, LOWER] == constraints[:, UPPER] - for m, i in enumerate(mutations_edge): - if i == tskit.NULL: # skip mutations above root - mutations_posterior[m] = np.nan + for m in mutations_order: + i = mutations_edge[m] + if i == tskit.NULL: # skip mutations above root or unphased continue - # TODO: if unphased skip, set to nan p, c = edges_parent[i], edges_child[i] if fixed[p] and fixed[c]: child_age = constraints[c, 0] parent_age = constraints[p, 0] - mean = 1 / 2 * (child_age + parent_age) - variance = 1 / 12 * (parent_age - child_age) ** 2 - mutations_posterior[m] = approx.approximate_gamma_mom(mean, variance) + mutations_posterior[m] = approx.mutation_fixed_projection( + parent_age, child_age + ) + #mean = 1 / 2 * (child_age + parent_age) + #variance = 1 / 12 * (parent_age - child_age) ** 2 + #mutations_posterior[m] = approx.approximate_gamma_mom(mean, variance) elif fixed[p] and not fixed[c]: child_message = factors[i, LEAFWARD] * scale[c] child_delta = 1.0 # hopefully we don't need to damp @@ -586,10 +572,7 @@ def propagate_mutations( return mutations_posterior - # @staticmethod - # @numba.njit(_f(_i2r, _i1r, _f2r, _f2w, _f3w, _f1w, _f)) - # def propagate_unphased_mutations() - + # TODO more arguments, blck_factors and block_parents @staticmethod @numba.njit(_void(_i1r, _i1r, _f3w, _f3w, _f1w)) def rescale_factors(edges_parent, edges_child, node_factors, edge_factors, scale): @@ -597,11 +580,11 @@ def rescale_factors(edges_parent, edges_child, node_factors, edge_factors, scale p, c = edges_parent, edges_child edge_factors[:, ROOTWARD] *= scale[p, np.newaxis] edge_factors[:, LEAFWARD] *= scale[c, np.newaxis] + j, k = block_parents + block_factors[:, ROOTWARD] *= scale[j, np.newaxis] + block_factors[:, LEAFWARD] *= scale[k, np.newaxis] node_factors[:, MIXPRIOR] *= scale[:, np.newaxis] node_factors[:, CONSTRNT] *= scale[:, np.newaxis] - # TODO: unphased factors - #unph_factors[:, FIRSTPAR] *= scale[:, np.newaxis] - #unph_factors[:, SECNDPAR] *= scale[:, np.newaxis] scale[:] = 1.0 def iterate( @@ -615,8 +598,18 @@ def iterate( check_valid=False, ): # TODO: pass through unphased intervals - #self.propagate_unphased( - # ... + #self.propagate_likelihood( + # self.block_order, + # self.block_parents[0], + # self.block_parents[1], + # self.block_likelihoods, + # self.constraints, + # self.block_factors, + # self.block_log_partition, + # self.scale, + # max_shape, + # min_step, + # USE_BLOCK_LIKELIHOOD, #) # rootward + leafward pass through edges @@ -632,6 +625,7 @@ def iterate( self.scale, max_shape, min_step, + USE_EDGE_LIKELIHOOD, ) # exponential regularization on roots From 51d694dba7baa62d5da439688fd04016575bdb5e Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Sat, 8 Jun 2024 20:17:28 -0700 Subject: [PATCH 07/29] Lots of renaming, prepratory factoring --- tests/exact_moments.py | 685 ++++++++++++++++++++++++++++++++++- tests/test_approximations.py | 522 ++++++-------------------- 2 files changed, 778 insertions(+), 429 deletions(-) diff --git a/tests/exact_moments.py b/tests/exact_moments.py index 0c9d30e6..d8f11793 100644 --- a/tests/exact_moments.py +++ b/tests/exact_moments.py @@ -1,12 +1,16 @@ +# flake8: noqa """ Moments for EP updates using exact hypergeometric evaluations rather than a Laplace approximation; intended for testing and accuracy benchmarking. """ +from math import exp +from math import log import mpmath import numpy as np -from scipy.special import betaln, gammaln -from math import log, exp +import scipy +from scipy.special import betaln +from scipy.special import gammaln def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): @@ -24,6 +28,10 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): f0 = float(mpmath.log(mpmath.hyp2f1(a + 0, b + 0, c + 0, z))) f1 = float(mpmath.log(mpmath.hyp2f1(a + 1, b + 1, c + 1, z))) f2 = float(mpmath.log(mpmath.hyp2f1(a + 2, b + 2, c + 2, z))) + s1 = a * b / c + s2 = s1 * (a + 1) * (b + 1) / (c + 1) + d1 = s1 * exp(f1 - f0) + d2 = s2 * exp(f2 - f0) logl = f0 + betaln(y_ij + 1, a) + gammaln(b) - b * log(t) mn_j = d1 / t sq_j = d2 / t**2 @@ -54,8 +62,8 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): f0 = float(mpmath.log(mpmath.hyperu(a + 0, b + 0, z))) f1 = float(mpmath.log(mpmath.hyperu(a + 1, b + 1, z))) f2 = float(mpmath.log(mpmath.hyperu(a + 2, b + 2, z))) - d0 = -a * f1 / f0 - d1 = -(a + 1) * f2 / f1 + d0 = -a * exp(f1 - f0) + d1 = -(a + 1) * exp(f2 - f1) logl = f0 - b_i * t_j + (b - 1) * log(t_j) + gammaln(a) mn_i = t_j * (1 - d0) va_i = t_j**2 * d0 * (d1 - d0) @@ -103,7 +111,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): s2 = s1 * (a + 1) * (b + 1) / (c + 1) d1 = s1 * exp(f1 - f0) d2 = s2 * exp(f2 - f0) - logl = f0 + _betaln(a_j, a_i) + _gammaln(b) - b * log(t) + logl = f0 + betaln(a_j, a_i) + gammaln(b) - b * log(t) mn_j = d1 / t sq_j = d2 / t**2 va_j = sq_j - mn_j**2 @@ -113,7 +121,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): return logl, mn_i, va_i, mn_j, va_j -def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): +def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): """ log p(t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ @@ -123,8 +131,11 @@ def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): a = a_j b = a_j + y_ij + 1 z = t_i * (mu_ij + b_j) - f0, d0 = float(mpmath.log(mpmath.hyperu(a + 0, b + 0, z))) - f1, d1 = float(mpmath.log(mpmath.hyperu(a + 1, b + 1, z))) + f0 = float(mpmath.log(mpmath.hyperu(a + 0, b + 0, z))) + f1 = float(mpmath.log(mpmath.hyperu(a + 1, b + 1, z))) + f2 = float(mpmath.log(mpmath.hyperu(a + 2, b + 2, z))) + d0 = -a * exp(f1 - f0) + d1 = -(a + 1) * exp(f2 - f1) logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + gammaln(a) mn_j = -t_i * d0 va_j = t_i**2 * d0 * (d1 - d0) @@ -151,11 +162,15 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): f222 = float(mpmath.log(mpmath.hyp2f1(a + 2, b + 2, c + 2, z))) s1 = a * b / c s2 = s1 * (a + 1) * (b + 1) / (c + 1) - d1 = b * (b + 1) / t ** 2 + d1 = b * (b + 1) / t**2 d2 = d1 * a / c d3 = d2 * (a + 1) / (c + 1) mn_m = s1 * exp(f111 - f000) / t / 2 * (1 + z) + b / t / 2 - sq_m = d1 * exp(f020 - f000) / 3 + d2 * exp(f121 - f000) / 3 + d3 * exp(f222 - f000) / 3 + sq_m = ( + d1 * exp(f020 - f000) / 3 + + d2 * exp(f121 - f000) / 3 + + d3 * exp(f222 - f000) / 3 + ) va_m = sq_m - mn_m**2 return mn_m, va_m @@ -188,7 +203,7 @@ def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): return mn_m, va_m -def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): +def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): """ log p(t_m, t_i, t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ @@ -221,7 +236,7 @@ def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): return pr_m, mn_m, va_m -def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): +def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): """ log p(t_m, t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ @@ -238,12 +253,21 @@ def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): f32 = float(mpmath.log(mpmath.hyperu(a + 3, b + 2, z))) pr_m = 1.0 - exp(f10 - f00) * a mn_m = pr_m * t_i / 2 + t_i * exp(f21 - f00) * a * (a + 1) / 2 - sq_m = pr_m * t_i ** 2 / 3 + t_i ** 2 * exp(f32 - f00) * a * (a + 1) * (a + 2) / 3 + sq_m = pr_m * t_i**2 / 3 + t_i**2 * exp(f32 - f00) * a * (a + 1) * (a + 2) / 3 va_m = sq_m - mn_m**2 return pr_m, mn_m, va_m -def unphased_mutation_fixed_moments(t_i, t_j): +def mutation_edge_moments(t_i, t_j): + """ + log p(t_m) := int(t_j < t_m < t_i) / (t_i - t_j) + """ + mn_m = t_i / 2 + t_j / 2 + va_m = (t_i - t_j) ** 2 / 12 + return mn_m, va_m + + +def mutation_block_moments(t_i, t_j): """ log p(t_m) := \ log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ @@ -251,6 +275,637 @@ def unphased_mutation_fixed_moments(t_i, t_j): """ pr_m = t_i / (t_i + t_j) mn_m = pr_m * t_i / 2 + (1 - pr_m) * t_j / 2 - sq_m = pr_m * t_i ** 2 / 3 + (1 - pr_m) * t_j ** 2 / 3 + sq_m = pr_m * t_i**2 / 3 + (1 - pr_m) * t_j**2 / 3 va_m = sq_m - mn_m**2 return pr_m, mn_m, va_m + + +# --- verify exact solutions with quadrature --- # + + +class TestExactMoments: + @staticmethod + def pdf(t_i, t_j, a_i, b_i, a_j, b_j, y, mu): + """ + Target joint (pair) distribution, proportional to the parent/child + marginals (gamma) and a Poisson mutation likelihood + """ + assert 0 < t_j < t_i + return ( + t_i ** (a_i - 1) + * np.exp(-t_i * b_i) + * t_j ** (a_j - 1) + * np.exp(-t_j * b_j) + * (t_i - t_j) ** y + * np.exp(-(t_i - t_j) * mu) + ) + + @staticmethod + def pdf_rootward(t_i, t_j, a_i, b_i, y, mu): + """ + Target conditional distribution, proportional to the parent + marginals (gamma) and a Poisson mutation likelihood at a + fixed child age + """ + assert 0 <= t_j < t_i + return ( + t_i ** (a_i - 1) + * np.exp(-t_i * b_i) + * (t_i - t_j) ** y + * np.exp(-(t_i - t_j) * mu) + ) + + @staticmethod + def pdf_leafward(t_i, t_j, a_j, b_j, y, mu): + """ + Target conditional distribution, proportional to the child + marginals (gamma) and a Poisson mutation likelihood at a + fixed parent age + """ + assert 0 < t_j < t_i + return ( + t_j ** (a_j - 1) + * np.exp(-t_j * b_j) + * (t_i - t_j) ** y + * np.exp(-(t_i - t_j) * mu) + ) + + @staticmethod + def pdf_unphased(t_i, t_j, a_i, b_i, a_j, b_j, y, mu): + """ + Target joint (pair) distribution, proportional to the parent + marginals (gamma) and a Poisson mutation likelihood over the + two branches leading from (present-day) individual to parents + """ + assert t_i > 0 and t_j > 0 + return ( + t_i ** (a_i - 1) + * np.exp(-t_i * b_i) + * t_j ** (a_j - 1) + * np.exp(-t_j * b_j) + * (t_i + t_j) ** y + * np.exp(-(t_i + t_j) * mu) + ) + + @staticmethod + def pdf_sideways(t_i, t_j, a_j, b_j, y, mu): + """ + Target joint (pair) distribution, proportional to the parent + marginals (gamma) and a Poisson mutation likelihood over the + two branches leading from (present-day) individual to parents, + with left parent fixed to t_i + """ + assert t_i > 0 and t_j > 0 + return ( + t_j ** (a_j - 1) + * np.exp(-t_j * b_j) + * (t_i + t_j) ** y + * np.exp(-(t_i + t_j) * mu) + ) + + @staticmethod + def pdf_edge(x, t_i, t_j): + """ + Mutation uniformly distributed between child and parent + """ + assert t_i > 0 and t_j > 0 + return int(t_j < x < t_i) / (t_i - t_j) + + @staticmethod + def pdf_block(x, t_i, t_j): + """ + Mutation uniformly distributed between child at time zero and one of + two parents with fixed ages + """ + assert t_i > 0 and t_j > 0 + return (int(0 < x < t_i) + int(0 < x < t_j)) / (t_i + t_j) + + def test_moments(self, pars): + """ + Test mean and variance when ages of both nodes are free + """ + logconst, t_i, var_t_i, t_j, var_t_j = moments(*pars) + ck_normconst = scipy.integrate.dblquad( + lambda t_i, t_j: self.pdf(t_i, t_j, *pars), + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] + assert np.isclose(logconst, np.log(ck_normconst)) + ck_t_i = scipy.integrate.dblquad( + lambda t_i, t_j: t_i * self.pdf(t_i, t_j, *pars) / ck_normconst, + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] + assert np.isclose(t_i, ck_t_i) + ck_t_j = scipy.integrate.dblquad( + lambda t_i, t_j: t_j * self.pdf(t_i, t_j, *pars) / ck_normconst, + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] + assert np.isclose(t_j, ck_t_j) + ck_var_t_i = ( + scipy.integrate.dblquad( + lambda t_i, t_j: t_i**2 * self.pdf(t_i, t_j, *pars) / ck_normconst, + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] + - ck_t_i**2 + ) + assert np.isclose(var_t_i, ck_var_t_i) + ck_var_t_j = ( + scipy.integrate.dblquad( + lambda t_i, t_j: t_j**2 * self.pdf(t_i, t_j, *pars) / ck_normconst, + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] + - ck_t_j**2 + ) + assert np.isclose(var_t_j, ck_var_t_j) + + def test_rootward_moments(self, pars): + """ + Test mean and variance of parent age when child age is fixed to a nonzero value + """ + a_i, b_i, a_j, b_j, y, mu = pars + pars_redux = (a_i, b_i, y, mu) + mn_j = a_j / b_j # point "estimate" for child + for t_j in [0.0, mn_j]: + logconst, t_i, var_t_i = rootward_moments(t_j, *pars_redux) + ck_normconst = scipy.integrate.quad( + lambda t_i: self.pdf_rootward(t_i, t_j, *pars_redux), + t_j, + np.inf, + epsabs=0, + )[0] + assert np.isclose(logconst, np.log(ck_normconst)) + ck_t_i = scipy.integrate.quad( + lambda t_i: t_i + * self.pdf_rootward(t_i, t_j, *pars_redux) + / ck_normconst, + t_j, + np.inf, + epsabs=0, + )[0] + assert np.isclose(t_i, ck_t_i) + ck_var_t_i = ( + scipy.integrate.quad( + lambda t_i: t_i**2 + * self.pdf_rootward(t_i, t_j, *pars_redux) + / ck_normconst, + t_j, + np.inf, + epsabs=0, + )[0] + - ck_t_i**2 + ) + assert np.isclose(var_t_i, ck_var_t_i) + + def test_leafward_moments(self, pars): + """ + Test mean and variance of child age when parent age is fixed to a nonzero value + """ + a_i, b_i, a_j, b_j, y, mu = pars + t_i = a_i / b_i # point "estimate" for parent + pars_redux = (a_j, b_j, y, mu) + logconst, t_j, var_t_j = leafward_moments(t_i, *pars_redux) + ck_normconst = scipy.integrate.quad( + lambda t_j: self.pdf_leafward(t_i, t_j, *pars_redux), + 0, + t_i, + epsabs=0, + )[0] + assert np.isclose(logconst, np.log(ck_normconst)) + ck_t_j = scipy.integrate.quad( + lambda t_j: t_j * self.pdf_leafward(t_i, t_j, *pars_redux) / ck_normconst, + 0, + t_i, + epsabs=0, + )[0] + assert np.isclose(t_j, ck_t_j) + ck_var_t_j = ( + scipy.integrate.quad( + lambda t_j: t_j**2 + * self.pdf_leafward(t_i, t_j, *pars_redux) + / ck_normconst, + 0, + t_i, + epsabs=0, + )[0] + - ck_t_j**2 + ) + assert np.isclose(var_t_j, ck_var_t_j) + + def test_unphased_moments(self, pars): + """ + Parent ages for an singleton nodes above an unphased individual + """ + logconst, t_i, var_t_i, t_j, var_t_j = unphased_moments(*pars) + ck_normconst = scipy.integrate.dblquad( + lambda t_i, t_j: self.pdf_unphased(t_i, t_j, *pars), + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + assert np.isclose(logconst, np.log(ck_normconst)) + ck_t_i = scipy.integrate.dblquad( + lambda t_i, t_j: t_i * self.pdf_unphased(t_i, t_j, *pars) / ck_normconst, + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + assert np.isclose(t_i, ck_t_i) + ck_t_j = scipy.integrate.dblquad( + lambda t_i, t_j: t_j * self.pdf_unphased(t_i, t_j, *pars) / ck_normconst, + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + assert np.isclose(t_j, ck_t_j) + ck_var_t_i = ( + scipy.integrate.dblquad( + lambda t_i, t_j: t_i**2 + * self.pdf_unphased(t_i, t_j, *pars) + / ck_normconst, + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + - ck_t_i**2 + ) + assert np.isclose(var_t_i, ck_var_t_i) + ck_var_t_j = ( + scipy.integrate.dblquad( + lambda t_i, t_j: t_j**2 + * self.pdf_unphased(t_i, t_j, *pars) + / ck_normconst, + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + - ck_t_j**2 + ) + assert np.isclose(var_t_j, ck_var_t_j) + + def test_sideways_moments(self, pars): + """ + Parent ages for an singleton nodes above an unphased individual, where + second parent is fixed to a particular time + """ + a_i, b_i, a_j, b_j, y, mu = pars + pars_redux = (a_j, b_j, y, mu) + t_i = a_i / b_i # point "estimate" for left parent + nc, mn, va = sideways_moments(t_i, *pars_redux) + ck_nc = scipy.integrate.quad( + lambda t_j: self.pdf_sideways(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] + assert np.isclose(np.exp(nc), ck_nc) + ck_mn = ( + scipy.integrate.quad( + lambda t_j: t_j * self.pdf_sideways(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] + / ck_nc + ) + assert np.isclose(mn, ck_mn) + ck_va = ( + scipy.integrate.quad( + lambda t_j: t_j**2 * self.pdf_sideways(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] + / ck_nc + - ck_mn**2 + ) + assert np.isclose(va, ck_va) + + def test_mutation_moments(self, pars): + """ + Mutation mapped to a single branch with both nodes free + """ + + def f(t_i, t_j): + assert t_j < t_i + mn = t_i / 2 + t_j / 2 + sq = (t_i**2 + t_i * t_j + t_j**2) / 3 + return mn, sq + + mn, va = mutation_moments(*pars) + nc = scipy.integrate.dblquad( + lambda t_i, t_j: self.pdf(t_i, t_j, *pars), + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] + ck_mn = ( + scipy.integrate.dblquad( + lambda t_i, t_j: f(t_i, t_j)[0] * self.pdf(t_i, t_j, *pars), + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] + / nc + ) + assert np.isclose(mn, ck_mn) + ck_va = ( + scipy.integrate.dblquad( + lambda t_i, t_j: f(t_i, t_j)[1] * self.pdf(t_i, t_j, *pars), + 0, + np.inf, + lambda t_j: t_j, + np.inf, + epsabs=0, + )[0] + / nc + - ck_mn**2 + ) + assert np.isclose(va, ck_va) + + def test_mutation_rootward_moments(self, pars): + """ + Mutation mapped to a single branch with child node fixed + """ + + def f(t_i, t_j): # conditional moments + assert t_j < t_i + mn = t_i / 2 + t_j / 2 + sq = (t_i**2 + t_i * t_j + t_j**2) / 3 + return mn, sq + + a_i, b_i, a_j, b_j, y, mu = pars + pars_redux = (a_i, b_i, y, mu) + mn_j = a_j / b_j # point "estimate" for child + for t_j in [0.0, mn_j]: + mn, va = mutation_rootward_moments(t_j, *pars_redux) + nc = scipy.integrate.quad( + lambda t_i: self.pdf_rootward(t_i, t_j, *pars_redux), + t_j, + np.inf, + )[0] + ck_mn = ( + scipy.integrate.quad( + lambda t_i: f(t_i, t_j)[0] + * self.pdf_rootward(t_i, t_j, *pars_redux), + t_j, + np.inf, + )[0] + / nc + ) + assert np.isclose(mn, ck_mn) + ck_va = ( + scipy.integrate.quad( + lambda t_i: f(t_i, t_j)[1] + * self.pdf_rootward(t_i, t_j, *pars_redux), + t_j, + np.inf, + )[0] + / nc + - ck_mn**2 + ) + assert np.isclose(va, ck_va) + + def test_mutation_leafward_moments(self, pars): + """ + Mutation mapped to a single branch with parent node fixed + """ + + def f(t_i, t_j): + assert t_j < t_i + mn = t_i / 2 + t_j / 2 + sq = (t_i**2 + t_i * t_j + t_j**2) / 3 + return mn, sq + + a_i, b_i, a_j, b_j, y, mu = pars + t_i = a_i / b_i # point "estimate" for parent + pars_redux = (a_j, b_j, y, mu) + mn, va = mutation_leafward_moments(t_i, *pars_redux) + nc = scipy.integrate.quad( + lambda t_j: self.pdf_leafward(t_i, t_j, *pars_redux), + 0, + t_i, + )[0] + ck_mn = ( + scipy.integrate.quad( + lambda t_j: f(t_i, t_j)[0] * self.pdf_leafward(t_i, t_j, *pars_redux), + 0, + t_i, + )[0] + / nc + ) + assert np.isclose(mn, ck_mn) + ck_va = ( + scipy.integrate.quad( + lambda t_j: f(t_i, t_j)[1] * self.pdf_leafward(t_i, t_j, *pars_redux), + 0, + t_i, + )[0] + / nc + - ck_mn**2 + ) + assert np.isclose(va, ck_va) + + def test_mutation_unphased_moments(self, pars): + """ + Mutation mapped to two singleton branches with children fixed to time zero + """ + + def f(t_i, t_j): # conditional moments + pr = t_i / (t_i + t_j) + mn = pr * t_i / 2 + (1 - pr) * t_j / 2 + sq = pr * t_i**2 / 3 + (1 - pr) * t_j**2 / 3 + return pr, mn, sq + + pr, mn, va = mutation_unphased_moments(*pars) + nc = scipy.integrate.dblquad( + lambda t_i, t_j: self.pdf_unphased(t_i, t_j, *pars), + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + ck_pr = ( + scipy.integrate.dblquad( + lambda t_i, t_j: f(t_i, t_j)[0] * self.pdf_unphased(t_i, t_j, *pars), + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + / nc + ) + assert np.isclose(pr, ck_pr) + ck_mn = ( + scipy.integrate.dblquad( + lambda t_i, t_j: f(t_i, t_j)[1] * self.pdf_unphased(t_i, t_j, *pars), + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + / nc + ) + assert np.isclose(mn, ck_mn) + ck_va = ( + scipy.integrate.dblquad( + lambda t_i, t_j: f(t_i, t_j)[2] * self.pdf_unphased(t_i, t_j, *pars), + 0, + np.inf, + 0, + np.inf, + epsabs=0, + )[0] + / nc + - ck_mn**2 + ) + assert np.isclose(va, ck_va) + + def test_mutation_sideways_moments(self, pars): + """ + Mutation mapped to two branches with children fixed to time zero, and + left parent (i) fixed + """ + + def f(t_i, t_j): # conditional moments + pr = t_i / (t_i + t_j) + mn = pr * t_i / 2 + (1 - pr) * t_j / 2 + sq = pr * t_i**2 / 3 + (1 - pr) * t_j**2 / 3 + return pr, mn, sq + + a_i, b_i, a_j, b_j, y, mu = pars + pars_redux = a_j, b_j, y, mu + t_i = a_i / b_i # point "estimate" for left parent + pr, mn, va = mutation_sideways_moments(t_i, *pars_redux) + nc = scipy.integrate.quad( + lambda t_j: self.pdf_sideways(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] + ck_pr = ( + scipy.integrate.quad( + lambda t_j: f(t_i, t_j)[0] * self.pdf_sideways(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] + / nc + ) + assert np.isclose(pr, ck_pr) + ck_mn = ( + scipy.integrate.quad( + lambda t_j: f(t_i, t_j)[1] * self.pdf_sideways(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] + / nc + ) + assert np.isclose(mn, ck_mn) + ck_va = ( + scipy.integrate.quad( + lambda t_j: f(t_i, t_j)[2] * self.pdf_sideways(t_i, t_j, *pars_redux), + 0, + np.inf, + )[0] + / nc + - ck_mn**2 + ) + assert np.isclose(va, ck_va) + + def test_mutation_edge_moments(self, pars): + """ + Mutation mapped to two branches with children fixed to time zero, and + both parents fixed + """ + a_i, b_i, a_j, b_j, y, mu = pars + t_i = a_i / b_i + t_j = a_j / b_j + mn, va = mutation_edge_moments(t_i, t_j) + ck_mn = scipy.integrate.quad( + lambda x: x * self.pdf_edge(x, t_i, t_j), + 0, + max(t_i, t_j), + )[0] + assert np.isclose(mn, ck_mn) + ck_va = ( + scipy.integrate.quad( + lambda x: x**2 * self.pdf_edge(x, t_i, t_j), + 0, + max(t_i, t_j), + )[0] + - ck_mn**2 + ) + assert np.isclose(va, ck_va) + + def test_mutation_block_moments(self, pars): + """ + Mutation mapped to two branches with children fixed to time zero, and + both parents fixed + """ + a_i, b_i, a_j, b_j, y, mu = pars + t_i = a_i / b_i + t_j = a_j / b_j + pars_redux = (a_j, b_j, y, mu) + pr, mn, va = mutation_block_moments(t_i, t_j) + ck_pr = t_i / (t_i + t_j) + assert np.isclose(pr, ck_pr) + ck_mn = scipy.integrate.quad( + lambda x: x * self.pdf_block(x, t_i, t_j), + 0, + max(t_i, t_j), + )[0] + assert np.isclose(mn, ck_mn) + ck_va = ( + scipy.integrate.quad( + lambda x: x**2 * self.pdf_block(x, t_i, t_j), + 0, + max(t_i, t_j), + )[0] + - ck_mn**2 + ) + assert np.isclose(va, ck_va) + + +def validate(): + tests = TestExactMoments() + test_names = [f for f in dir(tests) if f.startswith("test")] + test_cases = [ # [shape1, rate1, shape2, rate2, muts, rate] + [2.0, 0.0005, 1.5, 0.005, 0.0, 0.001], + [2.0, 0.0005, 1.5, 0.005, 1.0, 0.001], + [2.0, 0.0005, 1.5, 0.005, 2.0, 0.001], + [2.0, 0.0005, 1.5, 0.005, 3.0, 0.001], + ] + for pars in test_cases: + for test in test_names: + getattr(tests, test)(pars) diff --git a/tests/test_approximations.py b/tests/test_approximations.py index 6dc60538..5f22d26e 100644 --- a/tests/test_approximations.py +++ b/tests/test_approximations.py @@ -24,18 +24,30 @@ """ Test cases for the gamma-variational approximations in tsdate """ +from math import sqrt + import numpy as np import pytest import scipy.integrate import scipy.special import scipy.stats +from exact_moments import leafward_moments +from exact_moments import moments +from exact_moments import mutation_block_moments +from exact_moments import mutation_edge_moments +from exact_moments import mutation_leafward_moments +from exact_moments import mutation_moments +from exact_moments import mutation_rootward_moments +from exact_moments import mutation_sideways_moments +from exact_moments import mutation_unphased_moments +from exact_moments import rootward_moments +from exact_moments import sideways_moments +from exact_moments import unphased_moments from tsdate import approx from tsdate import hypergeo -from tsdate import prior # TODO: better test set? -# TODO: test special case where child is fixed to age 0 _gamma_trio_test_cases = [ # [shape1, rate1, shape2, rate2, muts, rate] [2.0, 0.0005, 1.5, 0.005, 0.0, 0.001], [2.0, 0.0005, 1.5, 0.005, 1.0, 0.001], @@ -47,485 +59,167 @@ @pytest.mark.parametrize("pars", _gamma_trio_test_cases) class TestPosteriorMomentMatching: """ - Test approximation of marginal pairwise joint distributions by a gamma via - moment matching of sufficient statistics + Test Laplace approximation of pairwise joint distributions for EP updates """ - @staticmethod - def pdf(t_i, t_j, a_i, b_i, a_j, b_j, y, mu): - """ - Target joint (pair) distribution, proportional to the parent/child - marginals (gamma) and a Poisson mutation likelihood - """ - assert 0 < t_j < t_i - return ( - t_i ** (a_i - 1) - * np.exp(-t_i * b_i) - * t_j ** (a_j - 1) - * np.exp(-t_j * b_j) - * (t_i - t_j) ** y - * np.exp(-(t_i - t_j) * mu) - ) - - @staticmethod - def pdf_rootward(t_i, t_j, a_i, b_i, y, mu): - """ - Target conditional distribution, proportional to the parent - marginals (gamma) and a Poisson mutation likelihood at a - fixed child age - """ - assert 0 <= t_j < t_i - return ( - t_i ** (a_i - 1) - * np.exp(-t_i * b_i) - * (t_i - t_j) ** y - * np.exp(-(t_i - t_j) * mu) - ) - - @staticmethod - def pdf_leafward(t_i, t_j, a_j, b_j, y, mu): - """ - Target conditional distribution, proportional to the child - marginals (gamma) and a Poisson mutation likelihood at a - fixed parent age - """ - assert 0 < t_j < t_i - return ( - t_j ** (a_j - 1) - * np.exp(-t_j * b_j) - * (t_i - t_j) ** y - * np.exp(-(t_i - t_j) * mu) - ) - - @staticmethod - def pdf_unphased(t_i, t_j, a_i, b_i, a_j, b_j, y, mu): - """ - Target joint (pair) distribution, proportional to the parent - marginals (gamma) and a Poisson mutation likelihood over the - two branches leading from (present-day) individual to parents - """ - assert t_i > 0 and t_j > 0 - return ( - t_i ** (a_i - 1) - * np.exp(-t_i * b_i) - * t_j ** (a_j - 1) - * np.exp(-t_j * b_j) - * (t_i + t_j) ** y - * np.exp(-(t_i + t_j) * mu) - ) - - @staticmethod - def pdf_unphased_rightward(t_i, t_j, a_j, b_j, y, mu): - """ - Target joint (pair) distribution, proportional to the parent - marginals (gamma) and a Poisson mutation likelihood over the - two branches leading from (present-day) individual to parents, - with left parent fixed to t_i - """ - assert t_i > 0 and t_j > 0 - return ( - t_j ** (a_j - 1) - * np.exp(-t_j * b_j) - * (t_i + t_j) ** y - * np.exp(-(t_i + t_j) * mu) - ) - def test_moments(self, pars): """ Test mean and variance when ages of both nodes are free """ - logconst, t_i, var_t_i, t_j, var_t_j = approx.moments(*pars) - ck_normconst = scipy.integrate.dblquad( - lambda t_i, t_j: self.pdf(t_i, t_j, *pars), - 0, - np.inf, - lambda t_j: t_j, - np.inf, - epsabs=0, - )[0] - assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-2) - ck_t_i = scipy.integrate.dblquad( - lambda t_i, t_j: t_i * self.pdf(t_i, t_j, *pars) / ck_normconst, - 0, - np.inf, - lambda t_j: t_j, - np.inf, - epsabs=0, - )[0] - assert np.isclose(t_i, ck_t_i, rtol=2e-2) - ck_t_j = scipy.integrate.dblquad( - lambda t_i, t_j: t_j * self.pdf(t_i, t_j, *pars) / ck_normconst, - 0, - np.inf, - lambda t_j: t_j, - np.inf, - epsabs=0, - )[0] - assert np.isclose(t_j, ck_t_j, rtol=2e-2) - ck_var_t_i = ( - scipy.integrate.dblquad( - lambda t_i, t_j: t_i**2 * self.pdf(t_i, t_j, *pars) / ck_normconst, - 0, - np.inf, - lambda t_j: t_j, - np.inf, - epsabs=0, - )[0] - - ck_t_i**2 - ) - assert np.isclose(var_t_i, ck_var_t_i, rtol=2e-2) - ck_var_t_j = ( - scipy.integrate.dblquad( - lambda t_i, t_j: t_j**2 * self.pdf(t_i, t_j, *pars) / ck_normconst, - 0, - np.inf, - lambda t_j: t_j, - np.inf, - epsabs=0, - )[0] - - ck_t_j**2 - ) - assert np.isclose(var_t_j, ck_var_t_j, rtol=2e-2) + rtol = 1e-2 + ll, mn_i, va_i, mn_j, va_j = approx.moments(*pars) + ck_ll, ck_mn_i, ck_va_i, ck_mn_j, ck_va_j = moments(*pars) + assert np.isclose(ck_ll, ll, rtol=rtol) + assert np.isclose(ck_mn_i, mn_i, rtol=rtol) + assert np.isclose(ck_mn_j, mn_j, rtol=rtol) + assert np.isclose(sqrt(ck_va_i), sqrt(va_i), rtol=rtol) + assert np.isclose(sqrt(ck_va_j), sqrt(va_j), rtol=rtol) def test_rootward_moments(self, pars): """ Test mean and variance of parent age when child age is fixed to a nonzero value """ + rtol = 2e-2 a_i, b_i, a_j, b_j, y, mu = pars pars_redux = (a_i, b_i, y, mu) - mn_j = a_j / b_j # point "estimate" for child + mn_j = a_j / b_j for t_j in [0.0, mn_j]: - logconst, t_i, var_t_i = approx.rootward_moments(t_j, *pars_redux) - ck_normconst = scipy.integrate.quad( - lambda t_i: self.pdf_rootward(t_i, t_j, *pars_redux), - t_j, - np.inf, - epsabs=0, - )[0] - assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-2) - ck_t_i = scipy.integrate.quad( - lambda t_i: t_i * self.pdf_rootward(t_i, t_j, *pars_redux) / ck_normconst, - t_j, - np.inf, - epsabs=0, - )[0] - assert np.isclose(t_i, ck_t_i, rtol=2e-2) - ck_var_t_i = ( - scipy.integrate.quad( - lambda t_i: t_i**2 - * self.pdf_rootward(t_i, t_j, *pars_redux) - / ck_normconst, - t_j, - np.inf, - epsabs=0, - )[0] - - ck_t_i**2 - ) - assert np.isclose(var_t_i, ck_var_t_i, rtol=2e-2) + ll, mn_i, va_i = approx.rootward_moments(t_j, *pars_redux) + ck_ll, ck_mn_i, ck_va_i = rootward_moments(t_j, *pars_redux) + assert np.isclose(ck_ll, ll, rtol=rtol) + assert np.isclose(ck_mn_i, mn_i, rtol=rtol) + assert np.isclose(sqrt(ck_va_i), sqrt(va_i), rtol=rtol) def test_leafward_moments(self, pars): """ Test mean and variance of child age when parent age is fixed to a nonzero value """ + rtol = 1e-2 a_i, b_i, a_j, b_j, y, mu = pars - t_i = a_i / b_i # point "estimate" for parent + t_i = a_i / b_i pars_redux = (a_j, b_j, y, mu) - logconst, t_j, var_t_j = approx.leafward_moments(t_i, *pars_redux) - ck_normconst = scipy.integrate.quad( - lambda t_j: self.pdf_leafward(t_i, t_j, *pars_redux), - 0, - t_i, - epsabs=0, - )[0] - assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-2) - ck_t_j = scipy.integrate.quad( - lambda t_j: t_j * self.pdf_leafward(t_i, t_j, *pars_redux) / ck_normconst, - 0, - t_i, - epsabs=0, - )[0] - assert np.isclose(t_j, ck_t_j, rtol=2e-2) - ck_var_t_j = ( - scipy.integrate.quad( - lambda t_j: t_j**2 - * self.pdf_leafward(t_i, t_j, *pars_redux) - / ck_normconst, - 0, - t_i, - epsabs=0, - )[0] - - ck_t_j**2 - ) - assert np.isclose(var_t_j, ck_var_t_j, rtol=2e-2) + ll, mn_j, va_j = approx.leafward_moments(t_i, *pars_redux) + ck_ll, ck_mn_j, ck_va_j = leafward_moments(t_i, *pars_redux) + assert np.isclose(ck_ll, ll, rtol=rtol) + assert np.isclose(ck_mn_j, mn_j, rtol=rtol) + assert np.isclose(sqrt(ck_va_j), sqrt(va_j), rtol=rtol) def test_unphased_moments(self, pars): """ Parent ages for an singleton nodes above an unphased individual """ - logconst, t_i, var_t_i, t_j, var_t_j = approx.unphased_moments(*pars) - ck_normconst = scipy.integrate.dblquad( - lambda t_i, t_j: self.pdf_unphased(t_i, t_j, *pars), - 0, - np.inf, - 0, - np.inf, - epsabs=0, - )[0] - assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-2) - ck_t_i = scipy.integrate.dblquad( - lambda t_i, t_j: t_i * self.pdf_unphased(t_i, t_j, *pars) / ck_normconst, - 0, - np.inf, - 0, - np.inf, - epsabs=0, - )[0] - assert np.isclose(t_i, ck_t_i, rtol=2e-2) - ck_t_j = scipy.integrate.dblquad( - lambda t_i, t_j: t_j * self.pdf_unphased(t_i, t_j, *pars) / ck_normconst, - 0, - np.inf, - 0, - np.inf, - epsabs=0, - )[0] - assert np.isclose(t_j, ck_t_j, rtol=2e-2) - ck_var_t_i = ( - scipy.integrate.dblquad( - lambda t_i, t_j: t_i**2 * self.pdf_unphased(t_i, t_j, *pars) / ck_normconst, - 0, - np.inf, - 0, - np.inf, - epsabs=0, - )[0] - - ck_t_i**2 - ) - assert np.isclose(var_t_i, ck_var_t_i, rtol=2e-2) - ck_var_t_j = ( - scipy.integrate.dblquad( - lambda t_i, t_j: t_j**2 * self.pdf_unphased(t_i, t_j, *pars) / ck_normconst, - 0, - np.inf, - 0, - np.inf, - epsabs=0, - )[0] - - ck_t_j**2 - ) - assert np.isclose(var_t_j, ck_var_t_j, rtol=2e-2) + rtol = 1e-2 + ll, mn_i, va_i, mn_j, va_j = approx.unphased_moments(*pars) + ck_ll, ck_mn_i, ck_va_i, ck_mn_j, ck_va_j = unphased_moments(*pars) + assert np.isclose(ck_ll, ll, rtol=rtol) + assert np.isclose(ck_mn_i, mn_i, rtol=rtol) + assert np.isclose(ck_mn_j, mn_j, rtol=rtol) + assert np.isclose(sqrt(ck_va_i), sqrt(va_i), rtol=rtol) + assert np.isclose(sqrt(ck_va_j), sqrt(va_j), rtol=rtol) - def test_unphased_rightward_moments(self, pars): + def test_sideways_moments(self, pars): """ Parent ages for an singleton nodes above an unphased individual, where second parent is fixed to a particular time """ + rtol = 1e-2 a_i, b_i, a_j, b_j, y, mu = pars pars_redux = (a_j, b_j, y, mu) - t_i = a_i / b_i # point "estimate" for left parent - nc, mn, va = approx.unphased_rightward_moments(t_i, *pars_redux) - ck_nc = scipy.integrate.quad( - lambda t_j: self.pdf_unphased_rightward(t_i, t_j, *pars_redux), - 0, - np.inf, - )[0] - assert np.isclose(np.exp(nc), ck_nc, rtol=2e-2) - ck_mn = scipy.integrate.quad( - lambda t_j: t_j * self.pdf_unphased_rightward(t_i, t_j, *pars_redux), - 0, - np.inf, - )[0] / ck_nc - assert np.isclose(mn, ck_mn, rtol=2e-2) - ck_va = scipy.integrate.quad( - lambda t_j: t_j**2 * self.pdf_unphased_rightward(t_i, t_j, *pars_redux), - 0, - np.inf, - )[0] / ck_nc - ck_mn**2 - assert np.isclose(va, ck_va, rtol=2e-2) + t_i = a_i / b_i + ll, mn_j, va_j = approx.sideways_moments(t_i, *pars_redux) + ck_ll, ck_mn_j, ck_va_j = sideways_moments(t_i, *pars_redux) + assert np.isclose(ck_ll, ll, rtol=rtol) + assert np.isclose(ck_mn_j, mn_j, rtol=rtol) + assert np.isclose(sqrt(ck_va_j), sqrt(va_j), rtol=rtol) def test_mutation_moments(self, pars): """ Mutation mapped to a single branch with both nodes free """ - def f(t_i, t_j): - assert t_j < t_i - mn = t_i / 2 + t_j / 2 - sq = (t_i**2 + t_i*t_j + t_j**2) / 3 - return mn, sq + rtol = 2e-2 mn, va = approx.mutation_moments(*pars) - nc = scipy.integrate.dblquad( - lambda t_i, t_j: self.pdf(t_i, t_j, *pars), - 0, - np.inf, - lambda t_j: t_j, - np.inf, - epsabs=0, - )[0] - ck_mn = scipy.integrate.dblquad( - lambda t_i, t_j: f(t_i, t_j)[0] * self.pdf(t_i, t_j, *pars), - 0, - np.inf, - lambda t_j: t_j, - np.inf, - epsabs=0, - )[0] / nc - assert np.isclose(mn, ck_mn, rtol=2e-2) - ck_va = scipy.integrate.dblquad( - lambda t_i, t_j: f(t_i, t_j)[1] * self.pdf(t_i, t_j, *pars), - 0, - np.inf, - lambda t_j: t_j, - np.inf, - epsabs=0, - )[0] / nc - ck_mn**2 - assert np.isclose(va, ck_va, rtol=5e-2) + ck_mn, ck_va = mutation_moments(*pars) + assert np.isclose(ck_mn, mn, rtol=rtol) + assert np.isclose(sqrt(ck_va), sqrt(va), rtol=rtol) def test_mutation_rootward_moments(self, pars): """ Mutation mapped to a single branch with child node fixed """ - def f(t_i, t_j): # conditional moments - assert t_j < t_i - mn = t_i / 2 + t_j / 2 - sq = (t_i**2 + t_i*t_j + t_j**2) / 3 - return mn, sq + rtol = 1e-2 a_i, b_i, a_j, b_j, y, mu = pars pars_redux = (a_i, b_i, y, mu) - mn_j = a_j / b_j # point "estimate" for child + mn_j = a_j / b_j for t_j in [0.0, mn_j]: mn, va = approx.mutation_rootward_moments(t_j, *pars_redux) - nc = scipy.integrate.quad( - lambda t_i: self.pdf_rootward(t_i, t_j, *pars_redux), - t_j, - np.inf, - )[0] - ck_mn = scipy.integrate.quad( - lambda t_i: f(t_i, t_j)[0] * self.pdf_rootward(t_i, t_j, *pars_redux), - t_j, - np.inf, - )[0] / nc - assert np.isclose(mn, ck_mn, rtol=2e-2) - ck_va = scipy.integrate.quad( - lambda t_i: f(t_i, t_j)[1] * self.pdf_rootward(t_i, t_j, *pars_redux), - t_j, - np.inf, - )[0] / nc - ck_mn**2 - assert np.isclose(va, ck_va, rtol=2e-2) + ck_mn, ck_va = mutation_rootward_moments(t_j, *pars_redux) + assert np.isclose(ck_mn, mn, rtol=rtol) + assert np.isclose(sqrt(ck_va), sqrt(va), rtol=rtol) def test_mutation_leafward_moments(self, pars): """ Mutation mapped to a single branch with parent node fixed """ - def f(t_i, t_j): - assert t_j < t_i - mn = t_i / 2 + t_j / 2 - sq = (t_i**2 + t_i*t_j + t_j**2) / 3 - return mn, sq + rtol = 1e-2 a_i, b_i, a_j, b_j, y, mu = pars - t_i = a_i / b_i # point "estimate" for parent + t_i = a_i / b_i pars_redux = (a_j, b_j, y, mu) mn, va = approx.mutation_leafward_moments(t_i, *pars_redux) - nc = scipy.integrate.quad( - lambda t_j: self.pdf_leafward(t_i, t_j, *pars_redux), - 0, - t_i, - )[0] - ck_mn = scipy.integrate.quad( - lambda t_j: f(t_i, t_j)[0] * self.pdf_leafward(t_i, t_j, *pars_redux), - 0, - t_i, - )[0] / nc - assert np.isclose(mn, ck_mn, rtol=2e-2) - ck_va = scipy.integrate.quad( - lambda t_j: f(t_i, t_j)[1] * self.pdf_leafward(t_i, t_j, *pars_redux), - 0, - t_i, - )[0] / nc - ck_mn**2 - assert np.isclose(va, ck_va, rtol=2e-2) + ck_mn, ck_va = mutation_leafward_moments(t_i, *pars_redux) + assert np.isclose(ck_mn, mn, rtol=rtol) + assert np.isclose(sqrt(ck_va), sqrt(va), rtol=rtol) - def test_unphased_mutation_moments(self, pars): + def test_mutation_unphased_moments(self, pars): """ Mutation mapped to two singleton branches with children fixed to time zero """ - def f(t_i, t_j): # conditional moments - pr = t_i / (t_i + t_j) - mn = pr * t_i / 2 + (1 - pr) * t_j / 2 - sq = pr * t_i**2 / 3 + (1 - pr) * t_j**2 / 3 - return pr, mn, sq - pr, mn, va = approx.unphased_mutation_moments(*pars) - nc = scipy.integrate.dblquad( - lambda t_i, t_j: self.pdf_unphased(t_i, t_j, *pars), - 0, - np.inf, - 0, - np.inf, - epsabs=0, - )[0] - ck_pr = scipy.integrate.dblquad( - lambda t_i, t_j: f(t_i, t_j)[0] * self.pdf_unphased(t_i, t_j, *pars), - 0, - np.inf, - 0, - np.inf, - epsabs=0, - )[0] / nc - assert np.isclose(pr, ck_pr, rtol=2e-2) - ck_mn = scipy.integrate.dblquad( - lambda t_i, t_j: f(t_i, t_j)[1] * self.pdf_unphased(t_i, t_j, *pars), - 0, - np.inf, - 0, - np.inf, - epsabs=0, - )[0] / nc - assert np.isclose(mn, ck_mn, rtol=2e-2) - ck_va = scipy.integrate.dblquad( - lambda t_i, t_j: f(t_i, t_j)[2] * self.pdf_unphased(t_i, t_j, *pars), - 0, - np.inf, - 0, - np.inf, - epsabs=0, - )[0] / nc - ck_mn**2 - assert np.isclose(va, ck_va, rtol=2e-2) + rtol = 1e-2 + pr, mn, va = approx.mutation_unphased_moments(*pars) + ck_pr, ck_mn, ck_va = mutation_unphased_moments(*pars) + assert np.isclose(ck_pr, pr, rtol=rtol) + assert np.isclose(ck_mn, mn, rtol=rtol) + assert np.isclose(sqrt(ck_va), sqrt(va), rtol=rtol) - def test_unphased_mutation_rightward_moments(self, pars): + def test_mutation_sideways_moments(self, pars): """ Mutation mapped to two branches with children fixed to time zero, and left parent (i) fixed """ - def f(t_i, t_j): # conditional moments - pr = t_i / (t_i + t_j) - mn = pr * t_i / 2 + (1 - pr) * t_j / 2 - sq = pr * t_i**2 / 3 + (1 - pr) * t_j**2 / 3 - return pr, mn, sq + rtol = 1e-2 a_i, b_i, a_j, b_j, y, mu = pars - t_i = a_i / b_i # point "estimate" for left parent + t_i = a_i / b_i pars_redux = (a_j, b_j, y, mu) - pr, mn, va = approx.unphased_mutation_rightward_moments(t_i, *pars_redux) - nc = scipy.integrate.quad( - lambda t_j: self.pdf_unphased_rightward(t_i, t_j, *pars_redux), - 0, - np.inf, - )[0] - ck_pr = scipy.integrate.quad( - lambda t_j: f(t_i, t_j)[0] * self.pdf_unphased_rightward(t_i, t_j, *pars_redux), - 0, - np.inf, - )[0] / nc - assert np.isclose(pr, ck_pr, rtol=2e-2) - ck_mn = scipy.integrate.quad( - lambda t_j: f(t_i, t_j)[1] * self.pdf_unphased_rightward(t_i, t_j, *pars_redux), - 0, - np.inf, - )[0] / nc - assert np.isclose(mn, ck_mn, rtol=2e-2) - ck_va = scipy.integrate.quad( - lambda t_j: f(t_i, t_j)[2] * self.pdf_unphased_rightward(t_i, t_j, *pars_redux), - 0, - np.inf, - )[0] / nc - ck_mn**2 - assert np.isclose(va, ck_va, rtol=2e-2) + pr, mn, va = approx.mutation_sideways_moments(t_i, *pars_redux) + ck_pr, ck_mn, ck_va = mutation_sideways_moments(t_i, *pars_redux) + assert np.isclose(ck_pr, pr, rtol=rtol) + assert np.isclose(ck_mn, mn, rtol=rtol) + assert np.isclose(sqrt(ck_va), sqrt(va), rtol=rtol) + + def test_mutation_edge_moments(self, pars): + """ + Mutation mapped to a single edge with parent and child fixed + """ + a_i, b_i, a_j, b_j, y, mu = pars + t_i = a_i / b_i + t_j = a_j / b_j + mn, va = approx.mutation_edge_moments(t_i, t_j) + ck_mn, ck_va = mutation_edge_moments(t_i, t_j) + assert np.isclose(ck_mn, mn) + assert np.isclose(sqrt(ck_va), sqrt(va)) + + def test_mutation_block_moments(self, pars): + """ + Mutation mapped to two branches with children fixed to time zero, and + both parents fixed + """ + a_i, b_i, a_j, b_j, y, mu = pars + t_i = a_i / b_i + t_j = a_j / b_j + pr, mn, va = approx.mutation_block_moments(t_i, t_j) + ck_pr, ck_mn, ck_va = mutation_block_moments(t_i, t_j) + assert np.isclose(ck_pr, pr) + assert np.isclose(ck_mn, mn) + assert np.isclose(sqrt(ck_va), sqrt(va)) def test_approximate_gamma_kl(self, pars): a_i, b_i, a_j, b_j, y, mu = pars From 2379c30dc5d33549bd83ba7673073c9477b2dfaa Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 11 Jun 2024 09:57:54 -0700 Subject: [PATCH 08/29] Tests and such --- tsdate/approx.py | 72 +++++------ tsdate/core.py | 5 +- tsdate/phasing.py | 275 ++++++++++++++++++++++++++++++++++++++---- tsdate/variational.py | 192 ++++++++++++++++------------- 4 files changed, 399 insertions(+), 145 deletions(-) diff --git a/tsdate/approx.py b/tsdate/approx.py index 8e7b28ba..d3b290ea 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -57,8 +57,13 @@ _f3r = numba.types.Array(_f, 3, "C", readonly=True) _i1w = numba.types.Array(_i, 1, "C", readonly=False) _i1r = numba.types.Array(_i, 1, "C", readonly=True) +_i2w = numba.types.Array(_i, 2, "C", readonly=False) +_i2r = numba.types.Array(_i, 2, "C", readonly=True) _b1w = numba.types.Array(_b, 1, "C", readonly=False) _b1r = numba.types.Array(_b, 1, "C", readonly=True) +_tuple = numba.types.Tuple +_unituple = numba.types.UniTuple +_void = numba.types.void class KLMinimizationFailed(Exception): @@ -372,7 +377,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): t = mu_ij + b_i z = (mu_ij + b_j) / t if t > 0 else nan - if not _valid_hyp2f1(a, b, c, z): + if not _valid_hyp2f1(a, b, c, 1 - z): return nan, nan, nan, nan, nan hyp2f1 = hypergeo._hyp2f1_laplace @@ -398,7 +403,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): @numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) -def unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): +def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ @@ -506,23 +511,8 @@ def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): return mn_m, va_m -@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) -def mutation_fixed_moments(t_i, a_j, b_j, y_ij, mu_ij): - r""" - log p(t_m) := \ - log(t_i - t_j) + log(int(t_j < t_m < t_i)) - - Returns E[t_m], V[t_m]. - """ - - mn_m = 1 / 2 * (t_i + t_j) - va_m = 1 / 12 * (t_i - t_j) ** 2 - - return mn_m, va_m - - @numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f)) -def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): +def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_m, t_i, t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ @@ -548,8 +538,8 @@ def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): t = mu_ij + b_i z = (mu_ij + b_j) / t if t > 0 else nan - if not _valid_hyp2f1(a, b, c, z): - return nan, nan + if not _valid_hyp2f1(a, b, c, 1 - z): + return nan, nan, nan hyp2f1 = hypergeo._hyp2f1_laplace f000 = hyp2f1(a + 0, b + 0, c + 0, 1 - z) @@ -575,7 +565,7 @@ def unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): @numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) -def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): +def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_m, t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ @@ -623,7 +613,23 @@ def unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij): return pr_m, mn_m, va_m -def unphased_mutation_fixed_moments(t_i, t_j): +@numba.njit(_unituple(_f, 2)(_f, _f)) +def mutation_edge_moments(t_i, t_j): + r""" + log p(t_m) := \ + log(t_i - t_j) + log(int(t_j < t_m < t_i)) + + Returns E[t_m], V[t_m]. + """ + + mn_m = 1 / 2 * (t_i + t_j) + va_m = 1 / 12 * (t_i - t_j) ** 2 + + return mn_m, va_m + + +@numba.njit(_unituple(_f, 3)(_f, _f)) +def mutation_block_moments(t_i, t_j): r""" log p(t_m) := \ log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ @@ -746,7 +752,7 @@ def unphased_projection(pars_i, pars_j, pars_ij): @numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) -def unphased_rightward_projection(t_i, pars_j, pars_ij): +def sideways_projection(t_i, pars_j, pars_ij): r""" log p(t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ @@ -754,13 +760,11 @@ def unphased_rightward_projection(t_i, pars_j, pars_ij): Returns normalizing constant, gamma natural parameters for nonfixed parent age """ - a_i, b_i = pars_i a_j, b_j = pars_j y_ij, mu_ij = pars_ij - a_i += 1 a_j += 1 - logl, mn_j, va_j = unphased_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij) + logl, mn_j, va_j = sideways_moments(t_i, a_j, b_j, y_ij, mu_ij) if not _valid_moments(mn_j, va_j): return np.nan, pars_j @@ -846,19 +850,19 @@ def mutation_rootward_projection(t_j, pars_i, pars_ij): @numba.njit(_tuple((_f, _f1r))(_f, _f)) -def mutation_fixed_projection(t_i, t_j): +def mutation_edge_projection(t_i, t_j): r""" log p(t_m) := \ log(t_i - t_j) + log(int(t_j < t_m < t_i)) Returns phase probability, gamma natural parameters for mutation age """ - mn_m, va_m = mutation_fixed_moments(t_i, t_j) + mn_m, va_m = mutation_edge_moments(t_i, t_j) if not _valid_moments(mn_m, va_m): return np.nan, np.full(2, np.nan) - proj_m = approx.approximate_gamma_mom(mn_m, va_m) + proj_m = approximate_gamma_mom(mn_m, va_m) return 1.0, np.array(proj_m) @@ -881,7 +885,7 @@ def mutation_unphased_projection(pars_i, pars_j, pars_ij): a_i += 1 a_j += 1 - pr_m, mn_m, va_m = unphased_mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) + pr_m, mn_m, va_m = mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) if not _valid_moments(mn_m, va_m) or not (0 <= pr_m <= 1): return np.nan, np.full(2, np.nan) @@ -892,7 +896,7 @@ def mutation_unphased_projection(pars_i, pars_j, pars_ij): @numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) -def mutation_unphased_rightward_projection(t_i, pars_j, pars_ij): +def mutation_sideways_projection(t_i, pars_j, pars_ij): r""" log p(t_m, t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ @@ -906,7 +910,7 @@ def mutation_unphased_rightward_projection(t_i, pars_j, pars_ij): y_ij, mu_ij = pars_ij a_j += 1 - pr_m, mn_m, va_m = unphased_mutation_rightward_moments(t_i, a_j, b_j, y_ij, mu_ij) + pr_m, mn_m, va_m = mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij) if not _valid_moments(mn_m, va_m) or not (0 <= pr_m <= 1): return np.nan, np.full(2, np.nan) @@ -917,7 +921,7 @@ def mutation_unphased_rightward_projection(t_i, pars_j, pars_ij): @numba.njit(_tuple((_f, _f1r))(_f, _f)) -def mutation_unphased_fixed_projection(t_i, t_j): +def mutation_block_projection(t_i, t_j): r""" log p(t_m) := \ log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ @@ -925,7 +929,7 @@ def mutation_unphased_fixed_projection(t_i, t_j): Returns phase probability, gamma natural parameters for mutation age """ - pr_m, mn_m, va_m = unphased_mutation_fixed_moments(t_i, t_j) + pr_m, mn_m, va_m = mutation_block_moments(t_i, t_j) if not _valid_moments(mn_m, va_m) or not (0 <= pr_m <= 1): return np.nan, np.full(2, np.nan) diff --git a/tsdate/core.py b/tsdate/core.py index e23ceb96..fe9f55e6 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1281,11 +1281,11 @@ def run( # TODO: use dynamic_prog.point_estimate posterior_mean, posterior_vari = self.mean_var( - dynamic_prog.posterior, dynamic_prog.constraints + dynamic_prog.node_posterior, dynamic_prog.node_constraints ) # TODO: clean up - mutation_post = dynamic_prog.mutations_posterior + mutation_post = dynamic_prog.mutation_posterior mutation_mean = np.full(mutation_post.shape[0], np.nan) mutation_vari = np.full(mutation_post.shape[0], np.nan) idx = mutation_post[:, 1] > 0 @@ -1577,7 +1577,6 @@ def variational_gamma( max_shape=None, match_segregating_sites=None, regularise_roots=None, - rescaling_intervals=None, **kwargs, ): """ diff --git a/tsdate/phasing.py b/tsdate/phasing.py index 2511cfbf..525ad953 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -22,9 +22,260 @@ """ Tools for phasing singleton mutations """ + +import numba import numpy as np import tskit +from .approx import _f +from .approx import _f1r +from .approx import _f1w +from .approx import _f2r +from .approx import _f2w +from .approx import _i +from .approx import _i1r +from .approx import _i1w +from .approx import _i2r +from .approx import _i2w +from .approx import _b +from .approx import _b1r +from .approx import _tuple + +# --- machinery used by ExpectationPropagation class --- # + +@numba.njit(_f1w(_f1r, _b1r, _f1r, _i1r, _i2r)) +def reallocate_unphased(edges_mutations, edges_unphased, mutations_phase, mutations_block, blocks_edges): + """ + Add a proportion of each unphased singleton mutation to one of the two + edges to which it maps, and returns the modified `edges_mutations`. + """ + assert mutations_phase.size == mutations_block.size + assert blocks_edges.shape[1] == 2 + assert edges_mutations.size == edges_unphased.size + + num_mutations = mutations_phase.size + num_edges = edges_mutations.size + num_blocks = blocks_edges.shape[0] + + new_edges_mutations = edges_mutations.copy() + new_edges_mutations[edges_unphased] = 0.0 + for m, b in enumerate(mutations_block): + if b == tskit.NULL: + continue + i, j = blocks_edges[b] + assert tskit.NULL < i < num_edges and edges_unphased[i] + assert tskit.NULL < j < num_edges and edges_unphased[j] + assert 0.0 <= mutations_phase[m] <= 1.0 + new_edges_mutations[i] += mutations_phase[m] + new_edges_mutations[j] += 1 - mutations_phase[m] + + assert np.isclose(np.sum(new_edges_mutations), np.sum(edges_mutations)) + return new_edges_mutations + + +@numba.njit(_tuple((_f2w, _i2w, _i1w))(_b1r, _i1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f)) +def _block_singletons(individuals_unphased, nodes_individual, mutations_node, mutations_position, edges_parent, edges_child, edges_left, edges_right, indexes_insert, indexes_remove, sequence_length): + """ + TODO + """ + assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size + assert indexes_insert.size == indexes_remove.size == edges_parent.size + assert mutations_node.size == mutations_position.size + + num_nodes = nodes_individual.size + num_mutations = mutations_node.size + num_edges = edges_parent.size + num_individuals = individuals_unphased.size + + indexes_mutation = np.argsort(mutations_position) + position_insert = edges_left[indexes_insert] + position_remove = edges_right[indexes_remove] + position_mutation = mutations_position[indexes_mutation] + + individuals_edges = np.full((num_individuals, 2), tskit.NULL) + individuals_position = np.full(num_individuals, np.nan) + individuals_singletons = np.zeros(num_individuals) + individuals_block = np.full(num_edges, tskit.NULL) + mutations_block = np.full(num_mutations, tskit.NULL) + + blocks_span = [] + blocks_singletons = [] + blocks_edges = [] + blocks_order = [] + + num_blocks = 0 + left = 0.0 + a, b, d = 0, 0, 0 + while a < num_edges or b < num_edges: + while b < num_edges and position_remove[b] == left: # edges out + e = indexes_remove[b] + p, c = edges_parent[e], edges_child[e] + i = nodes_individual[c] + if i != tskit.NULL and individuals_unphased[i]: + u, v = individuals_edges[i] + assert u == e or v == e + s = u if v == e else v + individuals_edges[i] = s, tskit.NULL + if s != tskit.NULL: # flush block + blocks_order.append(individuals_block[i]) + blocks_edges.extend([e, s]) + blocks_singletons.append(individuals_singletons[i]) + blocks_span.append(left - individuals_position[i]) + individuals_position[i] = np.nan + individuals_block[i] = tskit.NULL + individuals_singletons[i] = 0.0 + b += 1 + + while a < num_edges and position_insert[a] == left: # edges in + e = indexes_insert[a] + p, c = edges_parent[e], edges_child[e] + i = nodes_individual[c] + if i != tskit.NULL and individuals_unphased[i]: + u, v = individuals_edges[i] + assert u == tskit.NULL or v == tskit.NULL + individuals_edges[i] = [e, max(u, v)] + individuals_position[i] = left + if individuals_block[i] == tskit.NULL: + individuals_block[i] = num_blocks + num_blocks += 1 + a += 1 + + right = sequence_length + if b < num_edges: + right = min(right, position_remove[b]) + if a < num_edges: + right = min(right, position_insert[a]) + left = right + + while d < num_mutations and position_mutation[d] < right: # mutations + m = indexes_mutation[d] + c = mutations_node[m] + i = nodes_individual[c] + if i != tskit.NULL and individuals_unphased[i]: + mutations_block[m] = individuals_block[i] + individuals_singletons[i] += 1.0 + d += 1 + + mutations_block = mutations_block.astype(np.int32) + blocks_edges = np.array(blocks_edges, dtype=np.int32).reshape(-1, 2) + blocks_singletons = np.array(blocks_singletons) + blocks_span = np.array(blocks_span) + blocks_order = np.array(blocks_order) + blocks_stats = np.column_stack((blocks_singletons, blocks_span)) + assert num_blocks == blocks_edges.shape[0] == blocks_stats.shape[0] + + # sort block arrays so that mutations_block points to correct row + blocks_order = np.argsort(blocks_order) + blocks_edges = blocks_edges[blocks_order] + blocks_stats = blocks_stats[blocks_order] + + return blocks_stats, blocks_edges, mutations_block + + +def block_singletons(ts, individuals_unphased): + """ + TODO + """ + for i in ts.individuals(): + if individuals_unphased[i.id] and i.nodes.size != 2: + raise ValueError("Singleton blocking assumes diploid individuals") + + # TODO: adjust spans by an accessibility mask + return _block_singletons( + individuals_unphased, + ts.nodes_individual, + ts.mutations_node, + ts.sites_position[ts.mutations_site], + ts.edges_parent, + ts.edges_child, + ts.edges_left, + ts.edges_right, + ts.indexes_edge_insertion_order, + ts.indexes_edge_removal_order, + ts.sequence_length, + ) + + +@numba.njit(_tuple((_f2w, _i1w))(_i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _i, _f)) +def _count_mutations(mutations_node, mutations_position, edges_parent, edges_child, edges_left, edges_right, indexes_insert, indexes_remove, num_nodes, sequence_length): + """ + TODO + """ + assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size + assert indexes_insert.size == indexes_remove.size == edges_parent.size + assert mutations_node.size == mutations_position.size + + num_mutations = mutations_node.size + num_edges = edges_parent.size + + indexes_mutation = np.argsort(mutations_position) + position_insert = edges_left[indexes_insert] + position_remove = edges_right[indexes_remove] + position_mutation = mutations_position[indexes_mutation] + + nodes_edge = np.full(num_nodes, tskit.NULL) + mutations_edge = np.full(num_mutations, tskit.NULL) + edges_mutations = np.zeros(num_edges) + edges_span = edges_right - edges_left + + left = 0.0 + a, b, d = 0, 0, 0 + while a < num_edges or b < num_edges: + while b < num_edges and position_remove[b] == left: # edges out + e = indexes_remove[b] + p, c = edges_parent[e], edges_child[e] + nodes_edge[c] = tskit.NULL + b += 1 + + while a < num_edges and position_insert[a] == left: # edges in + e = indexes_insert[a] + p, c = edges_parent[e], edges_child[e] + nodes_edge[c] = e + a += 1 + + right = sequence_length + if b < num_edges: + right = min(right, position_remove[b]) + if a < num_edges: + right = min(right, position_insert[a]) + left = right + + while d < num_mutations and position_mutation[d] < right: + m = indexes_mutation[d] + c = mutations_node[m] + e = nodes_edge[c] + if e != tskit.NULL: + mutations_edge[m] = e + edges_mutations[e] += 1.0 + d += 1 + + mutations_edge = mutations_edge.astype(np.int32) + edges_stats = np.column_stack((edges_mutations, edges_span)) + + return edges_stats, mutations_edge + + +def count_mutations(ts): + """ + TODO + """ + # TODO: adjust spans by an accessibility mask + return _count_mutations( + ts.mutations_node, + ts.sites_position[ts.mutations_site], + ts.edges_parent, + ts.edges_child, + ts.edges_left, + ts.edges_right, + ts.indexes_edge_insertion_order, + ts.indexes_edge_removal_order, + ts.num_nodes, + ts.sequence_length, + ) + + +# --- helper functions --- # + def remove_singletons(ts): """ Remove all singleton mutations from the tree sequence. @@ -150,27 +401,3 @@ def insert_unphased_singletons(ts, position, individual, reference_state, altern return tables.tree_sequence() -def accumulate_unphased(edges_mutations, mutations_phase, mutations_block, block_edges): - """ - Add a proportion of each unphased singleton mutation to one of the two - edges to which it maps. - """ - unphased = mutations_block != tskit.NULL - assert np.all(mutations_phase[~unphased] == 1.0) - assert np.all( - np.logical_and( - mutations_phase[unphased] <= 1.0, - mutations_phase[unphased] >= 0.0, - ) - ) - for b in mutations_block[unphased]: - if b == tskit.NULL: - continue - i, j = block_edges[b] - edges_mutations[i] += mutations_phase - edges_mutations[j] += 1 - mutations_phase - assert np.sum(edges_mutations) == mutations_block.size - return edges_mutations - - -# TODO: mutation sort order diff --git a/tsdate/variational.py b/tsdate/variational.py index d5f9b398..4bbd69d4 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -49,6 +49,8 @@ from .rescaling import mutational_timescale from .rescaling import piecewise_scale_posterior from .util import contains_unary_nodes +from .phasing import count_mutations +from .phasing import block_singletons # columns for edge_factors @@ -233,51 +235,57 @@ def __init__(self, ts, likelihoods, constraints, mutations_edge): self._check_valid_inputs(ts, likelihoods, constraints, mutations_edge) # const - self.parents = ts.edges_parent - self.children = ts.edges_child - self.likelihoods = likelihoods - self.constraints = constraints - self.mutations_edge = mutations_edge + self.edge_parents = ts.edges_parent + self.edge_children = ts.edges_child + self.edge_likelihoods = likelihoods + self.node_constraints = constraints + self.mutation_edges = mutations_edge # TODO: get likelihoods + unphaseed + individual_unphased = np.full(ts.num_individuals, False) + self.edge_likelihoods, self.mutation_edges = count_mutations(ts) + self.block_likelihoods, self.block_edges, \ + self.mutation_blocks = block_singletons(ts, individual_unphased) + self.block_edges = np.ascontiguousarray(self.block_edges.T) + self.edge_unphased = np.full(ts.num_edges, False) + self.edge_unphased[self.block_edges[0]] = True + self.edge_unphased[self.block_edges[1]] = True + num_blocks = self.block_likelihoods.shape[0] # mutable self.node_factors = np.zeros((ts.num_nodes, 2, 2)) self.edge_factors = np.zeros((ts.num_edges, 2, 2)) - self.block_factors = np.zeros((num_blocks, 2, 2)) #TODO - self.posterior = np.zeros((ts.num_nodes, 2)) - #self.mutation_posterior = np.full((ts.num_mutations, np.nan)) - self.log_partition = np.zeros(ts.num_edges) - #self.edge_logconst = ... - #self.block_logconst = ... - self.scale = np.ones(ts.num_nodes) - assert False + self.block_factors = np.zeros((num_blocks, 2, 2)) + self.node_posterior = np.zeros((ts.num_nodes, 2)) + self.mutation_posterior = np.full((ts.num_mutations, 2), np.nan) + self.edge_logconst = np.zeros(ts.num_edges) + self.block_logconst = np.zeros(num_blocks) + self.node_scale = np.ones(ts.num_nodes) # terminal nodes has_parent = np.full(ts.num_nodes, False) has_child = np.full(ts.num_nodes, False) - has_parent[self.children] = True - has_child[self.parents] = True + has_parent[self.edge_children] = True + has_child[self.edge_parents] = True self.roots = np.logical_and(has_child, ~has_parent) self.leaves = np.logical_and(~has_child, has_parent) if np.any(np.logical_and(~has_child, ~has_parent)): raise ValueError("Tree sequence contains disconnected nodes") # edge traversal order - edges = np.arange(ts.num_edges, dtype=np.int32) - # TODO: mask singleton edges - assert False + edges = np.arange(ts.num_edges, dtype=np.int32)[~self.edge_unphased] self.edge_order = np.concatenate((edges[:-1], np.flip(edges))) self.edge_weights = edge_sampling_weight( self.leaves, - ts.edges_parent, - ts.edges_child, + self.edge_parents, + self.edge_children, ts.edges_left, ts.edges_right, ts.indexes_edge_insertion_order, ts.indexes_edge_removal_order, ) self.block_order = np.arange(num_blocks, dtype=np.int32) + self.mutation_order = np.arange(ts.num_mutations, dtype=np.int32) @staticmethod @numba.njit(_f(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f)) @@ -293,6 +301,7 @@ def propagate_likelihood( scale, max_shape, min_step, + #unphased, ): """ Update approximating factors for Poisson mutation likelihoods on edges. @@ -330,13 +339,16 @@ def cavity_damping(x, y): def posterior_damping(x): return _rescale(x, max_shape) - # TODO - # in "unphased" mode, edges are singleton blocks, and the two parents + # if "unphased" edges are singleton blocks, and the two parents # of each block are given by "parents" and "children" - assert False - leafward_projection = approx.leafward_projection if unphased else approx.unphased_fixed_projection - rootward_projection = approx.rootward_projection if unphased else approx.unphased_fixed_projection - gamma_projection = approx.gamma_projection if unphased else approx.unphased_projection + #if unphased: + # leafward_projection = approx.sideways_projection + # rootward_projection = approx.sideways_projection + # gamma_projection = approx.unphased_projection + #else: + leafward_projection = approx.leafward_projection + rootward_projection = approx.rootward_projection + gamma_projection = approx.gamma_projection fixed = constraints[:, LOWER] == constraints[:, UPPER] @@ -474,11 +486,11 @@ def posterior_damping(x): return np.nan - # TODO add arguments, void return @staticmethod - @numba.njit(_f2w(_i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r)) + @numba.njit(_f(_i1r, _f2w, _i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r)) def propagate_mutations( - TODO, + mutations_order, + mutations_posterior, mutations_edge, edges_parent, edges_child, @@ -495,7 +507,6 @@ def propagate_mutations( which to traverse mutations :param ndarray mutations_posterior: array of dimension `[num_mutations, 2]` containing natural parameters for each mutation - :param ndarray mutations_edge: integer array giving edge for each mutation :param ndarray edges_parent: integer array of parent ids per edge @@ -517,35 +528,43 @@ def propagate_mutations( # TODO: we don't seem to need to damp? # TODO: might as well copy format in other functions and have void return + # assert stuff here assert constraints.shape == posterior.shape assert edges_child.size == edges_parent.size assert factors.shape == (edges_parent.size, 2, 2) assert likelihoods.shape == (edges_parent.size, 2) - #mutations_posterior = np.zeros((mutations_edge.size, 2)) - #pass in mutations posterior filled with nan TODO + #if unphased: + # gamma_projection = approx.mutation_unphased_projection + # leafward_projection = approx.mutation_sideways_projection + # rootward_projection = approx.mutation_sideways_projection + # fixed_projection = approx.mutation_block_projection + #else: + gamma_projection = approx.mutation_gamma_projection + leafward_projection = approx.mutation_leafward_projection + rootward_projection = approx.mutation_rootward_projection + fixed_projection = approx.mutation_edge_projection + + phase = np.zeros(mutations_edge.size) # TODO fixed = constraints[:, LOWER] == constraints[:, UPPER] for m in mutations_order: i = mutations_edge[m] - if i == tskit.NULL: # skip mutations above root or unphased + if i == tskit.NULL: # skip mutations above root continue p, c = edges_parent[i], edges_child[i] if fixed[p] and fixed[c]: child_age = constraints[c, 0] parent_age = constraints[p, 0] - mutations_posterior[m] = approx.mutation_fixed_projection( - parent_age, child_age + phase[m], mutations_posterior[m] = fixed_projection( + parent_age, child_age ) - #mean = 1 / 2 * (child_age + parent_age) - #variance = 1 / 12 * (parent_age - child_age) ** 2 - #mutations_posterior[m] = approx.approximate_gamma_mom(mean, variance) elif fixed[p] and not fixed[c]: child_message = factors[i, LEAFWARD] * scale[c] child_delta = 1.0 # hopefully we don't need to damp child_cavity = posterior[c] - child_delta * child_message edge_likelihood = child_delta * likelihoods[i] parent_age = constraints[p, LOWER] - mutations_posterior[m] = approx.mutation_leafward_projection( + phase[m], mutations_posterior[m] = leafward_projection( parent_age, child_cavity, edge_likelihood, ) elif fixed[c] and not fixed[p]: @@ -554,7 +573,7 @@ def propagate_mutations( parent_cavity = posterior[p] - parent_delta * parent_message edge_likelihood = parent_delta * likelihoods[i] child_age = constraints[c, LOWER] - mutations_posterior[m] = approx.mutation_rootward_projection( + phase[m], mutations_posterior[m] = rootward_projection( child_age, parent_cavity, edge_likelihood, ) else: @@ -566,11 +585,11 @@ def propagate_mutations( parent_cavity = posterior[p] - delta * parent_message child_cavity = posterior[c] - delta * child_message edge_likelihood = delta * likelihoods[i] - mutations_posterior[m] = approx.mutation_gamma_projection( + phase[m], mutations_posterior[m] = gamma_projection( parent_cavity, child_cavity, edge_likelihood, ) - return mutations_posterior + return np.nan # TODO more arguments, blck_factors and block_parents @staticmethod @@ -580,9 +599,10 @@ def rescale_factors(edges_parent, edges_child, node_factors, edge_factors, scale p, c = edges_parent, edges_child edge_factors[:, ROOTWARD] *= scale[p, np.newaxis] edge_factors[:, LEAFWARD] *= scale[c, np.newaxis] - j, k = block_parents - block_factors[:, ROOTWARD] *= scale[j, np.newaxis] - block_factors[:, LEAFWARD] *= scale[k, np.newaxis] + # TODO + #j, k = blocks_parents + #block_factors[:, ROOTWARD] *= scale[j, np.newaxis] + #block_factors[:, LEAFWARD] *= scale[k, np.newaxis] node_factors[:, MIXPRIOR] *= scale[:, np.newaxis] node_factors[:, CONSTRNT] *= scale[:, np.newaxis] scale[:] = 1.0 @@ -603,10 +623,10 @@ def iterate( # self.block_parents[0], # self.block_parents[1], # self.block_likelihoods, - # self.constraints, + # self.node_constraints, # self.block_factors, # self.block_log_partition, - # self.scale, + # self.node_scale, # max_shape, # min_step, # USE_BLOCK_LIKELIHOOD, @@ -615,26 +635,26 @@ def iterate( # rootward + leafward pass through edges self.propagate_likelihood( self.edge_order, - self.parents, - self.children, - self.likelihoods, - self.constraints, - self.posterior, + self.edge_parents, + self.edge_children, + self.edge_likelihoods, + self.node_constraints, + self.node_posterior, self.edge_factors, - self.log_partition, - self.scale, + self.edge_logconst, + self.node_scale, max_shape, min_step, - USE_EDGE_LIKELIHOOD, + #USE_EDGE_LIKELIHOOD, ) # exponential regularization on roots if regularise: self.propagate_prior( self.roots, - self.posterior, + self.node_posterior, self.node_factors, - self.scale, + self.node_scale, max_shape, em_maxitt, em_reltol, @@ -642,18 +662,18 @@ def iterate( # absorb the scaling term into the factors self.rescale_factors( - self.parents, - self.children, + self.edge_parents, + self.edge_children, self.node_factors, self.edge_factors, - self.scale, + self.node_scale, ) if check_valid: # for debugging assert self._check_valid_state( - self.parents, - self.children, - self.posterior, + self.edge_parents, + self.edge_children, + self.node_posterior, self.node_factors, self.edge_factors, ) @@ -672,25 +692,27 @@ def rescale( edge_weights = ( np.ones(self.edge_weights.size) if rescale_segsites else self.edge_weights ) - nodes_time = self._point_estimate(self.posterior, self.constraints, use_median) + nodes_time = self._point_estimate( + self.node_posterior, self.node_constraints, use_median + ) original_breaks, rescaled_breaks = mutational_timescale( nodes_time, - self.likelihoods, - self.constraints, - self.parents, - self.children, + self.edge_likelihoods, + self.node_constraints, + self.edge_parents, + self.edge_children, edge_weights, rescale_intervals, ) - self.posterior[:] = piecewise_scale_posterior( - self.posterior, + self.node_posterior[:] = piecewise_scale_posterior( + self.node_posterior, original_breaks, rescaled_breaks, quantile_width, use_median, ) - self.mutations_posterior[:] = piecewise_scale_posterior( - self.mutations_posterior, + self.mutation_posterior[:] = piecewise_scale_posterior( + self.mutation_posterior, original_breaks, rescaled_breaks, quantile_width, @@ -720,24 +742,26 @@ def run( regularise=regularise, ) nodes_timing -= time.time() - skipped_nodes = np.sum(np.isnan(self.log_partition)) - if skipped_nodes: - logging.info(f"Skipped {skipped_nodes} nodes with invalid posteriors") + skipped_edges = np.sum(np.isnan(self.edge_logconst)) + if skipped_edges: + logging.info(f"Skipped {skipped_edges} edges with invalid factors") logging.info(f"Calculated node posteriors in {abs(nodes_timing)} seconds") muts_timing = time.time() - self.mutations_posterior = self.propagate_mutations( - self.mutations_edge, - self.parents, - self.children, - self.likelihoods, - self.constraints, - self.posterior, + self.propagate_mutations( + self.mutation_order, + self.mutation_posterior, + self.mutation_edges, + self.edge_parents, + self.edge_children, + self.edge_likelihoods, + self.node_constraints, + self.node_posterior, self.edge_factors, - self.scale, + self.node_scale, ) muts_timing -= time.time() - skipped_muts = np.sum(np.isnan(self.mutations_posterior[:, 0])) + skipped_muts = np.sum(np.isnan(self.mutation_posterior[:, 0])) if skipped_muts: logging.info(f"Skipped {skipped_muts} mutations with invalid posteriors") logging.info(f"Calculated mutation posteriors in {abs(muts_timing)} seconds") From 186347a446f8ec31cab6203cfb86a723f4e72804 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 11 Jun 2024 10:11:45 -0700 Subject: [PATCH 09/29] Tests and such --- tests/test_phasing.py | 229 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 tests/test_phasing.py diff --git a/tests/test_phasing.py b/tests/test_phasing.py new file mode 100644 index 00000000..da015194 --- /dev/null +++ b/tests/test_phasing.py @@ -0,0 +1,229 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# Copyright (c) 2020-21 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for the gamma-variational approximations in tsdate +""" + +import numpy as np +import pytest +import tskit +import msprime +import tsinfer + +from tsdate.phasing import block_singletons +from tsdate.phasing import count_mutations + + +@pytest.fixture(scope="session") +def inferred_ts(): + ts = msprime.sim_ancestry( + 10, + population_size=1e4, + recombination_rate=1e-8, + sequence_length=1e6, + random_seed=1, + ) + ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) + sample_data = tsinfer.SampleData.from_tree_sequence(ts) + inferred_ts = tsinfer.infer(sample_data).simplify() + return inferred_ts + + +class TestCountMutations: + def test_count_mutations(self, inferred_ts): + edge_stats, muts_edge = count_mutations(inferred_ts) + ck_edge_muts = np.zeros(inferred_ts.num_edges) + ck_muts_edge = np.full(inferred_ts.num_mutations, tskit.NULL) + for m in inferred_ts.mutations(): + if m.edge != tskit.NULL: + ck_edge_muts[m.edge] += 1.0 + ck_muts_edge[m.id] = m.edge + ck_edge_span = inferred_ts.edges_right - inferred_ts.edges_left + np.testing.assert_array_almost_equal(ck_edge_muts, edge_stats[:, 0]) + np.testing.assert_array_almost_equal(ck_edge_span, edge_stats[:, 1]) + np.testing.assert_array_equal(ck_muts_edge, muts_edge) + + +class TestBlockSingletons: + + @staticmethod + def naive_block_singletons(ts, individual): + """ + Get all intervals where the two intermediate parents of an individual are + unchanged over the interval. + """ + i = individual + j, k = ts.individual(i).nodes + last_block = np.full(2, tskit.NULL) + last_span = np.zeros(2) + muts_edges = np.full((ts.num_mutations, 2), tskit.NULL) + blocks_edge = [] + blocks_span = [] + for tree in ts.trees(): + if tree.num_edges == 0: # skip tree + muts = [] + span = 0.0 + block = tskit.NULL, tskit.NULL + else: + muts = [m.id for m in tree.mutations() if m.node == j or m.node == k] + span = tree.interval.span + block = tree.edge(j), tree.edge(k) + for m in muts: + muts_edges[m] = block + if last_block[0] != tskit.NULL and not np.array_equal(block, last_block): # flush block + blocks_edge.extend(last_block) + blocks_span.extend(last_span) + last_span[:] = 0.0 + last_span += len(muts), span + last_block[:] = block + if last_block[0] != tskit.NULL: # flush last block + blocks_edge.extend(last_block) + blocks_span.extend(last_span) + blocks_edge = np.array(blocks_edge).reshape(-1, 2) + blocks_span = np.array(blocks_span).reshape(-1, 2) + total_span = np.sum([t.interval.span for t in ts.trees() if t.num_edges > 0]) + total_muts = np.sum(np.logical_or(ts.mutations_node == j, ts.mutations_node == k)) + assert np.sum(blocks_span[:, 0]) == total_muts + assert np.sum(blocks_span[:, 1]) == total_span + return blocks_span, blocks_edge, muts_edges + + def test_against_naive(self, inferred_ts): + """ + Test fast routine against simpler tree-by-tree, + individual-by-individual implementation + """ + ts = inferred_ts + individuals_unphased = np.full(ts.num_individuals, False) + unphased_individuals = np.arange(0, ts.num_individuals // 2) + individuals_unphased[unphased_individuals] = True + block_stats, block_edges, muts_block = block_singletons(ts, individuals_unphased) + block_edges = block_edges + singletons = muts_block != tskit.NULL + muts_edges = np.full((ts.num_mutations, 2), tskit.NULL) + muts_edges[singletons] = block_edges[muts_block[singletons]] + ck_num_blocks = 0 + ck_num_singletons = 0 + for i in np.flatnonzero(individuals_unphased): + ck_block_stats, ck_block_edges, ck_muts_edges = self.naive_block_singletons(ts, i) + ck_num_blocks += ck_block_stats.shape[0] + # blocks of individual i + nodes_i = ts.individual(i).nodes + blocks_i = np.isin(ts.edges_child[block_edges.min(axis=1)], nodes_i) + np.testing.assert_allclose(block_stats[blocks_i], ck_block_stats) + np.testing.assert_array_equal( + np.min(block_edges[blocks_i], axis=1), np.min(ck_block_edges, axis=1) + ) + np.testing.assert_array_equal( + np.max(block_edges[blocks_i], axis=1), np.max(ck_block_edges, axis=1) + ) + # singleton mutations in unphased individual i + ck_muts_i = ck_muts_edges[:, 0] != tskit.NULL + np.testing.assert_array_equal( + np.min(muts_edges[ck_muts_i], axis=1), + np.min(ck_muts_edges[ck_muts_i], axis=1), + ) + np.testing.assert_array_equal( + np.max(muts_edges[ck_muts_i], axis=1), + np.max(ck_muts_edges[ck_muts_i], axis=1), + ) + ck_num_singletons += np.sum(ck_muts_i) + assert ck_num_blocks == block_stats.shape[0] == block_edges.shape[0] + assert ck_num_singletons == np.sum(singletons) + + def test_total_counts(self, inferred_ts): + """ + Sanity check: total number of mutations should equal number of singletons + and total edge span should equal sum of spans of singleton edges + """ + ts = inferred_ts + individuals_unphased = np.full(ts.num_individuals, False) + unphased_individuals = np.arange(0, ts.num_individuals // 2) + individuals_unphased[unphased_individuals] = True + unphased_nodes = np.concatenate([ts.individual(i).nodes for i in unphased_individuals]) + total_singleton_span = 0.0 + total_singleton_muts = 0.0 + for t in ts.trees(): + if t.num_edges == 0: continue + for s in t.samples(): + if s in unphased_nodes: + total_singleton_span += t.span + for m in t.mutations(): + if t.num_samples(m.node) == 1 and (m.node in unphased_nodes): + e = t.edge(m.node) + total_singleton_muts += 1.0 + block_stats, *_ = block_singletons(ts, individuals_unphased) + assert np.isclose(np.sum(block_stats[:, 0]), total_singleton_muts) + assert np.isclose(np.sum(block_stats[:, 1]), total_singleton_span / 2) + + def test_singleton_edges(self, inferred_ts): + """ + Sanity check: all singleton edges attached to unphased individuals + should show up in blocks + """ + ts = inferred_ts + individuals_unphased = np.full(ts.num_individuals, False) + unphased_individuals = np.arange(0, ts.num_individuals // 2) + individuals_unphased[unphased_individuals] = True + unphased_nodes = set(np.concatenate([ts.individual(i).nodes for i in unphased_individuals])) + ck_singleton_edge = set() + for t in ts.trees(): + if t.num_edges == 0: continue + for s in ts.samples(): + if s in unphased_nodes: + ck_singleton_edge.add(t.edge(s)) + _, block_edges, *_ = block_singletons(ts, individuals_unphased) + singleton_edge = set([i for i in block_edges.flatten()]) + assert singleton_edge == ck_singleton_edge + + def test_singleton_mutations(self, inferred_ts): + """ + Sanity check: all singleton mutations in unphased individuals + should show up in blocks + """ + ts = inferred_ts + individuals_unphased = np.full(ts.num_individuals, False) + unphased_individuals = np.arange(0, ts.num_individuals // 2) + individuals_unphased[unphased_individuals] = True + unphased_nodes = np.concatenate([ts.individual(i).nodes for i in unphased_individuals]) + ck_singleton_muts = set() + for t in ts.trees(): + if t.num_edges == 0: continue + for m in t.mutations(): + if t.num_samples(m.node) == 1 and (m.node in unphased_nodes): + ck_singleton_muts.add(m.id) + _, _, block_muts = block_singletons(ts, individuals_unphased) + singleton_muts = set([i for i in np.flatnonzero(block_muts != tskit.NULL)]) + assert singleton_muts == ck_singleton_muts + + def test_all_phased(self, inferred_ts): + """ + Test that empty arrays are returned when all individuals are phased + """ + ts = inferred_ts + individuals_unphased = np.full(ts.num_individuals, False) + block_stats, block_edges, block_muts = block_singletons(ts, individuals_unphased) + assert block_stats.shape == (0, 2) + assert block_edges.shape == (0, 2) + assert np.all(block_muts == tskit.NULL) + From f089878ec14af8aa342a849e6ad362fc83ca5696 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 11 Jun 2024 10:59:02 -0700 Subject: [PATCH 10/29] Tests and such --- tsdate/core.py | 14 +---------- tsdate/variational.py | 57 +++++++++++++++++-------------------------- 2 files changed, 23 insertions(+), 48 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index fe9f55e6..c4885082 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1234,20 +1234,8 @@ def mean_var(posteriors, constraints): return mn_post, va_post def main_algorithm(self): - # edge likelihoods - # TODO: variable mutation rates across genome - # TODO: truncate edge spans with accessiblity mask - likelihoods = self.edges_mutations.copy() - likelihoods[:, 1] *= self.mutation_rate - - # lower and upper bounds on node ages - sample_idx = list(self.ts.samples()) - constraints = np.zeros((self.ts.num_nodes, 2)) - constraints[:, 1] = np.inf - constraints[sample_idx, :] = self.ts.nodes_time[sample_idx, np.newaxis] - return variational.ExpectationPropagation( - self.ts, likelihoods, constraints, self.mutations_edge + self.ts, mutation_rate=self.mutation_rate ) def run( diff --git a/tsdate/variational.py b/tsdate/variational.py index 4bbd69d4..51115042 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -160,24 +160,11 @@ def _check_valid_constraints(constraints, edges_parent, edges_child): ) @staticmethod - def _check_valid_inputs(ts, likelihoods, constraints, mutations_edge): + def _check_valid_inputs(ts, mutation_rate): + if not mutation_rate > 0.0: + raise ValueError("Mutation rate must be positive") if contains_unary_nodes(ts): - raise ValueError( - "Tree sequence contains unary nodes, simplify before dating" - ) - if likelihoods.shape != (ts.num_edges, 2): - raise ValueError("Edge likelihoods are the wrong shape") - if constraints.shape != (ts.num_nodes, 2): - raise ValueError("Node age constraints are the wrong shape") - if np.any(likelihoods < 0.0): - raise ValueError("Edge likelihoods contains negative values") - if np.any(constraints < 0.0): - raise ValueError("Node age constraints contain negative values") - if mutations_edge.size > 0 and mutations_edge.max() >= ts.num_edges: - raise ValueError("Mutation edge indices are out-of-bounds") - ExpectationPropagation._check_valid_constraints( - constraints, ts.edges_parent, ts.edges_child - ) + raise ValueError("Tree sequence contains unary nodes, simplify first") @staticmethod def _check_valid_state( @@ -207,7 +194,7 @@ def _point_estimate(posteriors, constraints, median): point_estimate[fixed] = constraints[fixed, 0] return point_estimate - def __init__(self, ts, likelihoods, constraints, mutations_edge): + def __init__(self, ts, *, mutation_rate): """ Initialize an expectation propagation algorithm for dating nodes in a tree sequence. @@ -220,32 +207,32 @@ def __init__(self, ts, likelihoods, constraints, mutations_edge): :param ~tskit.TreeSequence ts: a tree sequence containing the partial ordering of nodes. - :param ~np.ndarray constraints: a `ts.num_nodes`-by-two array containing - lower and upper bounds for each node. If lower and upper bounds - are the same value, the node is considered fixed. - :param ~np.ndarray likelihoods: a `ts.num_edges`-by-two array containing - mutation counts and mutational spans (e.g. edge span multiplied by - mutation rate) per edge. - :param ~np.ndarray mutations_edge: an array containing edge indices - (one per mutation) for which to compute posteriors. + :param ~float mutation_rate: the expected per-base mutation rate per + time unit. """ - # TODO: pass in edge table rather than tree sequence - # TODO: check valid mutations_edge - self._check_valid_inputs(ts, likelihoods, constraints, mutations_edge) - - # const + self._check_valid_inputs(ts, mutation_rate) self.edge_parents = ts.edges_parent self.edge_children = ts.edges_child - self.edge_likelihoods = likelihoods - self.node_constraints = constraints - self.mutation_edges = mutations_edge - # TODO: get likelihoods + unphaseed + # lower and upper bounds on node ages + samples = list(ts.samples()) + self.node_constraints = np.zeros((ts.num_nodes, 2)) + self.node_constraints[:, 1] = np.inf + self.node_constraints[samples, :] = ts.nodes_time[samples, np.newaxis] + self._check_valid_constraints( + self.node_constraints, self.edge_parents, self.edge_children + ) + + # count mutations on edges individual_unphased = np.full(ts.num_individuals, False) self.edge_likelihoods, self.mutation_edges = count_mutations(ts) + self.edge_likelihoods[:, 1] *= mutation_rate + + # count mutations in singleton blocks self.block_likelihoods, self.block_edges, \ self.mutation_blocks = block_singletons(ts, individual_unphased) + self.block_likelihoods[:, 1] *= mutation_rate self.block_edges = np.ascontiguousarray(self.block_edges.T) self.edge_unphased = np.full(ts.num_edges, False) self.edge_unphased[self.block_edges[0]] = True From f0a002d190025ab0bb32f5b17e78dd613d7c609e Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 11 Jun 2024 13:28:28 -0700 Subject: [PATCH 11/29] Tests and such --- tsdate/core.py | 2 ++ tsdate/variational.py | 75 ++++++++++++++++++++++++++++--------------- 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index c4885082..d2b7cc7b 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -989,6 +989,7 @@ def __init__( self.priors = priors # mutation to edge mapping + # TODO: this isn't needed except for mutations_edge in constrain_mutations mutspan_timing = time.time() self.edges_mutations, self.mutations_edge = util.mutation_span_array(ts) mutspan_timing -= time.time() @@ -1011,6 +1012,7 @@ def get_modified_ts(self, result, eps): # Constrain node ages for positive branch lengths constr_timing = time.time() nodes.time = util.constrain_ages(ts, node_mean_t, eps, self.constr_iterations) + # TODO: what if mutations_edge is NULL? mutations.time = util.constrain_mutations(ts, nodes.time, self.mutations_edge) tables.time_units = self.time_units constr_timing -= time.time() diff --git a/tsdate/variational.py b/tsdate/variational.py index 51115042..87a02c72 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -69,6 +69,10 @@ LOWER = 0 # lower bound on node UPPER = 1 # upper bound on node +# named flags for unphased updates +USE_EDGE_LIKELIHOOD = False +USE_BLOCK_LIKELIHOOD = True + @numba.njit(_f(_f1r, _f1r, _f)) def _damp(x, y, s): @@ -225,11 +229,12 @@ def __init__(self, ts, *, mutation_rate): ) # count mutations on edges - individual_unphased = np.full(ts.num_individuals, False) self.edge_likelihoods, self.mutation_edges = count_mutations(ts) self.edge_likelihoods[:, 1] *= mutation_rate # count mutations in singleton blocks + # TODO: blocks should only be built from contemporary individuals + individual_unphased = np.full(ts.num_individuals, False) self.block_likelihoods, self.block_edges, \ self.mutation_blocks = block_singletons(ts, individual_unphased) self.block_likelihoods[:, 1] *= mutation_rate @@ -275,7 +280,7 @@ def __init__(self, ts, *, mutation_rate): self.mutation_order = np.arange(ts.num_mutations, dtype=np.int32) @staticmethod - @numba.njit(_f(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f)) + @numba.njit(_f(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f, _b)) def propagate_likelihood( edge_order, edges_parent, @@ -288,7 +293,7 @@ def propagate_likelihood( scale, max_shape, min_step, - #unphased, + unphased, ): """ Update approximating factors for Poisson mutation likelihoods on edges. @@ -311,6 +316,8 @@ def propagate_likelihood( scaling factor for the posteriors, updated in-place. :param float max_shape: the maximum allowed shape for node posteriors. :param float min_step: the minimum allowed step size in (0, 1). + :param bool unphased: if True, edges are treated as blocks of unphased + singletons in contemporary individuals """ assert constraints.shape == posterior.shape @@ -326,16 +333,20 @@ def cavity_damping(x, y): def posterior_damping(x): return _rescale(x, max_shape) - # if "unphased" edges are singleton blocks, and the two parents - # of each block are given by "parents" and "children" - #if unphased: - # leafward_projection = approx.sideways_projection - # rootward_projection = approx.sideways_projection - # gamma_projection = approx.unphased_projection - #else: - leafward_projection = approx.leafward_projection - rootward_projection = approx.rootward_projection - gamma_projection = approx.gamma_projection + def leafward_projection(x, y, z): + if unphased: + return approx.sideways_projection(x, y, z) + return approx.leafward_projection(x, y, z) + + def rootward_projection(x, y, z): + if unphased: + return approx.sideways_projection(x, y, z) + return approx.rootward_projection(x, y, z) + + def gamma_projection(x, y, z): + if unphased: + return approx.unphased_projection(x, y, z) + return approx.gamma_projection(x, y, z) fixed = constraints[:, LOWER] == constraints[:, UPPER] @@ -474,7 +485,7 @@ def posterior_damping(x): return np.nan @staticmethod - @numba.njit(_f(_i1r, _f2w, _i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r)) + @numba.njit(_f(_i1r, _f2w, _i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b)) def propagate_mutations( mutations_order, mutations_posterior, @@ -486,6 +497,7 @@ def propagate_mutations( posterior, factors, scale, + unphased, ): """ Calculate posteriors for mutations. @@ -509,6 +521,8 @@ def propagate_mutations( edge, updated in-place. :param ndarray scale: array of dimension `[num_nodes]` containing a scaling factor for the posteriors, updated in-place. + :param bool unphased: if True, edges are treated as blocks of unphased + singletons in contemporary individuals """ # TODO: scale should be 1.0, can we delete @@ -521,16 +535,25 @@ def propagate_mutations( assert factors.shape == (edges_parent.size, 2, 2) assert likelihoods.shape == (edges_parent.size, 2) - #if unphased: - # gamma_projection = approx.mutation_unphased_projection - # leafward_projection = approx.mutation_sideways_projection - # rootward_projection = approx.mutation_sideways_projection - # fixed_projection = approx.mutation_block_projection - #else: - gamma_projection = approx.mutation_gamma_projection - leafward_projection = approx.mutation_leafward_projection - rootward_projection = approx.mutation_rootward_projection - fixed_projection = approx.mutation_edge_projection + def leafward_projection(x, y, z): + if unphased: + return approx.mutation_sideways_projection(x, y, z) + return approx.mutation_leafward_projection(x, y, z) + + def rootward_projection(x, y, z): + if unphased: + return approx.mutation_sideways_projection(x, y, z) + return approx.mutation_rootward_projection(x, y, z) + + def gamma_projection(x, y, z): + if unphased: + return approx.mutation_unphased_projection(x, y, z) + return approx.mutation_gamma_projection(x, y, z) + + def fixed_projection(x, y): + if unphased: + return approx.mutation_block_projection(x, y) + return approx.mutation_edge_projection(x, y) phase = np.zeros(mutations_edge.size) # TODO fixed = constraints[:, LOWER] == constraints[:, UPPER] @@ -604,6 +627,7 @@ def iterate( regularise=True, check_valid=False, ): + # TODO: pass through unphased intervals #self.propagate_likelihood( # self.block_order, @@ -632,7 +656,7 @@ def iterate( self.node_scale, max_shape, min_step, - #USE_EDGE_LIKELIHOOD, + USE_EDGE_LIKELIHOOD, ) # exponential regularization on roots @@ -746,6 +770,7 @@ def run( self.node_posterior, self.edge_factors, self.node_scale, + USE_EDGE_LIKELIHOOD, ) muts_timing -= time.time() skipped_muts = np.sum(np.isnan(self.mutation_posterior[:, 0])) From 09e9db012424e87e051742b341d32e98414b660e Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 11 Jun 2024 15:14:46 -0700 Subject: [PATCH 12/29] Better mutation counting --- tsdate/core.py | 40 ++++++++++++---- tsdate/phasing.py | 37 ++++++++------- tsdate/variational.py | 107 ++++++++++++++++++++++++++++-------------- 3 files changed, 125 insertions(+), 59 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index d2b7cc7b..7c36f563 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -881,7 +881,8 @@ def outside_maximization(self, *, eps, progress=None): "posterior_obj", "mutation_mean", "mutation_var", - "mutation_likelihood", + "mutation_lik", + "mutation_edge", ], ) @@ -990,10 +991,7 @@ def __init__( # mutation to edge mapping # TODO: this isn't needed except for mutations_edge in constrain_mutations - mutspan_timing = time.time() self.edges_mutations, self.mutations_edge = util.mutation_span_array(ts) - mutspan_timing -= time.time() - logging.info(f"Extracted mutations in {abs(mutspan_timing)} seconds") def get_modified_ts(self, result, eps): # Return a new ts based on the existing one, but with the various @@ -1003,6 +1001,7 @@ def get_modified_ts(self, result, eps): node_var_t = result.posterior_var mut_mean_t = result.mutation_mean mut_var_t = result.mutation_var + mut_edge = result.mutation_edge tables = ts.dump_tables() nodes = tables.nodes mutations = tables.mutations @@ -1013,7 +1012,7 @@ def get_modified_ts(self, result, eps): constr_timing = time.time() nodes.time = util.constrain_ages(ts, node_mean_t, eps, self.constr_iterations) # TODO: what if mutations_edge is NULL? - mutations.time = util.constrain_mutations(ts, nodes.time, self.mutations_edge) + mutations.time = util.constrain_mutations(ts, nodes.time, mut_edge) tables.time_units = self.time_units constr_timing -= time.time() logging.info(f"Constrained node ages in {abs(constr_timing)} seconds") @@ -1078,7 +1077,7 @@ def parse_result(self, result, epsilon, extra_posterior_cols=None): pst_dict.update(extra_posterior_cols or {}) ret.append(pst_dict) if self.return_likelihood: - ret.append(result.mutation_likelihood) + ret.append(result.mutation_lik) return tuple(ret) if len(ret) > 1 else ret.pop() def get_fixed_nodes_set(self): @@ -1171,8 +1170,15 @@ def run( posterior_obj.to_probabilities() posterior_mean, posterior_var = self.mean_var(self.ts, posterior_obj) + mut_edge = np.full(self.ts.num_mutations, tskit.NULL) return Results( - posterior_mean, posterior_var, posterior_obj, None, None, marginal_likl + posterior_mean, + posterior_var, + posterior_obj, + None, + None, + marginal_likl, + mut_edge ) @@ -1200,7 +1206,16 @@ def run( dynamic_prog = self.main_algorithm(probability_space, eps, num_threads) marginal_likl = dynamic_prog.inside_pass(cache_inside=cache_inside) posterior_mean = dynamic_prog.outside_maximization(eps=eps) - return Results(posterior_mean, None, None, None, None, marginal_likl) + mut_edge = np.full(self.ts.num_mutations, tskit.NULL) + return Results( + posterior_mean, + None, + None, + None, + None, + marginal_likl, + mut_edge + ) class VariationalGammaMethod(EstimationMethod): @@ -1281,10 +1296,17 @@ def run( idx = mutation_post[:, 1] > 0 mutation_mean[idx] = (mutation_post[idx, 0] + 1) / mutation_post[idx, 1] mutation_vari[idx] = (mutation_post[idx, 0] + 1) / mutation_post[idx, 1] ** 2 + mutation_edge = dynamic_prog.mutation_edges # TODO: return marginal likelihood return Results( - posterior_mean, posterior_vari, None, mutation_mean, mutation_vari, None + posterior_mean, + posterior_vari, + None, + mutation_mean, + mutation_vari, + None, + mutation_edge ) diff --git a/tsdate/phasing.py b/tsdate/phasing.py index 525ad953..415f893c 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -40,37 +40,39 @@ from .approx import _b from .approx import _b1r from .approx import _tuple +from .approx import _void # --- machinery used by ExpectationPropagation class --- # -@numba.njit(_f1w(_f1r, _b1r, _f1r, _i1r, _i2r)) -def reallocate_unphased(edges_mutations, edges_unphased, mutations_phase, mutations_block, blocks_edges): +@numba.njit(_void(_f2w, _f1r, _i1r, _i2r)) +def reallocate_unphased(edges_likelihood, mutations_phase, mutations_block, blocks_edges): """ Add a proportion of each unphased singleton mutation to one of the two - edges to which it maps, and returns the modified `edges_mutations`. + edges to which it maps """ assert mutations_phase.size == mutations_block.size - assert blocks_edges.shape[1] == 2 - assert edges_mutations.size == edges_unphased.size + assert blocks_edges.shape[0] == 2 num_mutations = mutations_phase.size - num_edges = edges_mutations.size + num_edges = edges_likelihood.shape[0] num_blocks = blocks_edges.shape[0] - new_edges_mutations = edges_mutations.copy() - new_edges_mutations[edges_unphased] = 0.0 + edges_unphased = np.full(num_edges, False) + edges_unphased[blocks_edges[0]] = True + edges_unphased[blocks_edges[1]] = True + + num_unphased = np.sum(edges_likelihood[edges_unphased, 0]) + edges_likelihood[edges_unphased, 0] = 0.0 for m, b in enumerate(mutations_block): if b == tskit.NULL: continue - i, j = blocks_edges[b] + i, j = blocks_edges[0, b], blocks_edges[1, b] assert tskit.NULL < i < num_edges and edges_unphased[i] assert tskit.NULL < j < num_edges and edges_unphased[j] assert 0.0 <= mutations_phase[m] <= 1.0 - new_edges_mutations[i] += mutations_phase[m] - new_edges_mutations[j] += 1 - mutations_phase[m] - - assert np.isclose(np.sum(new_edges_mutations), np.sum(edges_mutations)) - return new_edges_mutations + edges_likelihood[i, 0] += mutations_phase[m] + edges_likelihood[j, 0] += 1 - mutations_phase[m] + assert np.isclose(num_unphased, np.sum(edges_likelihood[edges_unphased, 0])) @numba.njit(_tuple((_f2w, _i2w, _i1w))(_b1r, _i1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f)) @@ -177,8 +179,11 @@ def block_singletons(ts, individuals_unphased): TODO """ for i in ts.individuals(): - if individuals_unphased[i.id] and i.nodes.size != 2: - raise ValueError("Singleton blocking assumes diploid individuals") + if individuals_unphased[i.id]: + if i.nodes.size != 2: + raise ValueError("Singleton blocking assumes diploid individuals") + if not np.all(ts.nodes_time[i.nodes] == 0.0): + raise ValueError("Singleton blocking assumes contemporary individuals") # TODO: adjust spans by an accessibility mask return _block_singletons( diff --git a/tsdate/variational.py b/tsdate/variational.py index 87a02c72..29b07af0 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -51,6 +51,7 @@ from .util import contains_unary_nodes from .phasing import count_mutations from .phasing import block_singletons +from .phasing import reallocate_unphased # columns for edge_factors @@ -229,20 +230,29 @@ def __init__(self, ts, *, mutation_rate): ) # count mutations on edges + count_timing = time.time() self.edge_likelihoods, self.mutation_edges = count_mutations(ts) self.edge_likelihoods[:, 1] *= mutation_rate + count_timing -= time.time() + logging.info(f"Extracted mutations in {abs(count_timing)} seconds") # count mutations in singleton blocks - # TODO: blocks should only be built from contemporary individuals - individual_unphased = np.full(ts.num_individuals, False) + phase_timing = time.time() + individual_unphased = np.full(ts.num_individuals, True) self.block_likelihoods, self.block_edges, \ self.mutation_blocks = block_singletons(ts, individual_unphased) self.block_likelihoods[:, 1] *= mutation_rate - self.block_edges = np.ascontiguousarray(self.block_edges.T) - self.edge_unphased = np.full(ts.num_edges, False) - self.edge_unphased[self.block_edges[0]] = True - self.edge_unphased[self.block_edges[1]] = True + self.block_edges = np.ascontiguousarray(self.block_edges.T) # TODO: no need to transpose + self.block_nodes = np.full(self.block_edges.shape, tskit.NULL, dtype=np.int32) + self.block_nodes[0] = self.edge_parents[self.block_edges[0]] + self.block_nodes[1] = self.edge_parents[self.block_edges[1]] + self.mutation_phase = np.full(ts.num_mutations, np.nan) + num_unphased = np.sum(self.mutation_blocks != tskit.NULL) num_blocks = self.block_likelihoods.shape[0] + phase_timing -= time.time() + logging.info(f"Found {num_unphased} unphased singleton mutations") + logging.info(f"Split unphased singleton edges into {num_blocks} blocks") + logging.info(f"Phased singletons in {abs(phase_timing)} seconds") # mutable self.node_factors = np.zeros((ts.num_nodes, 2, 2)) @@ -265,7 +275,10 @@ def __init__(self, ts, *, mutation_rate): raise ValueError("Tree sequence contains disconnected nodes") # edge traversal order - edges = np.arange(ts.num_edges, dtype=np.int32)[~self.edge_unphased] + edge_unphased = np.full(ts.num_edges, False) + edge_unphased[self.block_edges[0]] = True + edge_unphased[self.block_edges[1]] = True + edges = np.arange(ts.num_edges, dtype=np.int32)[~edge_unphased] self.edge_order = np.concatenate((edges[:-1], np.flip(edges))) self.edge_weights = edge_sampling_weight( self.leaves, @@ -485,10 +498,11 @@ def posterior_damping(x): return np.nan @staticmethod - @numba.njit(_f(_i1r, _f2w, _i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b)) + @numba.njit(_f(_i1r, _f2w, _f1w, _i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b)) def propagate_mutations( mutations_order, mutations_posterior, + mutations_phase, mutations_edge, edges_parent, edges_child, @@ -504,8 +518,10 @@ def propagate_mutations( :param ndarray mutations_order: integer array giving order in which to traverse mutations - :param ndarray mutations_posterior: array of dimension `[num_mutations, 2]` - containing natural parameters for each mutation + :param ndarray mutations_posterior: array of dimension `(num_mutations, 2)` + containing natural parameters for each mutation, modified in place + :param ndarray mutations_phase: array of dimension `(num_mutations, )` + containing mutation phase, modified in place :param ndarray mutations_edge: integer array giving edge for each mutation :param ndarray edges_parent: integer array of parent ids per edge @@ -515,12 +531,12 @@ def propagate_mutations( :param ndarray constraints: array of dimension `[num_nodes, 2]` containing lower and upper bounds for each node. :param ndarray posterior: array of dimension `[num_nodes, 2]` - containing natural parameters for each node, updated in-place. + containing natural parameters for each node :param ndarray factors: array of dimension `[num_edges, 2, 2]` containing parent and child factors (natural parameters) for each - edge, updated in-place. - :param ndarray scale: array of dimension `[num_nodes]` containing a - scaling factor for the posteriors, updated in-place. + edge + :param ndarray scale: array of dimension `(num_nodes, )` containing a + scaling factor for the posteriors :param bool unphased: if True, edges are treated as blocks of unphased singletons in contemporary individuals """ @@ -530,6 +546,8 @@ def propagate_mutations( # TODO: might as well copy format in other functions and have void return # assert stuff here + assert mutations_phase.size == mutations_edge.size + assert mutations_posterior.shape == (mutations_phase.size, 2) assert constraints.shape == posterior.shape assert edges_child.size == edges_parent.size assert factors.shape == (edges_parent.size, 2, 2) @@ -555,7 +573,6 @@ def fixed_projection(x, y): return approx.mutation_block_projection(x, y) return approx.mutation_edge_projection(x, y) - phase = np.zeros(mutations_edge.size) # TODO fixed = constraints[:, LOWER] == constraints[:, UPPER] for m in mutations_order: i = mutations_edge[m] @@ -574,7 +591,7 @@ def fixed_projection(x, y): child_cavity = posterior[c] - child_delta * child_message edge_likelihood = child_delta * likelihoods[i] parent_age = constraints[p, LOWER] - phase[m], mutations_posterior[m] = leafward_projection( + mutations_phase[m], mutations_posterior[m] = leafward_projection( parent_age, child_cavity, edge_likelihood, ) elif fixed[c] and not fixed[p]: @@ -583,7 +600,7 @@ def fixed_projection(x, y): parent_cavity = posterior[p] - parent_delta * parent_message edge_likelihood = parent_delta * likelihoods[i] child_age = constraints[c, LOWER] - phase[m], mutations_posterior[m] = rootward_projection( + mutations_phase[m], mutations_posterior[m] = rootward_projection( child_age, parent_cavity, edge_likelihood, ) else: @@ -595,7 +612,7 @@ def fixed_projection(x, y): parent_cavity = posterior[p] - delta * parent_message child_cavity = posterior[c] - delta * child_message edge_likelihood = delta * likelihoods[i] - phase[m], mutations_posterior[m] = gamma_projection( + mutations_phase[m], mutations_posterior[m] = gamma_projection( parent_cavity, child_cavity, edge_likelihood, ) @@ -622,26 +639,27 @@ def iterate( *, max_shape=1000, min_step=0.1, - em_maxitt=100, + em_maxitt=10, em_reltol=1e-8, regularise=True, check_valid=False, ): - # TODO: pass through unphased intervals - #self.propagate_likelihood( - # self.block_order, - # self.block_parents[0], - # self.block_parents[1], - # self.block_likelihoods, - # self.node_constraints, - # self.block_factors, - # self.block_log_partition, - # self.node_scale, - # max_shape, - # min_step, - # USE_BLOCK_LIKELIHOOD, - #) + # pass through singleton blocks + self.propagate_likelihood( + self.block_order, + self.block_nodes[0], + self.block_nodes[1], + self.block_likelihoods, + self.node_constraints, + self.node_posterior, + self.block_factors, + self.block_logconst, + self.node_scale, + max_shape, + min_step, + USE_BLOCK_LIKELIHOOD, + ) # rootward + leafward pass through edges self.propagate_likelihood( @@ -706,6 +724,12 @@ def rescale( nodes_time = self._point_estimate( self.node_posterior, self.node_constraints, use_median ) + reallocate_unphased( # correct mutation counts for unphased singletons + self.edge_likelihoods, + self.mutation_phase, + self.mutation_blocks, + self.block_edges, + ) original_breaks, rescaled_breaks = mutational_timescale( nodes_time, self.edge_likelihoods, @@ -759,9 +783,24 @@ def run( logging.info(f"Calculated node posteriors in {abs(nodes_timing)} seconds") muts_timing = time.time() - self.propagate_mutations( + self.propagate_mutations( # unphased singletons + self.mutation_order, + self.mutation_posterior, + self.mutation_phase, + self.mutation_blocks, + self.block_nodes[0], + self.block_nodes[1], + self.block_likelihoods, + self.node_constraints, + self.node_posterior, + self.block_factors, + self.node_scale, + USE_BLOCK_LIKELIHOOD, + ) + self.propagate_mutations( # phased mutations self.mutation_order, self.mutation_posterior, + self.mutation_phase, self.mutation_edges, self.edge_parents, self.edge_children, From e32bfabb4c089faa614fc1a6c905b3963e3ea1b1 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 11 Jun 2024 20:22:24 -0700 Subject: [PATCH 13/29] Working --- tests/test_phasing.py | 86 ++++++++++++----- tsdate/approx.py | 27 ++++-- tsdate/core.py | 65 ++++++------- tsdate/evaluation.py | 57 ++++++++---- tsdate/phasing.py | 208 +++++++++++++++++++++++++++++++++++------- tsdate/variational.py | 64 ++++++++----- 6 files changed, 372 insertions(+), 135 deletions(-) diff --git a/tests/test_phasing.py b/tests/test_phasing.py index da015194..f8d74dc5 100644 --- a/tests/test_phasing.py +++ b/tests/test_phasing.py @@ -24,15 +24,15 @@ """ Test cases for the gamma-variational approximations in tsdate """ - +import msprime import numpy as np import pytest -import tskit -import msprime import tsinfer +import tskit from tsdate.phasing import block_singletons from tsdate.phasing import count_mutations +from tsdate.phasing import mutation_frequency @pytest.fixture(scope="session") @@ -66,7 +66,6 @@ def test_count_mutations(self, inferred_ts): class TestBlockSingletons: - @staticmethod def naive_block_singletons(ts, individual): """ @@ -81,7 +80,7 @@ def naive_block_singletons(ts, individual): blocks_edge = [] blocks_span = [] for tree in ts.trees(): - if tree.num_edges == 0: # skip tree + if tree.num_edges == 0: # skip tree muts = [] span = 0.0 block = tskit.NULL, tskit.NULL @@ -91,19 +90,23 @@ def naive_block_singletons(ts, individual): block = tree.edge(j), tree.edge(k) for m in muts: muts_edges[m] = block - if last_block[0] != tskit.NULL and not np.array_equal(block, last_block): # flush block + if last_block[0] != tskit.NULL and not np.array_equal( + block, last_block + ): # flush block blocks_edge.extend(last_block) blocks_span.extend(last_span) last_span[:] = 0.0 last_span += len(muts), span last_block[:] = block - if last_block[0] != tskit.NULL: # flush last block + if last_block[0] != tskit.NULL: # flush last block blocks_edge.extend(last_block) blocks_span.extend(last_span) blocks_edge = np.array(blocks_edge).reshape(-1, 2) blocks_span = np.array(blocks_span).reshape(-1, 2) total_span = np.sum([t.interval.span for t in ts.trees() if t.num_edges > 0]) - total_muts = np.sum(np.logical_or(ts.mutations_node == j, ts.mutations_node == k)) + total_muts = np.sum( + np.logical_or(ts.mutations_node == j, ts.mutations_node == k) + ) assert np.sum(blocks_span[:, 0]) == total_muts assert np.sum(blocks_span[:, 1]) == total_span return blocks_span, blocks_edge, muts_edges @@ -117,7 +120,9 @@ def test_against_naive(self, inferred_ts): individuals_unphased = np.full(ts.num_individuals, False) unphased_individuals = np.arange(0, ts.num_individuals // 2) individuals_unphased[unphased_individuals] = True - block_stats, block_edges, muts_block = block_singletons(ts, individuals_unphased) + block_stats, block_edges, muts_block = block_singletons( + ts, individuals_unphased + ) block_edges = block_edges singletons = muts_block != tskit.NULL muts_edges = np.full((ts.num_mutations, 2), tskit.NULL) @@ -125,7 +130,9 @@ def test_against_naive(self, inferred_ts): ck_num_blocks = 0 ck_num_singletons = 0 for i in np.flatnonzero(individuals_unphased): - ck_block_stats, ck_block_edges, ck_muts_edges = self.naive_block_singletons(ts, i) + ck_block_stats, ck_block_edges, ck_muts_edges = self.naive_block_singletons( + ts, i + ) ck_num_blocks += ck_block_stats.shape[0] # blocks of individual i nodes_i = ts.individual(i).nodes @@ -140,11 +147,11 @@ def test_against_naive(self, inferred_ts): # singleton mutations in unphased individual i ck_muts_i = ck_muts_edges[:, 0] != tskit.NULL np.testing.assert_array_equal( - np.min(muts_edges[ck_muts_i], axis=1), + np.min(muts_edges[ck_muts_i], axis=1), np.min(ck_muts_edges[ck_muts_i], axis=1), ) np.testing.assert_array_equal( - np.max(muts_edges[ck_muts_i], axis=1), + np.max(muts_edges[ck_muts_i], axis=1), np.max(ck_muts_edges[ck_muts_i], axis=1), ) ck_num_singletons += np.sum(ck_muts_i) @@ -160,11 +167,14 @@ def test_total_counts(self, inferred_ts): individuals_unphased = np.full(ts.num_individuals, False) unphased_individuals = np.arange(0, ts.num_individuals // 2) individuals_unphased[unphased_individuals] = True - unphased_nodes = np.concatenate([ts.individual(i).nodes for i in unphased_individuals]) + unphased_nodes = np.concatenate( + [ts.individual(i).nodes for i in unphased_individuals] + ) total_singleton_span = 0.0 total_singleton_muts = 0.0 for t in ts.trees(): - if t.num_edges == 0: continue + if t.num_edges == 0: + continue for s in t.samples(): if s in unphased_nodes: total_singleton_span += t.span @@ -185,15 +195,18 @@ def test_singleton_edges(self, inferred_ts): individuals_unphased = np.full(ts.num_individuals, False) unphased_individuals = np.arange(0, ts.num_individuals // 2) individuals_unphased[unphased_individuals] = True - unphased_nodes = set(np.concatenate([ts.individual(i).nodes for i in unphased_individuals])) + unphased_nodes = set( + np.concatenate([ts.individual(i).nodes for i in unphased_individuals]) + ) ck_singleton_edge = set() for t in ts.trees(): - if t.num_edges == 0: continue + if t.num_edges == 0: + continue for s in ts.samples(): if s in unphased_nodes: ck_singleton_edge.add(t.edge(s)) _, block_edges, *_ = block_singletons(ts, individuals_unphased) - singleton_edge = set([i for i in block_edges.flatten()]) + singleton_edge = {i for i in block_edges.flatten()} assert singleton_edge == ck_singleton_edge def test_singleton_mutations(self, inferred_ts): @@ -205,15 +218,18 @@ def test_singleton_mutations(self, inferred_ts): individuals_unphased = np.full(ts.num_individuals, False) unphased_individuals = np.arange(0, ts.num_individuals // 2) individuals_unphased[unphased_individuals] = True - unphased_nodes = np.concatenate([ts.individual(i).nodes for i in unphased_individuals]) + unphased_nodes = np.concatenate( + [ts.individual(i).nodes for i in unphased_individuals] + ) ck_singleton_muts = set() for t in ts.trees(): - if t.num_edges == 0: continue + if t.num_edges == 0: + continue for m in t.mutations(): if t.num_samples(m.node) == 1 and (m.node in unphased_nodes): ck_singleton_muts.add(m.id) _, _, block_muts = block_singletons(ts, individuals_unphased) - singleton_muts = set([i for i in np.flatnonzero(block_muts != tskit.NULL)]) + singleton_muts = {i for i in np.flatnonzero(block_muts != tskit.NULL)} assert singleton_muts == ck_singleton_muts def test_all_phased(self, inferred_ts): @@ -222,8 +238,36 @@ def test_all_phased(self, inferred_ts): """ ts = inferred_ts individuals_unphased = np.full(ts.num_individuals, False) - block_stats, block_edges, block_muts = block_singletons(ts, individuals_unphased) + block_stats, block_edges, block_muts = block_singletons( + ts, individuals_unphased + ) assert block_stats.shape == (0, 2) assert block_edges.shape == (0, 2) assert np.all(block_muts == tskit.NULL) + +class TestMutationFrequency: + @staticmethod + def naive_mutation_frequency(ts, sample_set): + frq = np.zeros(ts.num_mutations) + for t in ts.trees(): + for m in t.mutations(): + for s in t.samples(m.node): + frq[m.id] += int(s in sample_set) + return frq + + def test_mutation_frequency(self, inferred_ts): + ck_freq = self.naive_mutation_frequency(inferred_ts, inferred_ts.samples()) + freq = mutation_frequency(inferred_ts) + np.testing.assert_array_equal(ck_freq, freq.squeeze()) + + def test_mutation_frequency_stratified(self, inferred_ts): + sample_sets = [ + list(np.arange(5)), + list(np.arange(3, 10)), + list(np.arange(15, 20)), + ] + freqs = mutation_frequency(inferred_ts, sample_sets) + for i, s in enumerate(sample_sets): + ck_freq = self.naive_mutation_frequency(inferred_ts, s) + np.testing.assert_array_equal(ck_freq, freqs[:, i]) diff --git a/tsdate/approx.py b/tsdate/approx.py index d3b290ea..92694745 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -29,11 +29,8 @@ from math import log from math import nan -import mpmath import numba import numpy as np -from numba.types import Tuple as _tuple -from numba.types import UniTuple as _unituple from . import hypergeo @@ -59,8 +56,14 @@ _i1r = numba.types.Array(_i, 1, "C", readonly=True) _i2w = numba.types.Array(_i, 2, "C", readonly=False) _i2r = numba.types.Array(_i, 2, "C", readonly=True) +_i3w = numba.types.Array(_i, 3, "C", readonly=False) +_i3r = numba.types.Array(_i, 3, "C", readonly=True) _b1w = numba.types.Array(_b, 1, "C", readonly=False) _b1r = numba.types.Array(_b, 1, "C", readonly=True) +_b2w = numba.types.Array(_b, 2, "C", readonly=False) +_b2r = numba.types.Array(_b, 2, "C", readonly=True) +_b3w = numba.types.Array(_b, 3, "C", readonly=False) +_b3r = numba.types.Array(_b, 3, "C", readonly=True) _tuple = numba.types.Tuple _unituple = numba.types.UniTuple _void = numba.types.void @@ -242,6 +245,7 @@ def _valid_hyp2f1(a, b, c, z): # --- various EP updates --- # + @numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" @@ -462,12 +466,16 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): s1 = a * b / c s2 = s1 * (a + 1) * (b + 1) / (c + 1) - d1 = b * (b + 1) / t ** 2 + d1 = b * (b + 1) / t**2 d2 = d1 * a / c d3 = d2 * (a + 1) / (c + 1) mn_m = s1 * exp(f111 - f000) / t / 2 * (1 + z) + b / t / 2 - sq_m = d1 * exp(f020 - f000) / 3 + d2 * exp(f121 - f000) / 3 + d3 * exp(f222 - f000) / 3 + sq_m = ( + d1 * exp(f020 - f000) / 3 + + d2 * exp(f121 - f000) / 3 + + d3 * exp(f222 - f000) / 3 + ) va_m = sq_m - mn_m**2 return mn_m, va_m @@ -596,12 +604,12 @@ def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): # direct but unstable: hyperu = hypergeo._hyperu_laplace f00, d00 = hyperu(a + 0, b + 0, z) - f10, d10 = hyperu(a + 1, b + 0, z) - f21, d21 = hyperu(a + 2, b + 1, z) + f10, d10 = hyperu(a + 1, b + 0, z) + f21, d21 = hyperu(a + 2, b + 1, z) f32, d32 = hyperu(a + 3, b + 2, z) pr_m = 1.0 - exp(f10 - f00) * a mn_m = pr_m * t_i / 2 + t_i * exp(f21 - f00) * a * (a + 1) / 2 - sq_m = pr_m * t_i ** 2 / 3 + t_i ** 2 * exp(f32 - f00) * a * (a + 1) * (a + 2) / 3 + sq_m = pr_m * t_i**2 / 3 + t_i**2 * exp(f32 - f00) * a * (a + 1) * (a + 2) / 3 # TODO: use a stabler approach with derivatives # note that exp(f10 - f00) = (a + z * d00) / (a - b + 1) @@ -643,7 +651,7 @@ def mutation_block_moments(t_i, t_j): pr_m = t_i / (t_i + t_j) mn_m = pr_m * t_i / 2 + (1 - pr_m) * t_j / 2 - sq_m = pr_m * t_i ** 2 / 3 + (1 - pr_m) * t_j ** 2 / 3 + sq_m = pr_m * t_i**2 / 3 + (1 - pr_m) * t_j**2 / 3 va_m = sq_m - mn_m**2 return pr_m, mn_m, va_m @@ -651,6 +659,7 @@ def mutation_block_moments(t_i, t_j): # --- wrappers around updates --- # + @numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) def gamma_projection(pars_i, pars_j, pars_ij): r""" diff --git a/tsdate/core.py b/tsdate/core.py index 7c36f563..0d6f00ff 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1172,13 +1172,13 @@ def run( posterior_mean, posterior_var = self.mean_var(self.ts, posterior_obj) mut_edge = np.full(self.ts.num_mutations, tskit.NULL) return Results( - posterior_mean, - posterior_var, - posterior_obj, - None, - None, - marginal_likl, - mut_edge + posterior_mean, + posterior_var, + posterior_obj, + None, + None, + marginal_likl, + mut_edge, ) @@ -1207,15 +1207,7 @@ def run( marginal_likl = dynamic_prog.inside_pass(cache_inside=cache_inside) posterior_mean = dynamic_prog.outside_maximization(eps=eps) mut_edge = np.full(self.ts.num_mutations, tskit.NULL) - return Results( - posterior_mean, - None, - None, - None, - None, - marginal_likl, - mut_edge - ) + return Results(posterior_mean, None, None, None, None, marginal_likl, mut_edge) class VariationalGammaMethod(EstimationMethod): @@ -1250,11 +1242,6 @@ def mean_var(posteriors, constraints): return mn_post, va_post - def main_algorithm(self): - return variational.ExpectationPropagation( - self.ts, mutation_rate=self.mutation_rate - ) - def run( self, eps, @@ -1263,6 +1250,7 @@ def run( rescaling_intervals, match_segregating_sites, regularise_roots, + singletons_phased, ): if self.provenance_params is not None: self.provenance_params.update( @@ -1274,8 +1262,12 @@ def run( raise ValueError("Variational gamma method requires mutation rate") # match sufficient statistics or match central moments - dynamic_prog = self.main_algorithm() - dynamic_prog.run( + posterior = variational.ExpectationPropagation( + self.ts, + mutation_rate=self.mutation_rate, + singletons_phased=singletons_phased, + ) + posterior.run( ep_maxitt=max_iterations, max_shape=max_shape, rescale_intervals=rescaling_intervals, @@ -1284,29 +1276,29 @@ def run( progress=self.pbar, ) - # TODO: use dynamic_prog.point_estimate + # TODO: use posterior.point_estimate posterior_mean, posterior_vari = self.mean_var( - dynamic_prog.node_posterior, dynamic_prog.node_constraints + posterior.node_posterior, posterior.node_constraints ) # TODO: clean up - mutation_post = dynamic_prog.mutation_posterior + mutation_post = posterior.mutation_posterior mutation_mean = np.full(mutation_post.shape[0], np.nan) mutation_vari = np.full(mutation_post.shape[0], np.nan) idx = mutation_post[:, 1] > 0 mutation_mean[idx] = (mutation_post[idx, 0] + 1) / mutation_post[idx, 1] mutation_vari[idx] = (mutation_post[idx, 0] + 1) / mutation_post[idx, 1] ** 2 - mutation_edge = dynamic_prog.mutation_edges + mutation_edge = posterior.mutation_edges # TODO: return marginal likelihood return Results( - posterior_mean, - posterior_vari, - None, - mutation_mean, - mutation_vari, - None, - mutation_edge + posterior_mean, + posterior_vari, + None, + mutation_mean, + mutation_vari, + None, + mutation_edge, ) @@ -1589,6 +1581,7 @@ def variational_gamma( max_shape=None, match_segregating_sites=None, regularise_roots=None, + singletons_phased=None, **kwargs, ): """ @@ -1655,6 +1648,8 @@ def variational_gamma( match_segregating_sites = False if regularise_roots is None: regularise_roots = True + if singletons_phased is None: + singletons_phased = True if tree_sequence.num_mutations == 0: raise ValueError( "No mutations present: these are required for the variational_gamma method" @@ -1669,6 +1664,7 @@ def variational_gamma( rescaling_intervals=rescaling_intervals, match_segregating_sites=match_segregating_sites, regularise_roots=regularise_roots, + singletons_phased=singletons_phased, ) return dating_method.parse_result(result, eps, {"parameter": ["shape", "rate"]}) @@ -1759,7 +1755,6 @@ def date( :param bool record_provenance: Should the tsdate command be appended to the provenence information in the returned tree sequence? Default: None, treated as True. - :param float Ne: Deprecated, use the``population_size`` argument instead. :param \\**kwargs: Other keyword arguments specific to the :data:`estimation method` used. These are documented in those specific functions. diff --git a/tsdate/evaluation.py b/tsdate/evaluation.py index 1fbf2447..5a95ea92 100644 --- a/tsdate/evaluation.py +++ b/tsdate/evaluation.py @@ -34,6 +34,9 @@ import scipy.sparse import tskit +from .phasing import count_mutations +from .phasing import mutation_frequency + class CladeMap: """ @@ -533,11 +536,29 @@ def mutation_coverage(ts, inferred_ts, alpha): return prop_covered -def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None, title=None, subtending_node=False): +def mutations_time( + ts, + infer_ts, + use_posterior=False, + min_freq=None, + max_freq=None, + plotpath=None, + title=None, + subtending_node=False, +): """ Return true and inferred mutation ages, optionally creating a scatterplot and filtering by minimum or maximum frequency. """ + # mutation edges and frequency + _, true_edge = count_mutations(ts) + _, infr_edge = count_mutations(infer_ts) + true_freq = mutation_frequency(ts) + infr_freq = mutation_frequency(infer_ts) + if use_posterior: + mut_post_mn = np.zeros(infer_ts.num_mutations) + for m in infer_ts.mutations(): + mut_post_mn[m.id] = m.metadata["mn"] # find shared biallelic sites positions = {p: i for i, p in enumerate(ts.sites_position)} true_mut = np.full(ts.sites_position.size, tskit.NULL) @@ -547,7 +568,7 @@ def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None, if s.mutations[0].edge != tskit.NULL: sid = positions[s.position] true_mut[sid] = s.mutations[0].id - for s in inferred_ts.sites(): + for s in infer_ts.sites(): if len(s.mutations) == 1: if s.mutations[0].edge != tskit.NULL: sid = positions[s.position] @@ -557,21 +578,18 @@ def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None, true_mut = true_mut[~missing] # filter by frequency if min_freq is not None or max_freq is not None: - freq = np.zeros(inferred_ts.num_mutations) - for t in inferred_ts.trees(): - for m in t.mutations(): - freq[m.id] = t.num_samples(m.node) if min_freq is None: - min_freq = np.min(freq) + min_freq = 0 if max_freq is None: - max_freq = np.max(freq) - freq = freq[infr_mut] + max_freq = np.max(infr_freq) + freq = infr_freq[infr_mut] + assert np.allclose(freq, true_freq[true_mut]) is_freq = np.logical_and(freq >= min_freq, freq <= max_freq) infr_mut = infr_mut[is_freq] true_mut = true_mut[is_freq] # get age of mutation or subtended node if subtending_node: - infr_node = inferred_ts.mutations_node[infr_mut] + infr_node = infer_ts.mutations_node[infr_mut] true_node = ts.mutations_node[true_mut] _, uniq_idx = np.unique(infr_node, return_index=True) infr_node = infr_node[uniq_idx] @@ -579,14 +597,21 @@ def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None, _, uniq_idx = np.unique(true_node, return_index=True) infr_node = infr_node[uniq_idx] true_node = true_node[uniq_idx] - mean = inferred_ts.nodes_time[infr_node] + mean = infer_ts.nodes_time[infr_node] truth = ts.nodes_time[true_node] nonzero = np.logical_and(mean > 0, truth > 0) mean = mean[nonzero] truth = truth[nonzero] - else: - mean = inferred_ts.mutations_time[infr_mut] - truth = ts.mutations_time[true_mut] + else: # midpoint on branch + infr_p = infer_ts.edges_parent[infr_edge[infr_mut]] + true_p = ts.edges_parent[true_edge[true_mut]] + infr_c = infer_ts.edges_child[infr_edge[infr_mut]] + true_c = ts.edges_child[true_edge[true_mut]] + if use_posterior: + mean = mut_post_mn[infr_mut] + else: + mean = (infer_ts.nodes_time[infr_p] + infer_ts.nodes_time[infr_c]) / 2 + truth = (ts.nodes_time[true_p] + ts.nodes_time[true_c]) / 2 if plotpath is not None: rsq = np.corrcoef(np.log10(mean), np.log10(truth))[0, 1] ** 2 bias = np.mean(np.log10(mean) - np.log10(truth)) @@ -600,8 +625,8 @@ def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None, plt.xlabel("True node age") plt.ylabel("Estimated node age") else: - plt.xlabel("True mutation age") - plt.ylabel("Estimated mutation age") + plt.xlabel("True mutation age (branch midpoint)") + plt.ylabel("Estimated mutation age (branch midpoint)") if title is not None: plt.title(title) plt.tight_layout() diff --git a/tsdate/phasing.py b/tsdate/phasing.py index 415f893c..0102b463 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -22,11 +22,13 @@ """ Tools for phasing singleton mutations """ - import numba import numpy as np import tskit +from .approx import _b +from .approx import _b1r +from .approx import _b2r from .approx import _f from .approx import _f1r from .approx import _f1w @@ -37,15 +39,16 @@ from .approx import _i1w from .approx import _i2r from .approx import _i2w -from .approx import _b -from .approx import _b1r from .approx import _tuple from .approx import _void # --- machinery used by ExpectationPropagation class --- # + @numba.njit(_void(_f2w, _f1r, _i1r, _i2r)) -def reallocate_unphased(edges_likelihood, mutations_phase, mutations_block, blocks_edges): +def reallocate_unphased( + edges_likelihood, mutations_phase, mutations_block, blocks_edges +): """ Add a proportion of each unphased singleton mutation to one of the two edges to which it maps @@ -73,10 +76,26 @@ def reallocate_unphased(edges_likelihood, mutations_phase, mutations_block, bloc edges_likelihood[i, 0] += mutations_phase[m] edges_likelihood[j, 0] += 1 - mutations_phase[m] assert np.isclose(num_unphased, np.sum(edges_likelihood[edges_unphased, 0])) - -@numba.njit(_tuple((_f2w, _i2w, _i1w))(_b1r, _i1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f)) -def _block_singletons(individuals_unphased, nodes_individual, mutations_node, mutations_position, edges_parent, edges_child, edges_left, edges_right, indexes_insert, indexes_remove, sequence_length): + +@numba.njit( + _tuple((_f2w, _i2w, _i1w))( + _b1r, _i1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f + ) +) +def _block_singletons( + individuals_unphased, + nodes_individual, + mutations_node, + mutations_position, + edges_parent, + edges_child, + edges_left, + edges_right, + indexes_insert, + indexes_remove, + sequence_length, +): """ TODO """ @@ -110,7 +129,7 @@ def _block_singletons(individuals_unphased, nodes_individual, mutations_node, mu a, b, d = 0, 0, 0 while a < num_edges or b < num_edges: while b < num_edges and position_remove[b] == left: # edges out - e = indexes_remove[b] + e = indexes_remove[b] p, c = edges_parent[e], edges_child[e] i = nodes_individual[c] if i != tskit.NULL and individuals_unphased[i]: @@ -127,7 +146,7 @@ def _block_singletons(individuals_unphased, nodes_individual, mutations_node, mu individuals_block[i] = tskit.NULL individuals_singletons[i] = 0.0 b += 1 - + while a < num_edges and position_insert[a] == left: # edges in e = indexes_insert[a] p, c = edges_parent[e], edges_child[e] @@ -148,7 +167,7 @@ def _block_singletons(individuals_unphased, nodes_individual, mutations_node, mu if a < num_edges: right = min(right, position_insert[a]) left = right - + while d < num_mutations and position_mutation[d] < right: # mutations m = indexes_mutation[d] c = mutations_node[m] @@ -157,7 +176,7 @@ def _block_singletons(individuals_unphased, nodes_individual, mutations_node, mu mutations_block[m] = individuals_block[i] individuals_singletons[i] += 1.0 d += 1 - + mutations_block = mutations_block.astype(np.int32) blocks_edges = np.array(blocks_edges, dtype=np.int32).reshape(-1, 2) blocks_singletons = np.array(blocks_singletons) @@ -201,8 +220,21 @@ def block_singletons(ts, individuals_unphased): ) -@numba.njit(_tuple((_f2w, _i1w))(_i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _i, _f)) -def _count_mutations(mutations_node, mutations_position, edges_parent, edges_child, edges_left, edges_right, indexes_insert, indexes_remove, num_nodes, sequence_length): +@numba.njit( + _tuple((_f2w, _i1w))(_i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _i, _f) +) +def _count_mutations( + mutations_node, + mutations_position, + edges_parent, + edges_child, + edges_left, + edges_right, + indexes_insert, + indexes_remove, + num_nodes, + sequence_length, +): """ TODO """ @@ -227,11 +259,11 @@ def _count_mutations(mutations_node, mutations_position, edges_parent, edges_chi a, b, d = 0, 0, 0 while a < num_edges or b < num_edges: while b < num_edges and position_remove[b] == left: # edges out - e = indexes_remove[b] + e = indexes_remove[b] p, c = edges_parent[e], edges_child[e] nodes_edge[c] = tskit.NULL b += 1 - + while a < num_edges and position_insert[a] == left: # edges in e = indexes_insert[a] p, c = edges_parent[e], edges_child[e] @@ -244,7 +276,7 @@ def _count_mutations(mutations_node, mutations_position, edges_parent, edges_chi if a < num_edges: right = min(right, position_insert[a]) left = right - + while d < num_mutations and position_mutation[d] < right: m = indexes_mutation[d] c = mutations_node[m] @@ -253,7 +285,7 @@ def _count_mutations(mutations_node, mutations_position, edges_parent, edges_chi mutations_edge[m] = e edges_mutations[e] += 1.0 d += 1 - + mutations_edge = mutations_edge.astype(np.int32) edges_stats = np.column_stack((edges_mutations, edges_span)) @@ -279,8 +311,109 @@ def count_mutations(ts): ) +@numba.njit(_i2w(_b2r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f)) +def _mutation_frequency( + nodes_sample, + mutations_node, + mutations_position, + edges_parent, + edges_child, + edges_left, + edges_right, + indexes_insert, + indexes_remove, + sequence_length, +): + """ + TODO + """ + assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size + assert indexes_insert.size == indexes_remove.size == edges_parent.size + assert mutations_node.size == mutations_position.size + + num_nodes, num_sample_sets = nodes_sample.shape + num_mutations = mutations_node.size + num_edges = edges_parent.size + + indexes_mutation = np.argsort(mutations_position) + position_insert = edges_left[indexes_insert] + position_remove = edges_right[indexes_remove] + position_mutation = mutations_position[indexes_mutation] + + nodes_parent = np.full(num_nodes, tskit.NULL) + nodes_samples = np.zeros((num_nodes, num_sample_sets), dtype=np.int32) + mutations_freq = np.zeros((num_mutations, num_sample_sets), dtype=np.int32) + + # TODO: there's a better way than passing a big bool array + for i in range(num_sample_sets): + nodes_samples[nodes_sample[:, i], i] = 1.0 + + left = 0.0 + a, b, d = 0, 0, 0 + while a < num_edges or b < num_edges: + while b < num_edges and position_remove[b] == left: # edges out + e = indexes_remove[b] + p, c = edges_parent[e], edges_child[e] + nodes_parent[c] = tskit.NULL + while p != tskit.NULL: + nodes_samples[p] -= nodes_samples[c] + p = nodes_parent[p] + b += 1 + + while a < num_edges and position_insert[a] == left: # edges in + e = indexes_insert[a] + p, c = edges_parent[e], edges_child[e] + nodes_parent[c] = p + while p != tskit.NULL: + nodes_samples[p] += nodes_samples[c] + p = nodes_parent[p] + a += 1 + + right = sequence_length + if b < num_edges: + right = min(right, position_remove[b]) + if a < num_edges: + right = min(right, position_insert[a]) + left = right + + while d < num_mutations and position_mutation[d] < right: + m = indexes_mutation[d] + c = mutations_node[m] + mutations_freq[m] = nodes_samples[c] + d += 1 + + return mutations_freq + + +def mutation_frequency(ts, sample_sets=None): + """ + TODO + """ + if sample_sets is None: + sample_sets = [list(ts.samples())] + + nodes_sample = np.full((ts.num_nodes, len(sample_sets)), False) + for i, s in enumerate(sample_sets): + assert min(s) >= 0 and max(s) < ts.num_samples, "Sample out of range" + nodes_sample[s, i] = True + + return _mutation_frequency( + nodes_sample, + ts.mutations_node, + ts.sites_position[ts.mutations_site], + ts.edges_parent, + ts.edges_child, + ts.edges_left, + ts.edges_right, + ts.indexes_edge_insertion_order, + ts.indexes_edge_removal_order, + ts.sequence_length, + ).squeeze() + + # --- helper functions --- # + def remove_singletons(ts): """ Remove all singleton mutations from the tree sequence. @@ -294,14 +427,18 @@ def remove_singletons(ts): assert np.all(~nodes_sample[ts.edges_parent]), "Sample node has a child" singletons = nodes_sample[ts.mutations_node] - old_metadata = np.array(tskit.unpack_strings( - ts.tables.mutations.metadata, - ts.tables.mutations.metadata_offset, - )) - old_state = np.array(tskit.unpack_strings( - ts.tables.mutations.derived_state, - ts.tables.mutations.derived_state_offset, - )) + old_metadata = np.array( + tskit.unpack_strings( + ts.tables.mutations.metadata, + ts.tables.mutations.metadata_offset, + ) + ) + old_state = np.array( + tskit.unpack_strings( + ts.tables.mutations.derived_state, + ts.tables.mutations.derived_state_offset, + ) + ) new_metadata, new_metadata_offset = tskit.pack_strings(old_metadata[~singletons]) new_state, new_state_offset = tskit.pack_strings(old_state[~singletons]) @@ -346,7 +483,9 @@ def rephase_singletons(ts, use_node_times=True, random_seed=None): nodes_id = ts.individual(individual).nodes nodes_length = np.array([tree.time(tree.parent(n)) - time for n in nodes_id]) nodes_prob = nodes_length if use_node_times else np.ones(nodes_id.size) - mutations_node[i] = rng.choice(nodes_id, p=nodes_prob / nodes_prob.sum(), size=1) + mutations_node[i] = rng.choice( + nodes_id, p=nodes_prob / nodes_prob.sum(), size=1 + ) if not np.isnan(mutations_time[i]): mutations_time[i] = (time + tree.time(tree.parent(mutations_node[i]))) / 2 @@ -358,9 +497,16 @@ def rephase_singletons(ts, use_node_times=True, random_seed=None): return tables.tree_sequence(), singletons -def insert_unphased_singletons(ts, position, individual, reference_state, alternate_state, allow_overlapping_sites=False): +def insert_unphased_singletons( + ts, + position, + individual, + reference_state, + alternate_state, + allow_overlapping_sites=False, +): """ - Insert unphased singletons into the tree sequence. The phase is arbitrarily chosen + Insert unphased singletons into the tree sequence. The phase is arbitrarily chosen so that the mutation subtends the node with the lowest id, at a given position for a a given individual. @@ -380,7 +526,9 @@ def insert_unphased_singletons(ts, position, individual, reference_state, altern individuals_node = {i.id: min(i.nodes) for i in ts.individuals()} sites_id = {p: i for i, p in enumerate(ts.sites_position)} overlap = False - for pos, ind, ref, alt in zip(position, individual, reference_state, alternate_state): + for pos, ind, ref, alt in zip( + position, individual, reference_state, alternate_state + ): if ind not in individuals_nodes: raise LookupError(f"Individual {ind} is not in the tree sequence") if pos in sites_id: @@ -404,5 +552,3 @@ def insert_unphased_singletons(ts, position, individual, reference_state, altern tables.build_index() tables.compute_mutation_parents() return tables.tree_sequence() - - diff --git a/tsdate/variational.py b/tsdate/variational.py index 29b07af0..a6d3af31 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -45,13 +45,13 @@ from .approx import _i from .approx import _i1r from .hypergeo import _gammainc_inv as gammainc_inv +from .phasing import block_singletons +from .phasing import count_mutations +from .phasing import reallocate_unphased from .rescaling import edge_sampling_weight from .rescaling import mutational_timescale from .rescaling import piecewise_scale_posterior from .util import contains_unary_nodes -from .phasing import count_mutations -from .phasing import block_singletons -from .phasing import reallocate_unphased # columns for edge_factors @@ -199,7 +199,7 @@ def _point_estimate(posteriors, constraints, median): point_estimate[fixed] = constraints[fixed, 0] return point_estimate - def __init__(self, ts, *, mutation_rate): + def __init__(self, ts, *, mutation_rate, singletons_phased=True): """ Initialize an expectation propagation algorithm for dating nodes in a tree sequence. @@ -238,11 +238,16 @@ def __init__(self, ts, *, mutation_rate): # count mutations in singleton blocks phase_timing = time.time() - individual_unphased = np.full(ts.num_individuals, True) - self.block_likelihoods, self.block_edges, \ - self.mutation_blocks = block_singletons(ts, individual_unphased) + individual_phased = np.full(ts.num_individuals, singletons_phased) + ( + self.block_likelihoods, + self.block_edges, + self.mutation_blocks, + ) = block_singletons(ts, ~individual_phased) self.block_likelihoods[:, 1] *= mutation_rate - self.block_edges = np.ascontiguousarray(self.block_edges.T) # TODO: no need to transpose + self.block_edges = np.ascontiguousarray( + self.block_edges.T + ) # TODO: no need to transpose self.block_nodes = np.full(self.block_edges.shape, tskit.NULL, dtype=np.int32) self.block_nodes[0] = self.edge_parents[self.block_edges[0]] self.block_nodes[1] = self.edge_parents[self.block_edges[1]] @@ -353,12 +358,12 @@ def leafward_projection(x, y, z): def rootward_projection(x, y, z): if unphased: - return approx.sideways_projection(x, y, z) + return approx.sideways_projection(x, y, z) return approx.rootward_projection(x, y, z) def gamma_projection(x, y, z): if unphased: - return approx.unphased_projection(x, y, z) + return approx.unphased_projection(x, y, z) return approx.gamma_projection(x, y, z) fixed = constraints[:, LOWER] == constraints[:, UPPER] @@ -378,7 +383,9 @@ def gamma_projection(x, y, z): # match moments and update factor parent_age = constraints[p, LOWER] lognorm[i], posterior[c] = leafward_projection( - parent_age, child_cavity, edge_likelihood, + parent_age, + child_cavity, + edge_likelihood, ) factors[i, LEAFWARD] *= 1.0 - child_delta factors[i, LEAFWARD] += (posterior[c] - child_cavity) / scale[c] @@ -398,7 +405,9 @@ def gamma_projection(x, y, z): # match moments and update factor child_age = constraints[c, LOWER] lognorm[i], posterior[p] = rootward_projection( - child_age, parent_cavity, edge_likelihood, + child_age, + parent_cavity, + edge_likelihood, ) factors[i, ROOTWARD] *= 1.0 - parent_delta @@ -422,7 +431,9 @@ def gamma_projection(x, y, z): # match moments and update factors lognorm[i], posterior[p], posterior[c] = gamma_projection( - parent_cavity, child_cavity, edge_likelihood, + parent_cavity, + child_cavity, + edge_likelihood, ) factors[i, ROOTWARD] *= 1.0 - delta factors[i, ROOTWARD] += (posterior[p] - parent_cavity) / scale[p] @@ -498,7 +509,9 @@ def posterior_damping(x): return np.nan @staticmethod - @numba.njit(_f(_i1r, _f2w, _f1w, _i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b)) + @numba.njit( + _f(_i1r, _f2w, _f1w, _i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b) + ) def propagate_mutations( mutations_order, mutations_posterior, @@ -560,12 +573,12 @@ def leafward_projection(x, y, z): def rootward_projection(x, y, z): if unphased: - return approx.mutation_sideways_projection(x, y, z) + return approx.mutation_sideways_projection(x, y, z) return approx.mutation_rootward_projection(x, y, z) def gamma_projection(x, y, z): if unphased: - return approx.mutation_unphased_projection(x, y, z) + return approx.mutation_unphased_projection(x, y, z) return approx.mutation_gamma_projection(x, y, z) def fixed_projection(x, y): @@ -592,7 +605,9 @@ def fixed_projection(x, y): edge_likelihood = child_delta * likelihoods[i] parent_age = constraints[p, LOWER] mutations_phase[m], mutations_posterior[m] = leafward_projection( - parent_age, child_cavity, edge_likelihood, + parent_age, + child_cavity, + edge_likelihood, ) elif fixed[c] and not fixed[p]: parent_message = factors[i, ROOTWARD] * scale[p] @@ -601,7 +616,9 @@ def fixed_projection(x, y): edge_likelihood = parent_delta * likelihoods[i] child_age = constraints[c, LOWER] mutations_phase[m], mutations_posterior[m] = rootward_projection( - child_age, parent_cavity, edge_likelihood, + child_age, + parent_cavity, + edge_likelihood, ) else: parent_message = factors[i, ROOTWARD] * scale[p] @@ -613,7 +630,9 @@ def fixed_projection(x, y): child_cavity = posterior[c] - delta * child_message edge_likelihood = delta * likelihoods[i] mutations_phase[m], mutations_posterior[m] = gamma_projection( - parent_cavity, child_cavity, edge_likelihood, + parent_cavity, + child_cavity, + edge_likelihood, ) return np.nan @@ -627,9 +646,9 @@ def rescale_factors(edges_parent, edges_child, node_factors, edge_factors, scale edge_factors[:, ROOTWARD] *= scale[p, np.newaxis] edge_factors[:, LEAFWARD] *= scale[c, np.newaxis] # TODO - #j, k = blocks_parents - #block_factors[:, ROOTWARD] *= scale[j, np.newaxis] - #block_factors[:, LEAFWARD] *= scale[k, np.newaxis] + # j, k = blocks_parents + # block_factors[:, ROOTWARD] *= scale[j, np.newaxis] + # block_factors[:, LEAFWARD] *= scale[k, np.newaxis] node_factors[:, MIXPRIOR] *= scale[:, np.newaxis] node_factors[:, CONSTRNT] *= scale[:, np.newaxis] scale[:] = 1.0 @@ -644,7 +663,6 @@ def iterate( regularise=True, check_valid=False, ): - # pass through singleton blocks self.propagate_likelihood( self.block_order, From 5fa38d6d5ac5f254401f35bb4663286917b11488 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 12 Jun 2024 12:03:14 -0700 Subject: [PATCH 14/29] More tests --- tests/test_phasing.py | 67 ++++++++++++++++++ tsdate/phasing.py | 156 +++++++++++++++++++++++------------------- tsdate/variational.py | 22 +++--- 3 files changed, 162 insertions(+), 83 deletions(-) diff --git a/tests/test_phasing.py b/tests/test_phasing.py index f8d74dc5..5be37fa1 100644 --- a/tests/test_phasing.py +++ b/tests/test_phasing.py @@ -30,9 +30,14 @@ import tsinfer import tskit +import tsdate + from tsdate.phasing import block_singletons from tsdate.phasing import count_mutations from tsdate.phasing import mutation_frequency +from tsdate.phasing import remove_singletons +from tsdate.phasing import insert_unphased_singletons +from tsdate.phasing import rephase_singletons @pytest.fixture(scope="session") @@ -246,6 +251,33 @@ def test_all_phased(self, inferred_ts): assert np.all(block_muts == tskit.NULL) +class TestPhaseAgnosticDating: + """ + If singleton phase is randomized, we should get same results with the phase + agnostic algorithm + """ + + def test_phase_invariance(self, inferred_ts): + ts1 = inferred_ts + ts2 = rephase_singletons(ts1, use_node_times=False, random_seed=1) + frq = mutation_frequency(ts1) + assert np.all(ts1.mutations_node[frq != 1] == ts2.mutations_node[frq != 1]) + assert np.any(ts1.mutations_node[frq == 1] != ts2.mutations_node[frq == 1]) + dts1 = tsdate.date( + ts1, + mutation_rate=1.29e-8, + method='variational_gamma', + singletons_phased=False, + ) + dts2 = tsdate.date( + ts2, + mutation_rate=1.29e-8, + method='variational_gamma', + singletons_phased=False, + ) + np.testing.assert_allclose(dts1.nodes_time, dts2.nodes_time) + + class TestMutationFrequency: @staticmethod def naive_mutation_frequency(ts, sample_set): @@ -271,3 +303,38 @@ def test_mutation_frequency_stratified(self, inferred_ts): for i, s in enumerate(sample_sets): ck_freq = self.naive_mutation_frequency(inferred_ts, s) np.testing.assert_array_equal(ck_freq, freqs[:, i]) + + +class TestModifySingletons: + def test_remove_singletons(self, inferred_ts): + new_ts, _ = remove_singletons(inferred_ts) + old_frq = mutation_frequency(inferred_ts) + new_frq = mutation_frequency(new_ts) + num_singletons = np.sum(old_frq == 1) + assert inferred_ts.num_mutations - num_singletons == new_ts.num_mutations + assert np.any(old_frq == 1) and np.all(new_frq > 1) + + def test_insert_unphased_singletons(self, inferred_ts): + its = inferred_ts + inter_ts, removed = remove_singletons(its) + new_ts = insert_unphased_singletons(inter_ts, *removed) + old_frq = mutation_frequency(its) + new_frq = mutation_frequency(new_ts) + assert new_ts.num_mutations == its.num_mutations + old_singles = old_frq == 1 + old_singleton_pos = its.sites_position[its.mutations_site] + old_singleton_ind = its.nodes_individual[its.mutations_node] + old_order = np.argsort(old_singleton_pos) + new_singles = new_frq == 1 + new_singleton_pos = new_ts.sites_position[new_ts.mutations_site] + new_singleton_ind = new_ts.nodes_individual[new_ts.mutations_node] + new_order = np.argsort(new_singleton_pos) + np.testing.assert_array_equal( + old_singleton_pos[old_order], + new_singleton_pos[new_order], + ) + np.testing.assert_array_equal( + old_singleton_ind[old_order], + new_singleton_ind[new_order], + ) + # TODO: more thorough testing (ancestral state, etc) diff --git a/tsdate/phasing.py b/tsdate/phasing.py index 0102b463..c90d5674 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -54,22 +54,22 @@ def reallocate_unphased( edges to which it maps """ assert mutations_phase.size == mutations_block.size - assert blocks_edges.shape[0] == 2 + assert blocks_edges.shape[1] == 2 num_mutations = mutations_phase.size num_edges = edges_likelihood.shape[0] num_blocks = blocks_edges.shape[0] edges_unphased = np.full(num_edges, False) - edges_unphased[blocks_edges[0]] = True - edges_unphased[blocks_edges[1]] = True + edges_unphased[blocks_edges[:, 0]] = True + edges_unphased[blocks_edges[:, 1]] = True num_unphased = np.sum(edges_likelihood[edges_unphased, 0]) edges_likelihood[edges_unphased, 0] = 0.0 for m, b in enumerate(mutations_block): if b == tskit.NULL: continue - i, j = blocks_edges[0, b], blocks_edges[1, b] + i, j = blocks_edges[b] assert tskit.NULL < i < num_edges and edges_unphased[i] assert tskit.NULL < j < num_edges and edges_unphased[j] assert 0.0 <= mutations_phase[m] <= 1.0 @@ -311,6 +311,10 @@ def count_mutations(ts): ) +# def mutations_node(mutations_block, mutations_phase, blocks_nodes): + + + @numba.njit(_i2w(_b2r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f)) def _mutation_frequency( nodes_sample, @@ -427,120 +431,96 @@ def remove_singletons(ts): assert np.all(~nodes_sample[ts.edges_parent]), "Sample node has a child" singletons = nodes_sample[ts.mutations_node] - old_metadata = np.array( + metadata = np.array( tskit.unpack_strings( ts.tables.mutations.metadata, ts.tables.mutations.metadata_offset, ) ) - old_state = np.array( + + state = np.array( tskit.unpack_strings( ts.tables.mutations.derived_state, ts.tables.mutations.derived_state_offset, ) ) - new_metadata, new_metadata_offset = tskit.pack_strings(old_metadata[~singletons]) - new_state, new_state_offset = tskit.pack_strings(old_state[~singletons]) + + singleton_derived = state[singletons] + singleton_ancestral = np.array( + tskit.unpack_strings( + ts.tables.sites.ancestral_state, + ts.tables.sites.ancestral_state_offset, + ) + ) + singleton_ancestral = singleton_ancestral[ts.mutations_site] + singleton_ancestral = singleton_ancestral[singletons] + + metadata, metadata_offset = tskit.pack_strings(metadata[~singletons]) + state, state_offset = tskit.pack_strings(state[~singletons]) tables = ts.dump_tables() tables.mutations.set_columns( node=ts.mutations_node[~singletons], time=ts.mutations_time[~singletons], site=ts.mutations_site[~singletons], - derived_state=new_state, - derived_state_offset=new_state_offset, - metadata=new_metadata, - metadata_offset=new_metadata_offset, + derived_state=state, + derived_state_offset=state_offset, + metadata=metadata, + metadata_offset=metadata_offset, ) tables.sort() tables.build_index() tables.compute_mutation_parents() - return tables.tree_sequence(), np.flatnonzero(singletons) - - -def rephase_singletons(ts, use_node_times=True, random_seed=None): - """ - Rephase singleton mutations in the tree sequence. If `use_node_times` - is True, singletons are added to permissable branches with probability - proportional to the branch length (and with equal probability otherwise). - """ - rng = np.random.default_rng(random_seed) - - mutations_node = ts.mutations_node.copy() - mutations_time = ts.mutations_time.copy() - - singletons = np.bitwise_and(ts.nodes_flags[mutations_node], tskit.NODE_IS_SAMPLE) - singletons = np.flatnonzero(singletons) - tree = ts.first() - for i in singletons: - position = ts.sites_position[ts.mutations_site[i]] - individual = ts.nodes_individual[ts.mutations_node[i]] - time = ts.nodes_time[ts.mutations_node[i]] - assert individual != tskit.NULL - assert time == 0.0 - tree.seek(position) - nodes_id = ts.individual(individual).nodes - nodes_length = np.array([tree.time(tree.parent(n)) - time for n in nodes_id]) - nodes_prob = nodes_length if use_node_times else np.ones(nodes_id.size) - mutations_node[i] = rng.choice( - nodes_id, p=nodes_prob / nodes_prob.sum(), size=1 - ) - if not np.isnan(mutations_time[i]): - mutations_time[i] = (time + tree.time(tree.parent(mutations_node[i]))) / 2 + singleton_individual = ts.nodes_individual[ts.mutations_node[singletons]] + singleton_position = ts.sites_position[ts.mutations_site[singletons]] + removed_singletons = ( + singleton_position, + singleton_individual, + singleton_ancestral, + singleton_derived, + ) - # TODO: add metadata with phase probability - tables = ts.dump_tables() - tables.mutations.node = mutations_node - tables.mutations.time = mutations_time - tables.sort() - return tables.tree_sequence(), singletons + return tables.tree_sequence(), removed_singletons def insert_unphased_singletons( ts, position, individual, - reference_state, - alternate_state, - allow_overlapping_sites=False, + ancestral_state, + derived_state, ): """ Insert unphased singletons into the tree sequence. The phase is arbitrarily chosen - so that the mutation subtends the node with the lowest id, at a given position for a + so that the mutation subtends the node with the highest id, at a given position for a a given individual. :param tskit.TreeSequence ts: the tree sequence to add singletons to :param np.ndarray position: the position of the variants :param np.ndarray individual: the individual id in which the variant occurs - :param np.ndarray reference_state: the reference state of the variant - :param np.ndarray alternate_state: the alternate state of the variant - :param bool allow_overlapping_sites: whether to permit insertion of - singletons at existing sites (in which case the reference states must be - consistent) + :param np.ndarray ancestral_state: the ancestral state of the variant + :param np.ndarray derived_state: the derived state of the variant :returns: A copy of the tree sequence with singletons inserted """ # TODO: provenance / metdata tables = ts.dump_tables() - individuals_node = {i.id: min(i.nodes) for i in ts.individuals()} + individuals_node = {i.id: max(i.nodes) for i in ts.individuals()} sites_id = {p: i for i, p in enumerate(ts.sites_position)} - overlap = False for pos, ind, ref, alt in zip( - position, individual, reference_state, alternate_state + position, individual, ancestral_state, derived_state ): - if ind not in individuals_nodes: + if ind not in individuals_node: raise LookupError(f"Individual {ind} is not in the tree sequence") if pos in sites_id: - if not allow_overlapping_sites: - raise ValueError(f"A site already exists at position {pos}") if ref != ts.site(sites_id[pos]).ancestral_state: raise ValueError( f"Existing site at position {pos} has a different ancestral state" ) - overlap = True else: sites_id[pos] = tables.sites.add_row(position=pos, ancestral_state=ref) + # TODO: more efficient to do in bulk? tables.mutations.add_row( site=sites_id[pos], node=individuals_node[ind], @@ -548,7 +528,45 @@ def insert_unphased_singletons( derived_state=alt, ) tables.sort() - if allow_overlapping_sites and overlap: - tables.build_index() - tables.compute_mutation_parents() + tables.build_index() + tables.compute_mutation_parents() + return tables.tree_sequence() + + +def rephase_singletons(ts, use_node_times=True, random_seed=None): + """ + Rephase singleton mutations in the tree sequence. If `use_node_times` + is True, singletons are added to permissable branches with probability + proportional to the branch length (and with equal probability otherwise). + + This is not efficient, and is intended for benchmarking/testing. + """ + rng = np.random.default_rng(random_seed) + + mutations_node = ts.mutations_node.copy() + mutations_time = ts.mutations_time.copy() + + singletons = np.bitwise_and(ts.nodes_flags[mutations_node], tskit.NODE_IS_SAMPLE) + singletons = np.flatnonzero(singletons) + tree = ts.first() + for i in singletons: + position = ts.sites_position[ts.mutations_site[i]] + individual = ts.nodes_individual[ts.mutations_node[i]] + time = ts.nodes_time[ts.mutations_node[i]] + assert individual != tskit.NULL + assert time == 0.0 + tree.seek(position) + nodes_id = ts.individual(individual).nodes + nodes_length = np.array([tree.time(tree.parent(n)) - time for n in nodes_id]) + nodes_prob = nodes_length if use_node_times else np.ones(nodes_id.size) + nodes_prob /= nodes_prob.sum() + mutations_node[i] = rng.choice(nodes_id, p=nodes_prob, size=1)[0] + if not np.isnan(mutations_time[i]): + parent_time = tree.time(tree.parent(mutations_node[i])) + mutations_time[i] = (time + parent_time) / 2 + + tables = ts.dump_tables() + tables.mutations.node = mutations_node + tables.mutations.time = mutations_time + tables.sort() return tables.tree_sequence() diff --git a/tsdate/variational.py b/tsdate/variational.py index a6d3af31..a9a1f39c 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -239,21 +239,15 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): # count mutations in singleton blocks phase_timing = time.time() individual_phased = np.full(ts.num_individuals, singletons_phased) - ( - self.block_likelihoods, - self.block_edges, - self.mutation_blocks, - ) = block_singletons(ts, ~individual_phased) + self.block_likelihoods, self.block_edges, self.mutation_blocks = \ + block_singletons(ts, ~individual_phased) # fmt: skip self.block_likelihoods[:, 1] *= mutation_rate - self.block_edges = np.ascontiguousarray( - self.block_edges.T - ) # TODO: no need to transpose - self.block_nodes = np.full(self.block_edges.shape, tskit.NULL, dtype=np.int32) - self.block_nodes[0] = self.edge_parents[self.block_edges[0]] - self.block_nodes[1] = self.edge_parents[self.block_edges[1]] + num_blocks = self.block_likelihoods.shape[0] + self.block_nodes = np.full((2, num_blocks), tskit.NULL, dtype=np.int32) + self.block_nodes[0] = self.edge_parents[self.block_edges[:, 0]] + self.block_nodes[1] = self.edge_parents[self.block_edges[:, 1]] self.mutation_phase = np.full(ts.num_mutations, np.nan) num_unphased = np.sum(self.mutation_blocks != tskit.NULL) - num_blocks = self.block_likelihoods.shape[0] phase_timing -= time.time() logging.info(f"Found {num_unphased} unphased singleton mutations") logging.info(f"Split unphased singleton edges into {num_blocks} blocks") @@ -281,8 +275,8 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): # edge traversal order edge_unphased = np.full(ts.num_edges, False) - edge_unphased[self.block_edges[0]] = True - edge_unphased[self.block_edges[1]] = True + edge_unphased[self.block_edges[:, 0]] = True + edge_unphased[self.block_edges[:, 1]] = True edges = np.arange(ts.num_edges, dtype=np.int32)[~edge_unphased] self.edge_order = np.concatenate((edges[:-1], np.flip(edges))) self.edge_weights = edge_sampling_weight( From 58945ea5233cb1fd73506ac1afa77689e1d60020 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 12 Jun 2024 18:19:59 -0700 Subject: [PATCH 15/29] Fix edge case --- tests/exact_moments.py | 152 +++++++++++++++++++++++++-- tests/test_approximations.py | 28 +++++ tests/test_phasing.py | 82 +++++++++------ tsdate/approx.py | 98 ++++++++++++++++-- tsdate/phasing.py | 20 +--- tsdate/variational.py | 196 +++++++++++++++++++++-------------- 6 files changed, 429 insertions(+), 147 deletions(-) diff --git a/tests/exact_moments.py b/tests/exact_moments.py index d8f11793..c4444d85 100644 --- a/tests/exact_moments.py +++ b/tests/exact_moments.py @@ -121,6 +121,20 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): return logl, mn_i, va_i, mn_j, va_j +def twin_moments(a_i, b_i, y_ij, mu_ij): + """ + log p(t_i) := \ + log(2 * t_i) * y_ij - mu_ij * (2 * t_i) + \ + log(t_i) * (a_i - 1) - b_i * t_i + """ + s = a_i + y_ij + r = b_i + 2 * mu_ij + logl = log(2) * y_ij + gammaln(s) - log(r) * s + mn_i = s / r + va_i = s / r**2 + return logl, mn_i, va_i + + def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): """ log p(t_j) := \ @@ -147,8 +161,8 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): log p(t_m, t_i, t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ log(t_i) * (a_i - 1) - b_i * t_i + \ - log(t_j) * (a_j - 1) - b_j * t_j - \ - log(t_i - t_j) + log(int(t_j < t_m < t_i)) + log(t_j) * (a_j - 1) - b_j * t_j + \ + log(int(t_j < t_m < t_i) / (t_i - t_j)) """ a = a_j b = a_i + a_j + y_ij @@ -180,7 +194,7 @@ def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): log p(t_m, t_i) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ log(t_i) * (a_i - 1) - b_i * t_i + \ - log(t_i - t_j) + log(int(t_j < t_m < t_i)) + log(int(t_j < t_m < t_i) / (t_i - t_j)) """ logl, mn_i, va_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) mn_m = mn_i / 2 + t_j / 2 @@ -193,8 +207,8 @@ def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): """ log p(t_m, t_j) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_j) * (a_j - 1) - b_j * t_j - \ - log(t_i - t_j) + log(int(t_j < t_m < t_i)) + log(t_j) * (a_j - 1) - b_j * t_j + \ + log(int(t_j < t_m < t_i) / (t_i - t_j)) """ logl, mn_j, va_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) mn_m = mn_j / 2 + t_i / 2 @@ -209,8 +223,8 @@ def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ log(t_i) * (a_i - 1) - b_i * t_i + \ log(t_j) * (a_j - 1) - b_j * t_j + \ - log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ - t_j / (t_i + t_j) * int(0 < t_m < t_j)) + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) / t_i + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j) / t_j) """ a = a_j b = a_j + a_i + y_ij @@ -236,13 +250,29 @@ def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): return pr_m, mn_m, va_m +def mutation_twin_moments(a_i, b_i, y_ij, mu_ij): + """ + log p(t_m, t_i) := \ + log(int(0 < t_m < t_i) / t_i) + \ + log(2 * t_i) * y_ij - mu_ij * (2 * t_i) + \ + log(t_i) * (a_i - 1) - b_i * t_i + """ + s = a_i + y_ij + r = b_i + 2 * mu_ij + pr_m = 0.5 + mn_m = s / r / 2 + sq_m = (s + 1) * s / 3 / r**2 + va_m = sq_m - mn_m**2 + return pr_m, mn_m, va_m + + def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): """ log p(t_m, t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ log(t_j) * (a_j - 1) - b_j * t_j + \ - log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ - t_j / (t_i + t_j) * int(0 < t_m < t_j)) + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) / t_i + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j) / t_j) """ a = a_j b = a_j + y_ij + 1 @@ -270,8 +300,8 @@ def mutation_edge_moments(t_i, t_j): def mutation_block_moments(t_i, t_j): """ log p(t_m) := \ - log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ - t_j / (t_i + t_j) * int(0 < t_m < t_j)) + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) / t_i + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j) / t_j) """ pr_m = t_i / (t_i + t_j) mn_m = pr_m * t_i / 2 + (1 - pr_m) * t_j / 2 @@ -347,6 +377,21 @@ def pdf_unphased(t_i, t_j, a_i, b_i, a_j, b_j, y, mu): * np.exp(-(t_i + t_j) * mu) ) + @staticmethod + def pdf_twin(t_i, a_i, b_i, y, mu): + """ + Target marginal distribution, proportional to the parent + marginal (gamma) and a Poisson mutation likelihood over the two + branches leading from (present-day) individual to a *single* parent + """ + assert t_i > 0 + return ( + t_i ** (a_i - 1) + * np.exp(-t_i * b_i) + * (t_i + t_i) ** y + * np.exp(-(t_i + t_i) * mu) + ) + @staticmethod def pdf_sideways(t_i, t_j, a_j, b_j, y, mu): """ @@ -571,6 +616,38 @@ def test_unphased_moments(self, pars): ) assert np.isclose(var_t_j, ck_var_t_j) + def test_twin_moments(self, pars): + """ + Parent age for one singleton node above an unphased individual + """ + a_i, b_i, a_j, b_j, y_ij, mu_ij = pars + pars_redux = (a_i, b_i, y_ij, mu_ij) + logconst, t_i, var_t_i = twin_moments(*pars_redux) + ck_normconst = scipy.integrate.quad( + lambda t_i: self.pdf_twin(t_i, *pars_redux), + 0, + np.inf, + epsabs=0, + )[0] + assert np.isclose(logconst, np.log(ck_normconst)) + ck_t_i = scipy.integrate.quad( + lambda t_i: t_i * self.pdf_twin(t_i, *pars_redux) / ck_normconst, + 0, + np.inf, + epsabs=0, + )[0] + assert np.isclose(t_i, ck_t_i) + ck_var_t_i = ( + scipy.integrate.quad( + lambda t_i: t_i**2 * self.pdf_twin(t_i, *pars_redux) / ck_normconst, + 0, + np.inf, + epsabs=0, + )[0] + - ck_t_i**2 + ) + assert np.isclose(var_t_i, ck_var_t_i) + def test_sideways_moments(self, pars): """ Parent ages for an singleton nodes above an unphased individual, where @@ -793,6 +870,59 @@ def f(t_i, t_j): # conditional moments ) assert np.isclose(va, ck_va) + def test_mutation_twin_moments(self, pars): + """ + Mutation mapped to two singleton branches with children fixed to time zero + and a single parent node + """ + + def f(t_i): # conditional moments + pr = 0.5 + mn = t_i / 2 + sq = t_i**2 / 3 + return pr, mn, sq + + a_i, b_i, a_j, b_j, y_ij, mu_ij = pars + pars_redux = (a_i, b_i, y_ij, mu_ij) + pr, mn, va = mutation_twin_moments(*pars_redux) + nc = scipy.integrate.quad( + lambda t_i: self.pdf_twin(t_i, *pars_redux), + 0, + np.inf, + epsabs=0, + )[0] + ck_pr = ( + scipy.integrate.quad( + lambda t_i: f(t_i)[0] * self.pdf_twin(t_i, *pars_redux), + 0, + np.inf, + epsabs=0, + )[0] + / nc + ) + assert np.isclose(pr, ck_pr) + ck_mn = ( + scipy.integrate.quad( + lambda t_i: f(t_i)[1] * self.pdf_twin(t_i, *pars_redux), + 0, + np.inf, + epsabs=0, + )[0] + / nc + ) + assert np.isclose(mn, ck_mn) + ck_va = ( + scipy.integrate.quad( + lambda t_i: f(t_i)[2] * self.pdf_twin(t_i, *pars_redux), + 0, + np.inf, + epsabs=0, + )[0] + / nc + - ck_mn**2 + ) + assert np.isclose(va, ck_va) + def test_mutation_sideways_moments(self, pars): """ Mutation mapped to two branches with children fixed to time zero, and diff --git a/tests/test_approximations.py b/tests/test_approximations.py index 5f22d26e..0371772e 100644 --- a/tests/test_approximations.py +++ b/tests/test_approximations.py @@ -39,9 +39,11 @@ from exact_moments import mutation_moments from exact_moments import mutation_rootward_moments from exact_moments import mutation_sideways_moments +from exact_moments import mutation_twin_moments from exact_moments import mutation_unphased_moments from exact_moments import rootward_moments from exact_moments import sideways_moments +from exact_moments import twin_moments from exact_moments import unphased_moments from tsdate import approx @@ -117,6 +119,19 @@ def test_unphased_moments(self, pars): assert np.isclose(sqrt(ck_va_i), sqrt(va_i), rtol=rtol) assert np.isclose(sqrt(ck_va_j), sqrt(va_j), rtol=rtol) + def test_twin_moments(self, pars): + """ + Parent age for a singleton node above both edges of an unphased + individual + """ + a_i, b_i, a_j, b_j, y_ij, mu_ij = pars + pars_redux = (a_i, b_i, y_ij, mu_ij) + ll, mn_i, va_i = approx.twin_moments(*pars_redux) + ck_ll, ck_mn_i, ck_va_i = twin_moments(*pars_redux) + assert np.isclose(ck_ll, ll) + assert np.isclose(ck_mn_i, mn_i) + assert np.isclose(sqrt(ck_va_i), sqrt(va_i)) + def test_sideways_moments(self, pars): """ Parent ages for an singleton nodes above an unphased individual, where @@ -180,6 +195,19 @@ def test_mutation_unphased_moments(self, pars): assert np.isclose(ck_mn, mn, rtol=rtol) assert np.isclose(sqrt(ck_va), sqrt(va), rtol=rtol) + def test_mutation_twin_moments(self, pars): + """ + Mutation mapped to two singleton branches with children fixed to time zero + and the same parent + """ + a_i, b_i, a_j, b_j, y_ij, mu_ij = pars + pars_redux = (a_i, b_i, y_ij, mu_ij) + pr, mn, va = approx.mutation_twin_moments(*pars_redux) + ck_pr, ck_mn, ck_va = mutation_twin_moments(*pars_redux) + assert np.isclose(ck_pr, pr) + assert np.isclose(ck_mn, mn) + assert np.isclose(sqrt(ck_va), sqrt(va)) + def test_mutation_sideways_moments(self, pars): """ Mutation mapped to two branches with children fixed to time zero, and diff --git a/tests/test_phasing.py b/tests/test_phasing.py index 5be37fa1..71256b08 100644 --- a/tests/test_phasing.py +++ b/tests/test_phasing.py @@ -31,12 +31,11 @@ import tskit import tsdate - from tsdate.phasing import block_singletons from tsdate.phasing import count_mutations +from tsdate.phasing import insert_unphased_singletons from tsdate.phasing import mutation_frequency from tsdate.phasing import remove_singletons -from tsdate.phasing import insert_unphased_singletons from tsdate.phasing import rephase_singletons @@ -185,7 +184,6 @@ def test_total_counts(self, inferred_ts): total_singleton_span += t.span for m in t.mutations(): if t.num_samples(m.node) == 1 and (m.node in unphased_nodes): - e = t.edge(m.node) total_singleton_muts += 1.0 block_stats, *_ = block_singletons(ts, individuals_unphased) assert np.isclose(np.sum(block_stats[:, 0]), total_singleton_muts) @@ -252,30 +250,50 @@ def test_all_phased(self, inferred_ts): class TestPhaseAgnosticDating: - """ - If singleton phase is randomized, we should get same results with the phase - agnostic algorithm - """ + """ + If singleton phase is randomized, we should get same results with the phase + agnostic algorithm + """ - def test_phase_invariance(self, inferred_ts): - ts1 = inferred_ts - ts2 = rephase_singletons(ts1, use_node_times=False, random_seed=1) - frq = mutation_frequency(ts1) - assert np.all(ts1.mutations_node[frq != 1] == ts2.mutations_node[frq != 1]) - assert np.any(ts1.mutations_node[frq == 1] != ts2.mutations_node[frq == 1]) - dts1 = tsdate.date( + def test_phase_invariance(self, inferred_ts): + ts1 = inferred_ts + ts2 = rephase_singletons(ts1, use_node_times=False, random_seed=1) + frq = mutation_frequency(ts1) + assert np.all(ts1.mutations_node[frq != 1] == ts2.mutations_node[frq != 1]) + assert np.any(ts1.mutations_node[frq == 1] != ts2.mutations_node[frq == 1]) + dts1 = tsdate.date( ts1, mutation_rate=1.29e-8, - method='variational_gamma', + method="variational_gamma", singletons_phased=False, - ) - dts2 = tsdate.date( + ) + dts2 = tsdate.date( ts2, mutation_rate=1.29e-8, - method='variational_gamma', + method="variational_gamma", singletons_phased=False, - ) - np.testing.assert_allclose(dts1.nodes_time, dts2.nodes_time) + ) + np.testing.assert_allclose(dts1.nodes_time, dts2.nodes_time) + + def test_not_phase_invariance(self, inferred_ts): + ts1 = inferred_ts + ts2 = rephase_singletons(ts1, use_node_times=False, random_seed=1) + frq = mutation_frequency(ts1) + assert np.all(ts1.mutations_node[frq != 1] == ts2.mutations_node[frq != 1]) + assert np.any(ts1.mutations_node[frq == 1] != ts2.mutations_node[frq == 1]) + dts1 = tsdate.date( + ts1, + mutation_rate=1.29e-8, + method="variational_gamma", + singletons_phased=True, + ) + dts2 = tsdate.date( + ts2, + mutation_rate=1.29e-8, + method="variational_gamma", + singletons_phased=True, + ) + assert not np.allclose(dts1.nodes_time, dts2.nodes_time) class TestMutationFrequency: @@ -318,23 +336,19 @@ def test_insert_unphased_singletons(self, inferred_ts): its = inferred_ts inter_ts, removed = remove_singletons(its) new_ts = insert_unphased_singletons(inter_ts, *removed) - old_frq = mutation_frequency(its) - new_frq = mutation_frequency(new_ts) assert new_ts.num_mutations == its.num_mutations - old_singles = old_frq == 1 - old_singleton_pos = its.sites_position[its.mutations_site] - old_singleton_ind = its.nodes_individual[its.mutations_node] - old_order = np.argsort(old_singleton_pos) - new_singles = new_frq == 1 - new_singleton_pos = new_ts.sites_position[new_ts.mutations_site] - new_singleton_ind = new_ts.nodes_individual[new_ts.mutations_node] - new_order = np.argsort(new_singleton_pos) + old_pos = its.sites_position[its.mutations_site] + old_ind = its.nodes_individual[its.mutations_node] + old_order = np.argsort(old_pos) + new_pos = new_ts.sites_position[new_ts.mutations_site] + new_ind = new_ts.nodes_individual[new_ts.mutations_node] + new_order = np.argsort(new_pos) np.testing.assert_array_equal( - old_singleton_pos[old_order], - new_singleton_pos[new_order], + old_pos[old_order], + new_pos[new_order], ) np.testing.assert_array_equal( - old_singleton_ind[old_order], - new_singleton_ind[new_order], + old_ind[old_order], + new_ind[new_order], ) # TODO: more thorough testing (ancestral state, etc) diff --git a/tsdate/approx.py b/tsdate/approx.py index 92694745..21feda4a 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -293,7 +293,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): r""" log p(t_i) := \ log(t_i - t_j) * y_ij - mu_ij * (t_i - t_j) + \ - log(t_i) * (a_i - 1) - b_i * t_i + log(t_i) * (a_i - 1) - b_i * t_i Returns normalizing constant, E[t_i], V[t_i]. """ @@ -406,6 +406,23 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): return logl, mn_i, va_i, mn_j, va_j +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f)) +def twin_moments(a_i, b_i, y_ij, mu_ij): + r""" + log p(t_i) := \ + log(2 * t_i) * y_ij - mu_ij * (2 * t_i) + \ + log(t_i) * (a_i - 1) - b_i * t_i + + Returns normalizing constant, E[t_i], V[t_i]. + """ + s = a_i + y_ij + r = b_i + 2 * mu_ij + logl = log(2) * y_ij + hypergeo._gammaln(s) - log(r) * s + mn_i = s / r + va_i = s / r**2 + return logl, mn_i, va_i + + @numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" @@ -465,7 +482,6 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): f222 = hyp2f1(a + 2, b + 2, c + 2, z) s1 = a * b / c - s2 = s1 * (a + 1) * (b + 1) / (c + 1) d1 = b * (b + 1) / t**2 d2 = d1 * a / c d3 = d2 * (a + 1) / (c + 1) @@ -526,8 +542,8 @@ def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ log(t_i) * (a_i - 1) - b_i * t_i + \ log(t_j) * (a_j - 1) - b_j * t_j + \ - log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ - t_j / (t_i + t_j) * int(0 < t_m < t_j)) + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) / t_i + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j) / t_j) Returns P[m under i], E[t_m], V[t_m]. """ @@ -572,14 +588,33 @@ def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): return pr_m, mn_m, va_m +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f)) +def mutation_twin_moments(a_i, b_i, y_ij, mu_ij): + r""" + log p(t_m, t_i) := \ + log(2 * t_i) * y_ij - mu_ij * (2 * t_i) + \ + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(int(0 < t_m < t_i) / t_i) + """ + + s = a_i + y_ij + r = b_i + 2 * mu_ij + pr_m = 0.5 + mn_m = s / r / 2 + sq_m = (s + 1) * s / 3 / r**2 + va_m = sq_m - mn_m**2 + + return pr_m, mn_m, va_m + + @numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_m, t_j) := \ log(t_i + t_j) * y_ij - mu_ij * (t_i + t_j) + \ log(t_j) * (a_j - 1) - b_j * t_j + \ - log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ - t_j / (t_i + t_j) * int(0 < t_m < t_j)) + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) / t_i + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j) / t_j) Returns P[m under i], E[t_m], V[t_m]. """ @@ -640,8 +675,8 @@ def mutation_edge_moments(t_i, t_j): def mutation_block_moments(t_i, t_j): r""" log p(t_m) := \ - log(t_i / (t_i + t_j) * int(0 < t_m < t_i) + \ - t_j / (t_i + t_j) * int(0 < t_m < t_j)) + log(t_i / (t_i + t_j) * int(0 < t_m < t_i) / t_i + \ + t_j / (t_i + t_j) * int(0 < t_m < t_j) / t_j) Returns P[m under i], E[t_m], V[t_m]. """ @@ -760,6 +795,29 @@ def unphased_projection(pars_i, pars_j, pars_ij): return logl, np.array(proj_i), np.array(proj_j) +@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r)) +def twin_projection(pars_i, pars_ij): + r""" + log p(t_i) := \ + log(2 * t_i) * y_ij - mu_ij * (2 * t_i) + \ + log(t_i) * (a_i - 1) - b_i * t_i + + Returns normalizing constant, gamma natural parameters for parent ages + """ + a_i, b_i = pars_i + y_ij, mu_ij = pars_ij + a_i += 1 + + logl, mn_i, va_i = twin_moments(a_i, b_i, y_ij, mu_ij) + + if not _valid_moments(mn_i, va_i): + return np.nan, pars_i + + proj_i = approximate_gamma_mom(mn_i, va_i) + + return logl, np.array(proj_i) + + @numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def sideways_projection(t_i, pars_j, pars_ij): r""" @@ -904,6 +962,30 @@ def mutation_unphased_projection(pars_i, pars_j, pars_ij): return pr_m, np.array(proj_m) +@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r)) +def mutation_twin_projection(pars_i, pars_ij): + r""" + log p(t_m, t_i) := \ + log(2 * t_i) * y_ij - mu_ij * (2 * t_i) + \ + log(t_i) * (a_i - 1) - b_i * t_i + \ + log(int(0 < t_m < t_i) / t_i) + + Returns phase probability, gamma natural parameters for mutation age + """ + a_i, b_i = pars_i + y_ij, mu_ij = pars_ij + a_i += 1 + + pr_m, mn_m, va_m = mutation_twin_moments(a_i, b_i, y_ij, mu_ij) + + if not _valid_moments(mn_m, va_m) or not (0 <= pr_m <= 1): + return np.nan, np.full(2, np.nan) + + proj_m = approximate_gamma_mom(mn_m, va_m) + + return pr_m, np.array(proj_m) + + @numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def mutation_sideways_projection(t_i, pars_j, pars_ij): r""" diff --git a/tsdate/phasing.py b/tsdate/phasing.py index c90d5674..55eaad89 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -26,13 +26,10 @@ import numpy as np import tskit -from .approx import _b from .approx import _b1r from .approx import _b2r from .approx import _f from .approx import _f1r -from .approx import _f1w -from .approx import _f2r from .approx import _f2w from .approx import _i from .approx import _i1r @@ -56,10 +53,7 @@ def reallocate_unphased( assert mutations_phase.size == mutations_block.size assert blocks_edges.shape[1] == 2 - num_mutations = mutations_phase.size num_edges = edges_likelihood.shape[0] - num_blocks = blocks_edges.shape[0] - edges_unphased = np.full(num_edges, False) edges_unphased[blocks_edges[:, 0]] = True edges_unphased[blocks_edges[:, 1]] = True @@ -103,7 +97,6 @@ def _block_singletons( assert indexes_insert.size == indexes_remove.size == edges_parent.size assert mutations_node.size == mutations_position.size - num_nodes = nodes_individual.size num_mutations = mutations_node.size num_edges = edges_parent.size num_individuals = individuals_unphased.size @@ -130,7 +123,7 @@ def _block_singletons( while a < num_edges or b < num_edges: while b < num_edges and position_remove[b] == left: # edges out e = indexes_remove[b] - p, c = edges_parent[e], edges_child[e] + c = edges_child[e] i = nodes_individual[c] if i != tskit.NULL and individuals_unphased[i]: u, v = individuals_edges[i] @@ -149,7 +142,7 @@ def _block_singletons( while a < num_edges and position_insert[a] == left: # edges in e = indexes_insert[a] - p, c = edges_parent[e], edges_child[e] + c = edges_child[e] i = nodes_individual[c] if i != tskit.NULL and individuals_unphased[i]: u, v = individuals_edges[i] @@ -260,13 +253,13 @@ def _count_mutations( while a < num_edges or b < num_edges: while b < num_edges and position_remove[b] == left: # edges out e = indexes_remove[b] - p, c = edges_parent[e], edges_child[e] + c = edges_child[e] nodes_edge[c] = tskit.NULL b += 1 while a < num_edges and position_insert[a] == left: # edges in e = indexes_insert[a] - p, c = edges_parent[e], edges_child[e] + c = edges_child[e] nodes_edge[c] = e a += 1 @@ -314,7 +307,6 @@ def count_mutations(ts): # def mutations_node(mutations_block, mutations_phase, blocks_nodes): - @numba.njit(_i2w(_b2r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f)) def _mutation_frequency( nodes_sample, @@ -508,9 +500,7 @@ def insert_unphased_singletons( tables = ts.dump_tables() individuals_node = {i.id: max(i.nodes) for i in ts.individuals()} sites_id = {p: i for i, p in enumerate(ts.sites_position)} - for pos, ind, ref, alt in zip( - position, individual, ancestral_state, derived_state - ): + for pos, ind, ref, alt in zip(position, individual, ancestral_state, derived_state): if ind not in individuals_node: raise LookupError(f"Individual {ind} is not in the tree sequence") if pos in sites_id: diff --git a/tsdate/variational.py b/tsdate/variational.py index a9a1f39c..bbfb1331 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -44,6 +44,7 @@ from .approx import _f3w from .approx import _i from .approx import _i1r +from .approx import _i2r from .hypergeo import _gammainc_inv as gammainc_inv from .phasing import block_singletons from .phasing import count_mutations @@ -59,8 +60,8 @@ LEAFWARD = 1 # edge likelihood to child # columns for unphased_factors -FIRSTPAR = 0 # edge likelihood to first parent -SECNDPAR = 1 # edge likelihood to second parent +NODEONE = 0 # block likelihood to first parent +NODETWO = 1 # block likelihood to second parent # columns for node_factors MIXPRIOR = 0 # mixture prior to node @@ -173,18 +174,26 @@ def _check_valid_inputs(ts, mutation_rate): @staticmethod def _check_valid_state( - edges_parent, edges_child, posterior, node_factors, edge_factors + edges_parent, + edges_child, + block_nodes, + posterior, + node_factors, + edge_factors, + block_factors, ): """Check that the messages sum to the posterior (debugging only)""" + block_one, block_two = block_nodes posterior_check = np.zeros(posterior.shape) for i, (p, c) in enumerate(zip(edges_parent, edges_child)): posterior_check[p] += edge_factors[i, ROOTWARD] posterior_check[c] += edge_factors[i, LEAFWARD] - # TODO: unphased factors - assert False + for i, (j, k) in enumerate(zip(block_one, block_two)): + posterior_check[j] += block_factors[i, ROOTWARD] + posterior_check[k] += block_factors[i, LEAFWARD] posterior_check += node_factors[:, MIXPRIOR] posterior_check += node_factors[:, CONSTRNT] - return np.allclose(posterior_check, posterior) + np.testing.assert_allclose(posterior_check, posterior) @staticmethod @numba.njit(_f1w(_f2r, _f2r, _b)) @@ -194,7 +203,8 @@ def _point_estimate(posteriors, constraints, median): point_estimate = np.zeros(posteriors.shape[0]) for i in np.flatnonzero(~fixed): alpha, beta = posteriors[i] - point_estimate[i] = gammainc_inv(alpha + 1, 0.5) if median else (alpha + 1) + point_estimate[i] = gammainc_inv(alpha + 1, 0.5) \ + if median else (alpha + 1) # fmt: skip point_estimate[i] /= beta point_estimate[fixed] = constraints[fixed, 0] return point_estimate @@ -360,6 +370,10 @@ def gamma_projection(x, y, z): return approx.unphased_projection(x, y, z) return approx.gamma_projection(x, y, z) + def twin_projection(x, y): + assert unphased, "Invalid update" + return approx.twin_projection(x, y) + fixed = constraints[:, LOWER] == constraints[:, UPPER] for i in edge_order: @@ -373,8 +387,6 @@ def gamma_projection(x, y, z): child_delta = cavity_damping(posterior[c], child_message) child_cavity = posterior[c] - child_delta * child_message edge_likelihood = child_delta * likelihoods[i] - - # match moments and update factor parent_age = constraints[p, LOWER] lognorm[i], posterior[c] = leafward_projection( parent_age, @@ -383,8 +395,6 @@ def gamma_projection(x, y, z): ) factors[i, LEAFWARD] *= 1.0 - child_delta factors[i, LEAFWARD] += (posterior[c] - child_cavity) / scale[c] - - # upper bound posterior child_eta = posterior_damping(posterior[c]) posterior[c] *= child_eta scale[c] *= child_eta @@ -395,52 +405,61 @@ def gamma_projection(x, y, z): parent_delta = cavity_damping(posterior[p], parent_message) parent_cavity = posterior[p] - parent_delta * parent_message edge_likelihood = parent_delta * likelihoods[i] - - # match moments and update factor child_age = constraints[c, LOWER] lognorm[i], posterior[p] = rootward_projection( child_age, parent_cavity, edge_likelihood, ) - factors[i, ROOTWARD] *= 1.0 - parent_delta factors[i, ROOTWARD] += (posterior[p] - parent_cavity) / scale[p] - - # upper-bound posterior parent_eta = posterior_damping(posterior[p]) posterior[p] *= parent_eta scale[p] *= parent_eta else: - # lower-bound cavity - parent_message = factors[i, ROOTWARD] * scale[p] - child_message = factors[i, LEAFWARD] * scale[c] - parent_delta = cavity_damping(posterior[p], parent_message) - child_delta = cavity_damping(posterior[c], child_message) - delta = min(parent_delta, child_delta) - - parent_cavity = posterior[p] - delta * parent_message - child_cavity = posterior[c] - delta * child_message - edge_likelihood = delta * likelihoods[i] - - # match moments and update factors - lognorm[i], posterior[p], posterior[c] = gamma_projection( - parent_cavity, - child_cavity, - edge_likelihood, - ) - factors[i, ROOTWARD] *= 1.0 - delta - factors[i, ROOTWARD] += (posterior[p] - parent_cavity) / scale[p] - factors[i, LEAFWARD] *= 1.0 - delta - factors[i, LEAFWARD] += (posterior[c] - child_cavity) / scale[c] - - # upper-bound posterior - parent_eta = posterior_damping(posterior[p]) - child_eta = posterior_damping(posterior[c]) - posterior[p] *= parent_eta - posterior[c] *= child_eta - scale[p] *= parent_eta - scale[c] *= child_eta + if p == c: # singleton block with single parent + parent_message = factors[i, ROOTWARD] * scale[p] + parent_delta = cavity_damping(posterior[p], parent_message) + parent_cavity = posterior[p] - parent_delta * parent_message + edge_likelihood = parent_delta * likelihoods[i] + child_age = constraints[c, LOWER] + lognorm[i], posterior[p] = \ + twin_projection(parent_cavity, edge_likelihood) # fmt: skip + factors[i, ROOTWARD] *= 1.0 - parent_delta + factors[i, ROOTWARD] += (posterior[p] - parent_cavity) / scale[p] + parent_eta = posterior_damping(posterior[p]) + posterior[p] *= parent_eta + scale[p] *= parent_eta + else: + # lower-bound cavity + parent_message = factors[i, ROOTWARD] * scale[p] + child_message = factors[i, LEAFWARD] * scale[c] + parent_delta = cavity_damping(posterior[p], parent_message) + child_delta = cavity_damping(posterior[c], child_message) + delta = min(parent_delta, child_delta) + + parent_cavity = posterior[p] - delta * parent_message + child_cavity = posterior[c] - delta * child_message + edge_likelihood = delta * likelihoods[i] + + # match moments and update factors + lognorm[i], posterior[p], posterior[c] = gamma_projection( + parent_cavity, + child_cavity, + edge_likelihood, + ) + factors[i, ROOTWARD] *= 1.0 - delta + factors[i, ROOTWARD] += (posterior[p] - parent_cavity) / scale[p] + factors[i, LEAFWARD] *= 1.0 - delta + factors[i, LEAFWARD] += (posterior[c] - child_cavity) / scale[c] + + # upper-bound posterior + parent_eta = posterior_damping(posterior[p]) + child_eta = posterior_damping(posterior[c]) + posterior[p] *= parent_eta + posterior[c] *= child_eta + scale[p] *= parent_eta + scale[c] *= child_eta return np.nan @@ -492,9 +511,8 @@ def posterior_damping(x): # update posteriors and rescale to keep shape bounded posterior[free, 1] = cavity[free, 1] + penalty - factors[free, MIXPRIOR] = (posterior[free] - cavity[free]) / scale[ - free, np.newaxis - ] + factors[free, MIXPRIOR] = \ + (posterior[free] - cavity[free]) / scale[free, np.newaxis] # fmt: skip for i in np.flatnonzero(free): eta = posterior_damping(posterior[i]) posterior[i] *= eta @@ -550,9 +568,8 @@ def propagate_mutations( # TODO: scale should be 1.0, can we delete # TODO: we don't seem to need to damp? - # TODO: might as well copy format in other functions and have void return - # assert stuff here + # TODO: assert more stuff here? assert mutations_phase.size == mutations_edge.size assert mutations_posterior.shape == (mutations_phase.size, 2) assert constraints.shape == posterior.shape @@ -580,7 +597,10 @@ def fixed_projection(x, y): return approx.mutation_block_projection(x, y) return approx.mutation_edge_projection(x, y) + twin_projection = approx.mutation_twin_projection + fixed = constraints[:, LOWER] == constraints[:, UPPER] + for m in mutations_order: i = mutations_edge[m] if i == tskit.NULL: # skip mutations above root @@ -589,9 +609,8 @@ def fixed_projection(x, y): if fixed[p] and fixed[c]: child_age = constraints[c, 0] parent_age = constraints[p, 0] - phase[m], mutations_posterior[m] = fixed_projection( - parent_age, child_age - ) + mutations_phase[m], mutations_posterior[m] = \ + fixed_projection(parent_age, child_age) # fmt: skip elif fixed[p] and not fixed[c]: child_message = factors[i, LEAFWARD] * scale[c] child_delta = 1.0 # hopefully we don't need to damp @@ -615,34 +634,49 @@ def fixed_projection(x, y): edge_likelihood, ) else: - parent_message = factors[i, ROOTWARD] * scale[p] - child_message = factors[i, LEAFWARD] * scale[c] - parent_delta = 1.0 # hopefully we don't need to damp - child_delta = 1.0 # hopefully we don't need to damp - delta = min(parent_delta, child_delta) - parent_cavity = posterior[p] - delta * parent_message - child_cavity = posterior[c] - delta * child_message - edge_likelihood = delta * likelihoods[i] - mutations_phase[m], mutations_posterior[m] = gamma_projection( - parent_cavity, - child_cavity, - edge_likelihood, - ) + if p == c: # singleton block with single parent + parent_message = factors[i, ROOTWARD] * scale[p] + parent_delta = 1.0 # hopefully we don't need to damp + parent_cavity = posterior[p] - parent_delta * parent_message + edge_likelihood = parent_delta * likelihoods[i] + child_age = constraints[c, LOWER] + mutations_phase[m], mutations_posterior[m] = \ + twin_projection(parent_cavity, edge_likelihood) # fmt: skip + else: + parent_message = factors[i, ROOTWARD] * scale[p] + child_message = factors[i, LEAFWARD] * scale[c] + parent_delta = 1.0 # hopefully we don't need to damp + child_delta = 1.0 # hopefully we don't need to damp + delta = min(parent_delta, child_delta) + parent_cavity = posterior[p] - delta * parent_message + child_cavity = posterior[c] - delta * child_message + edge_likelihood = delta * likelihoods[i] + mutations_phase[m], mutations_posterior[m] = gamma_projection( + parent_cavity, + child_cavity, + edge_likelihood, + ) return np.nan - # TODO more arguments, blck_factors and block_parents @staticmethod - @numba.njit(_void(_i1r, _i1r, _f3w, _f3w, _f1w)) - def rescale_factors(edges_parent, edges_child, node_factors, edge_factors, scale): + @numba.njit(_void(_i1r, _i1r, _i2r, _f3w, _f3w, _f3w, _f1w)) + def rescale_factors( + edges_parent, + edges_child, + block_nodes, + node_factors, + edge_factors, + block_factors, + scale, + ): """Incorporate scaling term into factors and reset""" p, c = edges_parent, edges_child + j, k = block_nodes edge_factors[:, ROOTWARD] *= scale[p, np.newaxis] edge_factors[:, LEAFWARD] *= scale[c, np.newaxis] - # TODO - # j, k = blocks_parents - # block_factors[:, ROOTWARD] *= scale[j, np.newaxis] - # block_factors[:, LEAFWARD] *= scale[k, np.newaxis] + block_factors[:, ROOTWARD] *= scale[j, np.newaxis] + block_factors[:, LEAFWARD] *= scale[k, np.newaxis] node_factors[:, MIXPRIOR] *= scale[:, np.newaxis] node_factors[:, CONSTRNT] *= scale[:, np.newaxis] scale[:] = 1.0 @@ -655,13 +689,13 @@ def iterate( em_maxitt=10, em_reltol=1e-8, regularise=True, - check_valid=False, + check_valid=True, # DEBUG ): # pass through singleton blocks self.propagate_likelihood( self.block_order, - self.block_nodes[0], - self.block_nodes[1], + self.block_nodes[ROOTWARD], + self.block_nodes[LEAFWARD], self.block_likelihoods, self.node_constraints, self.node_posterior, @@ -705,18 +739,22 @@ def iterate( self.rescale_factors( self.edge_parents, self.edge_children, + self.block_nodes, self.node_factors, self.edge_factors, + self.block_factors, self.node_scale, ) if check_valid: # for debugging - assert self._check_valid_state( + self._check_valid_state( self.edge_parents, self.edge_children, + self.block_nodes, self.node_posterior, self.node_factors, self.edge_factors, + self.block_factors, ) return np.nan # TODO: placeholder for marginal likelihood @@ -800,8 +838,8 @@ def run( self.mutation_posterior, self.mutation_phase, self.mutation_blocks, - self.block_nodes[0], - self.block_nodes[1], + self.block_nodes[ROOTWARD], + self.block_nodes[LEAFWARD], self.block_likelihoods, self.node_constraints, self.node_posterior, From d23f37dc456cf1e38a3ec94dbef5a651d0aa4075 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Thu, 13 Jun 2024 11:07:36 -0700 Subject: [PATCH 16/29] Phasing --- tests/test_phasing.py | 2 ++ tsdate/core.py | 60 ++++++++++++------------------------------ tsdate/rescaling.py | 2 +- tsdate/variational.py | 61 +++++++++++++++++++++++++++++++++++++------ 4 files changed, 73 insertions(+), 52 deletions(-) diff --git a/tests/test_phasing.py b/tests/test_phasing.py index 71256b08..4e664cdb 100644 --- a/tests/test_phasing.py +++ b/tests/test_phasing.py @@ -274,6 +274,8 @@ def test_phase_invariance(self, inferred_ts): singletons_phased=False, ) np.testing.assert_allclose(dts1.nodes_time, dts2.nodes_time) + np.testing.assert_allclose(dts1.mutations_node, dts2.mutations_node) + np.testing.assert_allclose(dts1.mutations_time, dts2.mutations_time) def test_not_phase_invariance(self, inferred_ts): ts1 = inferred_ts diff --git a/tsdate/core.py b/tsdate/core.py index 0d6f00ff..d8fc4aba 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -883,6 +883,7 @@ def outside_maximization(self, *, eps, progress=None): "mutation_var", "mutation_lik", "mutation_edge", + "mutation_node", ], ) @@ -1002,6 +1003,7 @@ def get_modified_ts(self, result, eps): mut_mean_t = result.mutation_mean mut_var_t = result.mutation_var mut_edge = result.mutation_edge + mut_node = result.mutation_node tables = ts.dump_tables() nodes = tables.nodes mutations = tables.mutations @@ -1011,8 +1013,8 @@ def get_modified_ts(self, result, eps): # Constrain node ages for positive branch lengths constr_timing = time.time() nodes.time = util.constrain_ages(ts, node_mean_t, eps, self.constr_iterations) - # TODO: what if mutations_edge is NULL? mutations.time = util.constrain_mutations(ts, nodes.time, mut_edge) + mutations.node = mut_node tables.time_units = self.time_units constr_timing -= time.time() logging.info(f"Constrained node ages in {abs(constr_timing)} seconds") @@ -1171,6 +1173,7 @@ def run( posterior_mean, posterior_var = self.mean_var(self.ts, posterior_obj) mut_edge = np.full(self.ts.num_mutations, tskit.NULL) + mut_node = self.ts.mutations_node return Results( posterior_mean, posterior_var, @@ -1179,6 +1182,7 @@ def run( None, marginal_likl, mut_edge, + mut_node, ) @@ -1207,7 +1211,10 @@ def run( marginal_likl = dynamic_prog.inside_pass(cache_inside=cache_inside) posterior_mean = dynamic_prog.outside_maximization(eps=eps) mut_edge = np.full(self.ts.num_mutations, tskit.NULL) - return Results(posterior_mean, None, None, None, None, marginal_likl, mut_edge) + mut_node = self.ts.mutations_node + return Results( + posterior_mean, None, None, None, None, marginal_likl, mut_edge, mut_node + ) class VariationalGammaMethod(EstimationMethod): @@ -1217,31 +1224,6 @@ class VariationalGammaMethod(EstimationMethod): def __init__(self, ts, **kwargs): super().__init__(ts, **kwargs) - @staticmethod - def mean_var(posteriors, constraints): - """ - Mean and variance of node age from variational posterior (e.g. gamma - distributions). Fixed nodes will be given a mean of their exact time in - the tree sequence, and zero variance (as long as they are identified by the - fixed_node_set). This is a static method for ease of testing. - """ - - mn_post = np.full( - posteriors.shape[0], np.nan - ) # Fill with NaNs so we detect when - va_post = np.full(posteriors.shape[0], np.nan) # there's been an error - - fixed = constraints[:, 0] == constraints[:, 1] - mn_post[fixed] = constraints[fixed, 0] - va_post[fixed] = 0 - - for i in np.flatnonzero(~fixed): - pars = posteriors[i] - mn_post[i] = (pars[0] + 1) / pars[1] - va_post[i] = (pars[0] + 1) / pars[1] ** 2 - - return mn_post, va_post - def run( self, eps, @@ -1276,29 +1258,21 @@ def run( progress=self.pbar, ) - # TODO: use posterior.point_estimate - posterior_mean, posterior_vari = self.mean_var( - posterior.node_posterior, posterior.node_constraints - ) + node_mn, node_va = posterior.node_moments() + mutation_mn, mutation_va = posterior.mutation_moments() - # TODO: clean up - mutation_post = posterior.mutation_posterior - mutation_mean = np.full(mutation_post.shape[0], np.nan) - mutation_vari = np.full(mutation_post.shape[0], np.nan) - idx = mutation_post[:, 1] > 0 - mutation_mean[idx] = (mutation_post[idx, 0] + 1) / mutation_post[idx, 1] - mutation_vari[idx] = (mutation_post[idx, 0] + 1) / mutation_post[idx, 1] ** 2 mutation_edge = posterior.mutation_edges + mutation_edge, mutation_node = posterior.mutation_mapping() - # TODO: return marginal likelihood return Results( - posterior_mean, - posterior_vari, + node_mn, + node_va, None, - mutation_mean, - mutation_vari, + mutation_mn, + mutation_va, None, mutation_edge, + mutation_node, ) diff --git a/tsdate/rescaling.py b/tsdate/rescaling.py index 0d6dbade..a1bcec84 100644 --- a/tsdate/rescaling.py +++ b/tsdate/rescaling.py @@ -245,7 +245,7 @@ def rescale(x): # reproject posteriors using inter-quantile range # TODO: catch rare cases where lower/upper quantiles are nearly identical - new_posteriors = np.zeros(posteriors.shape) + new_posteriors = np.full(posteriors.shape, np.nan) for i in np.flatnonzero(freed): alpha, beta = approximate_gamma_iqr( quant_lower, quant_upper, lower[i], upper[i] diff --git a/tsdate/variational.py b/tsdate/variational.py index bbfb1331..230e6d73 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -256,7 +256,6 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): self.block_nodes = np.full((2, num_blocks), tskit.NULL, dtype=np.int32) self.block_nodes[0] = self.edge_parents[self.block_edges[:, 0]] self.block_nodes[1] = self.edge_parents[self.block_edges[:, 1]] - self.mutation_phase = np.full(ts.num_mutations, np.nan) num_unphased = np.sum(self.mutation_blocks != tskit.NULL) phase_timing -= time.time() logging.info(f"Found {num_unphased} unphased singleton mutations") @@ -269,6 +268,8 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): self.block_factors = np.zeros((num_blocks, 2, 2)) self.node_posterior = np.zeros((ts.num_nodes, 2)) self.mutation_posterior = np.full((ts.num_mutations, 2), np.nan) + self.mutation_phase = np.ones(ts.num_mutations) + self.mutation_nodes = ts.mutations_node.copy() self.edge_logconst = np.zeros(ts.num_edges) self.block_logconst = np.zeros(num_blocks) self.node_scale = np.ones(ts.num_nodes) @@ -689,7 +690,7 @@ def iterate( em_maxitt=10, em_reltol=1e-8, regularise=True, - check_valid=True, # DEBUG + check_valid=False, # for debugging ): # pass through singleton blocks self.propagate_likelihood( @@ -828,13 +829,13 @@ def run( ) nodes_timing -= time.time() skipped_edges = np.sum(np.isnan(self.edge_logconst)) - if skipped_edges: - logging.info(f"Skipped {skipped_edges} edges with invalid factors") + logging.info(f"Skipped {skipped_edges} edges with invalid factors") logging.info(f"Calculated node posteriors in {abs(nodes_timing)} seconds") muts_timing = time.time() + mutations_phased = self.mutation_blocks == tskit.NULL self.propagate_mutations( # unphased singletons - self.mutation_order, + self.mutation_order[~mutations_phased], self.mutation_posterior, self.mutation_phase, self.mutation_blocks, @@ -848,7 +849,7 @@ def run( USE_BLOCK_LIKELIHOOD, ) self.propagate_mutations( # phased mutations - self.mutation_order, + self.mutation_order[mutations_phased], self.mutation_posterior, self.mutation_phase, self.mutation_edges, @@ -863,10 +864,22 @@ def run( ) muts_timing -= time.time() skipped_muts = np.sum(np.isnan(self.mutation_posterior[:, 0])) - if skipped_muts: - logging.info(f"Skipped {skipped_muts} mutations with invalid posteriors") + logging.info(f"Skipped {skipped_muts} mutations with invalid posteriors") logging.info(f"Calculated mutation posteriors in {abs(muts_timing)} seconds") + singletons = self.mutation_blocks != tskit.NULL + switched_blocks = self.mutation_blocks[singletons] + switched_edges = np.where( + self.mutation_phase[singletons] < 0.5, + self.block_edges[switched_blocks, 1], + self.block_edges[switched_blocks, 0], + ) + self.mutation_edges[singletons] = switched_edges + self.mutation_nodes[singletons] = self.edge_children[switched_edges] + switched = self.mutation_phase < 0.5 + self.mutation_phase[switched] = 1 - self.mutation_phase[switched] + logging.info(f"Switched phase of {np.sum(switched)} singletons") + if rescale_intervals > 0: rescale_timing = time.time() self.rescale( @@ -874,3 +887,35 @@ def run( ) rescale_timing -= time.time() logging.info(f"Timescale rescaled in {abs(rescale_timing)} seconds") + + def node_moments(self): + """ + Posterior mean and variance of node ages + """ + alpha, beta = self.node_posterior.T + nodes_mn = np.ascontiguousarray(self.node_constraints[:, 0]) + nodes_va = np.zeros(nodes_mn.size) + free = self.node_constraints[:, 0] != self.node_constraints[:, 1] + nodes_mn[free] = (alpha[free] + 1) / beta[free] + nodes_va[free] = nodes_mn[free] / beta[free] + return nodes_mn, nodes_va + + def mutation_moments(self): + """ + Posterior mean and variance of mutation ages + """ + alpha, beta = self.mutation_posterior.T + muts_mn = np.full(alpha.size, np.nan) + muts_va = np.full(alpha.size, np.nan) + free = np.isfinite(alpha) + muts_mn[free] = (alpha[free] + 1) / beta[free] + muts_va[free] = muts_mn[free] / beta[free] + return muts_mn, muts_va + + def mutation_mapping(self): + """ + Map from mutations to edges and subtended node, using estimated singleton + phase (if singletons were unphased) + """ + # TODO: should these be copies? Should members be readonly? + return self.mutation_edges, self.mutation_nodes From 80862e57cf71130b49e43759c00f277a1044f2f2 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Thu, 13 Jun 2024 21:10:22 -0700 Subject: [PATCH 17/29] Mutation parent fix --- tests/test_phasing.py | 2 ++ tsdate/core.py | 9 ++++++--- tsdate/phasing.py | 12 +++++++++--- tsdate/variational.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 6 deletions(-) diff --git a/tests/test_phasing.py b/tests/test_phasing.py index 4e664cdb..f91fc561 100644 --- a/tests/test_phasing.py +++ b/tests/test_phasing.py @@ -296,6 +296,8 @@ def test_not_phase_invariance(self, inferred_ts): singletons_phased=True, ) assert not np.allclose(dts1.nodes_time, dts2.nodes_time) + assert not np.allclose(dts1.mutations_node, dts2.mutations_node) + assert not np.allclose(dts1.mutations_time, dts2.mutations_time) class TestMutationFrequency: diff --git a/tsdate/core.py b/tsdate/core.py index d8fc4aba..3b65be34 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1015,6 +1015,7 @@ def get_modified_ts(self, result, eps): nodes.time = util.constrain_ages(ts, node_mean_t, eps, self.constr_iterations) mutations.time = util.constrain_mutations(ts, nodes.time, mut_edge) mutations.node = mut_node + mutations.parent = np.full(mutations.num_rows, tskit.NULL, dtype=np.int32) tables.time_units = self.time_units constr_timing -= time.time() logging.info(f"Constrained node ages in {abs(constr_timing)} seconds") @@ -1030,7 +1031,12 @@ def get_modified_ts(self, result, eps): logging.info( f"Inserted node and mutation metadata in {abs(meta_timing)} seconds" ) + sort_timing = time.time() tables.sort() + tables.build_index() + tables.compute_mutation_parents() + sort_timing -= time.time() + logging.info(f"Sorted tree sequence in {abs(sort_timing)} seconds") return tables.tree_sequence() def set_time_metadata(self, table, mean, var, default_schema, overwrite=False): @@ -1243,7 +1249,6 @@ def run( if self.mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") - # match sufficient statistics or match central moments posterior = variational.ExpectationPropagation( self.ts, mutation_rate=self.mutation_rate, @@ -1260,8 +1265,6 @@ def run( node_mn, node_va = posterior.node_moments() mutation_mn, mutation_va = posterior.mutation_moments() - - mutation_edge = posterior.mutation_edges mutation_edge, mutation_node = posterior.mutation_mapping() return Results( diff --git a/tsdate/phasing.py b/tsdate/phasing.py index 55eaad89..3e4d9273 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -508,13 +508,19 @@ def insert_unphased_singletons( raise ValueError( f"Existing site at position {pos} has a different ancestral state" ) + muts = ts.site(sites_id[pos]).mutations + set_time = len(muts) and np.isfinite(muts[0].time) else: sites_id[pos] = tables.sites.add_row(position=pos, ancestral_state=ref) + set_time = False # TODO: more efficient to do in bulk? + site = sites_id[pos] + node = individuals_node[ind] + time = ts.nodes_time[node] if set_time else tskit.UNKNOWN_TIME tables.mutations.add_row( - site=sites_id[pos], - node=individuals_node[ind], - time=tskit.UNKNOWN_TIME, + site=site, + node=node, + time=time, derived_state=alt, ) tables.sort() diff --git a/tsdate/variational.py b/tsdate/variational.py index 230e6d73..ac82737a 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -919,3 +919,46 @@ def mutation_mapping(self): """ # TODO: should these be copies? Should members be readonly? return self.mutation_edges, self.mutation_nodes + + +# def date( +# ts, +# *, +# mutation_rate, +# singletons_phased=True, +# max_iterations=10, +# match_segregating_sites=False, +# regularise_roots=True, +# constr_iterations=0, +# progress=True, +# ): +# """ +# Date a tree sequence with expectation propagation. Returns dated tree +# sequence and converged ExpectationPropagation object. +# """ +# +# posterior = variational.ExpectationPropagation( +# ts, +# mutation_rate=mutation_rate, +# singletons_phased=singletons_phased, +# ) +# posterior.run( +# ep_maxitt=max_iterations, +# max_shape=max_shape, +# rescale_intervals=rescaling_intervals, +# regularise=regularise_roots, +# rescale_segsites=match_segregating_sites, +# progress=progress, +# ) +# +# node_mn, node_va = posterior.node_moments() +# mutation_mn, mutation_va = posterior.mutation_moments() +# mutation_edge, mutation_node = posterior.mutation_mapping() +# +# tables = ts.dump_tables() +# tables.nodes.time = constrain_ages( +# ts, node_mn, constr_iterations=constr_iterations) +# tables.mutations.node = mutation_node +# tables.sort() +# +# return tables.tree_sequence(), posterior From 63bb5ac4f23c178b73e3c2205dc809101458e903 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 17 Jun 2024 11:15:41 -0700 Subject: [PATCH 18/29] Mutational area --- tests/test_phasing.py | 16 --- tests/test_rescaling.py | 259 +++++++++++++++++++++++++++++++++++++ tsdate/evaluation.py | 33 ++++- tsdate/phasing.py | 94 -------------- tsdate/rescaling.py | 279 +++++++++++++++++++++++++++++++++++++++- tsdate/variational.py | 5 +- 6 files changed, 566 insertions(+), 120 deletions(-) create mode 100644 tests/test_rescaling.py diff --git a/tests/test_phasing.py b/tests/test_phasing.py index f91fc561..d3d60870 100644 --- a/tests/test_phasing.py +++ b/tests/test_phasing.py @@ -32,7 +32,6 @@ import tsdate from tsdate.phasing import block_singletons -from tsdate.phasing import count_mutations from tsdate.phasing import insert_unphased_singletons from tsdate.phasing import mutation_frequency from tsdate.phasing import remove_singletons @@ -54,21 +53,6 @@ def inferred_ts(): return inferred_ts -class TestCountMutations: - def test_count_mutations(self, inferred_ts): - edge_stats, muts_edge = count_mutations(inferred_ts) - ck_edge_muts = np.zeros(inferred_ts.num_edges) - ck_muts_edge = np.full(inferred_ts.num_mutations, tskit.NULL) - for m in inferred_ts.mutations(): - if m.edge != tskit.NULL: - ck_edge_muts[m.edge] += 1.0 - ck_muts_edge[m.id] = m.edge - ck_edge_span = inferred_ts.edges_right - inferred_ts.edges_left - np.testing.assert_array_almost_equal(ck_edge_muts, edge_stats[:, 0]) - np.testing.assert_array_almost_equal(ck_edge_span, edge_stats[:, 1]) - np.testing.assert_array_equal(ck_muts_edge, muts_edge) - - class TestBlockSingletons: @staticmethod def naive_block_singletons(ts, individual): diff --git a/tests/test_rescaling.py b/tests/test_rescaling.py new file mode 100644 index 00000000..716cc513 --- /dev/null +++ b/tests/test_rescaling.py @@ -0,0 +1,259 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# Copyright (c) 2020-21 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for the gamma-variational approximations in tsdate +""" +from collections import defaultdict + +import msprime +import numpy as np +import pytest +import tsinfer +import tskit + +import tsdate +from tsdate.rescaling import count_mutations +from tsdate.rescaling import count_sizebiased +from tsdate.rescaling import edge_sampling_weight +from tsdate.rescaling import mutational_area + + +@pytest.fixture(scope="session") +def inferred_ts(): + ts = msprime.sim_ancestry( + 10, + population_size=1e4, + recombination_rate=1e-8, + sequence_length=1e6, + random_seed=1, + ) + ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) + sample_data = tsinfer.SampleData.from_tree_sequence(ts) + inferred_ts = tsinfer.infer(sample_data).simplify() + inferred_ts = tsdate.date(inferred_ts, mutation_rate=1e-8, max_iterations=2) + return inferred_ts + + +# TODO: delete, methodology is flawed +class TestEdgeSamplingWeight: + @staticmethod + def naive_edge_sampling_weight(ts): + out = np.zeros(ts.num_edges) + tot = 0.0 + for t in ts.trees(): + if t.num_edges == 0: + continue + tot += t.num_samples() * t.span + for n in t.nodes(): + e = t.edge(n) + if e == tskit.NULL: + continue + out[e] += t.num_samples(n) * t.span + out /= tot + return out + + def test_edge_sampling_weight(self, inferred_ts): + ts = inferred_ts + is_leaf = np.full(ts.num_nodes, False) + is_leaf[list(ts.samples())] = True + edges_weight = edge_sampling_weight( + is_leaf, + ts.edges_parent, + ts.edges_child, + ts.edges_left, + ts.edges_right, + ts.indexes_edge_insertion_order, + ts.indexes_edge_removal_order, + ) + ck_edges_weight = self.naive_edge_sampling_weight(ts) + np.testing.assert_allclose(edges_weight, ck_edges_weight) + + +class TestMutationalArea: + """ + Test tallying of mutational area within inter-node time intervals. + """ + + @staticmethod + def naive_mutational_area(ts): + """ + Count muts/area in each inter-node interval + """ + edges_muts = np.zeros(ts.num_edges) + for m in ts.mutations(): + if m.edge != tskit.NULL: + edges_muts[m.edge] += 1.0 + unique_node_times, node_map = np.unique(ts.nodes_time, return_inverse=True) + area = np.zeros(unique_node_times.size - 1) + muts = np.zeros(unique_node_times.size - 1) + for edge in ts.edges(): + p = node_map[edge.parent] + c = node_map[edge.child] + length = ts.nodes_time[edge.parent] - ts.nodes_time[edge.child] + width = edge.right - edge.left + area[c:p] += width + muts[c:p] += edges_muts[edge.id] / length + return muts, area, np.diff(unique_node_times) + + @staticmethod + def naive_total_path_area(ts): + """ + Count total area of all paths from samples to roots, + and all mutations on these paths. + """ + muts, span = 0.0, 0.0 + for t in ts.trees(): + if t.num_edges == 0: + continue + mut_node = defaultdict(int) + for m in t.mutations(): + mut_node[m.node] += 1 + for c in t.nodes(): + p = t.parent(c) + if t.parent(c) == tskit.NULL: + continue + weight = t.num_samples(c) + length = t.time(p) - t.time(c) + muts += mut_node[c] * weight + span += t.span * length * weight + return muts, span + + def test_total_mutational_area(self, inferred_ts): + ts = inferred_ts + likelihoods, _ = count_mutations(ts) + epoch_muts, epoch_span, epoch_duration = mutational_area( + ts.nodes_time, + likelihoods, + ts.edges_parent, + ts.edges_child, + ) + segsite = np.sum(epoch_span * epoch_duration) + ck_segsite = ts.segregating_sites(mode="branch", span_normalise=False) + assert np.isclose(segsite, ck_segsite) + totmuts = np.sum(epoch_muts * epoch_duration) + ck_totmuts = ts.segregating_sites(mode="site", span_normalise=False) + assert np.isclose(totmuts, ck_totmuts) + + def test_total_path_area(self, inferred_ts): + ts = inferred_ts + constraints = np.zeros((ts.num_nodes, 2)) + constraints[:, 1] = np.inf + constraints[list(ts.samples())] = 0.0 + likelihoods, _ = tsdate.rescaling.count_sizebiased(ts, constraints) + epoch_muts, epoch_span, epoch_duration = mutational_area( + ts.nodes_time, + likelihoods, + ts.edges_parent, + ts.edges_child, + ) + totmuts = np.sum(epoch_muts * epoch_duration) + patharea = np.sum(epoch_span * epoch_duration) + ck_totmuts, ck_patharea = self.naive_total_path_area(ts) + assert np.isclose(totmuts, ck_totmuts) + assert np.isclose(patharea, ck_patharea) + + def test_vs_naive(self, inferred_ts): + ts = inferred_ts + likelihoods, _ = count_mutations(inferred_ts) + epoch_muts, epoch_span, epoch_duration = mutational_area( + ts.nodes_time, + likelihoods, + ts.edges_parent, + ts.edges_child, + ) + ck_muts, ck_span, ck_duration = self.naive_mutational_area(ts) + np.testing.assert_allclose(epoch_muts, ck_muts) + np.testing.assert_allclose(epoch_span, ck_span) + np.testing.assert_allclose(epoch_duration, ck_duration) + + # TODO: for count mutations variants: + # def test_masked_mutations(...): + # def test_masked_samples(...): + + +class TestCountMutations: + """ + Test tallying of mutations on edges + """ + + def test_count_mutations(self, inferred_ts): + constraints = np.zeros((inferred_ts.num_nodes, 2)) + constraints[:, 1] = np.inf + constraints[list(inferred_ts.samples()), 0] = 0.0 + edge_stats, muts_edge = count_mutations(inferred_ts, constraints) + ck_edge_muts = np.zeros(inferred_ts.num_edges) + ck_muts_edge = np.full(inferred_ts.num_mutations, tskit.NULL) + for m in inferred_ts.mutations(): + if m.edge != tskit.NULL: + ck_edge_muts[m.edge] += 1.0 + ck_muts_edge[m.id] = m.edge + ck_edge_span = inferred_ts.edges_right - inferred_ts.edges_left + np.testing.assert_array_almost_equal(ck_edge_muts, edge_stats[:, 0]) + np.testing.assert_array_almost_equal(ck_edge_span, edge_stats[:, 1]) + np.testing.assert_array_equal(ck_muts_edge, muts_edge) + + +class TestCountSizeBiased: + """ + Count sized-biased mutations and edge area. E.g. weighting the contribution + from each tree by the number of samples subtended by a mutation or edge. + """ + + @staticmethod + def naive_count_sizebiased(ts): + muts_edge = np.full(ts.num_mutations, tskit.NULL) + edge_muts = np.zeros(ts.num_edges) + edge_span = np.zeros(ts.num_edges) + for m in ts.mutations(): + if m.edge != tskit.NULL: + muts_edge[m.id] = m.edge + for t in ts.trees(): + if t.num_edges == 0: + continue + for m in t.mutations(): + e = t.edge(m.node) + if e == tskit.NULL: + continue + edge_muts[e] += t.num_samples(m.node) + for n in t.nodes(): + e = t.edge(n) + if e == tskit.NULL: + continue + edge_span[e] += t.span * t.num_samples(n) + return np.column_stack([edge_muts, edge_span]), muts_edge + + def test_count_sizebiased(self, inferred_ts): + constraints = np.zeros((inferred_ts.num_nodes, 2)) + constraints[:, 1] = np.inf + constraints[list(inferred_ts.samples())] = 0.0 + edge_stats, muts_edge = count_sizebiased(inferred_ts, constraints) + ck_edge_stats, ck_muts_edge = self.naive_count_sizebiased(inferred_ts) + np.testing.assert_array_almost_equal(ck_edge_stats, edge_stats) + np.testing.assert_array_equal(ck_muts_edge, muts_edge) + + @pytest.mark.skip("Ancestral samples not implemented") + def test_ancestral_samples(self, inferred_ts): + # TODO: if there are ancestral samples, these should not be used as weights. + # test when ancestral samples are fully implemented. + return diff --git a/tsdate/evaluation.py b/tsdate/evaluation.py index 5a95ea92..0e9f03db 100644 --- a/tsdate/evaluation.py +++ b/tsdate/evaluation.py @@ -34,8 +34,8 @@ import scipy.sparse import tskit -from .phasing import count_mutations from .phasing import mutation_frequency +from .rescaling import count_mutations class CladeMap: @@ -544,7 +544,7 @@ def mutations_time( max_freq=None, plotpath=None, title=None, - subtending_node=False, + what=0, ): """ Return true and inferred mutation ages, optionally creating a scatterplot and @@ -588,9 +588,11 @@ def mutations_time( infr_mut = infr_mut[is_freq] true_mut = true_mut[is_freq] # get age of mutation or subtended node - if subtending_node: - infr_node = infer_ts.mutations_node[infr_mut] - true_node = ts.mutations_node[true_mut] + if what == 1: + infr_node = infer_ts.edges_child[infr_edge[infr_mut]] + assert np.allclose(infr_node, infer_ts.mutations_node[infr_mut]) + true_node = ts.edges_child[true_edge[true_mut]] + assert np.allclose(true_node, ts.mutations_node[true_mut]) _, uniq_idx = np.unique(infr_node, return_index=True) infr_node = infr_node[uniq_idx] true_node = true_node[uniq_idx] @@ -602,7 +604,22 @@ def mutations_time( nonzero = np.logical_and(mean > 0, truth > 0) mean = mean[nonzero] truth = truth[nonzero] - else: # midpoint on branch + elif what == 2: + infr_node = infer_ts.edges_parent[infr_edge[infr_mut]] + true_node = ts.edges_parent[true_edge[true_mut]] + _, uniq_idx = np.unique(infr_node, return_index=True) + infr_node = infr_node[uniq_idx] + true_node = true_node[uniq_idx] + _, uniq_idx = np.unique(true_node, return_index=True) + infr_node = infr_node[uniq_idx] + true_node = true_node[uniq_idx] + mean = infer_ts.nodes_time[infr_node] + truth = ts.nodes_time[true_node] + nonzero = np.logical_and(mean > 0, truth > 0) + mean = mean[nonzero] + truth = truth[nonzero] + elif what == 0: # midpoint on branch + # TODO clean up infr_p = infer_ts.edges_parent[infr_edge[infr_mut]] true_p = ts.edges_parent[true_edge[true_mut]] infr_c = infer_ts.edges_child[infr_edge[infr_mut]] @@ -612,6 +629,8 @@ def mutations_time( else: mean = (infer_ts.nodes_time[infr_p] + infer_ts.nodes_time[infr_c]) / 2 truth = (ts.nodes_time[true_p] + ts.nodes_time[true_c]) / 2 + else: + raise ValueError("Invalid choice of `what`") if plotpath is not None: rsq = np.corrcoef(np.log10(mean), np.log10(truth))[0, 1] ** 2 bias = np.mean(np.log10(mean) - np.log10(truth)) @@ -621,7 +640,7 @@ def mutations_time( plt.hexbin(truth, mean, xscale="log", yscale="log", mincnt=1) plt.text(0.01, 0.99, info, ha="left", va="top", transform=plt.gca().transAxes) plt.axline(pt1, pt2, linestyle="--", color="firebrick") - if subtending_node: + if what != 0: plt.xlabel("True node age") plt.ylabel("Estimated node age") else: diff --git a/tsdate/phasing.py b/tsdate/phasing.py index 3e4d9273..f79dcd89 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -213,100 +213,6 @@ def block_singletons(ts, individuals_unphased): ) -@numba.njit( - _tuple((_f2w, _i1w))(_i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _i, _f) -) -def _count_mutations( - mutations_node, - mutations_position, - edges_parent, - edges_child, - edges_left, - edges_right, - indexes_insert, - indexes_remove, - num_nodes, - sequence_length, -): - """ - TODO - """ - assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size - assert indexes_insert.size == indexes_remove.size == edges_parent.size - assert mutations_node.size == mutations_position.size - - num_mutations = mutations_node.size - num_edges = edges_parent.size - - indexes_mutation = np.argsort(mutations_position) - position_insert = edges_left[indexes_insert] - position_remove = edges_right[indexes_remove] - position_mutation = mutations_position[indexes_mutation] - - nodes_edge = np.full(num_nodes, tskit.NULL) - mutations_edge = np.full(num_mutations, tskit.NULL) - edges_mutations = np.zeros(num_edges) - edges_span = edges_right - edges_left - - left = 0.0 - a, b, d = 0, 0, 0 - while a < num_edges or b < num_edges: - while b < num_edges and position_remove[b] == left: # edges out - e = indexes_remove[b] - c = edges_child[e] - nodes_edge[c] = tskit.NULL - b += 1 - - while a < num_edges and position_insert[a] == left: # edges in - e = indexes_insert[a] - c = edges_child[e] - nodes_edge[c] = e - a += 1 - - right = sequence_length - if b < num_edges: - right = min(right, position_remove[b]) - if a < num_edges: - right = min(right, position_insert[a]) - left = right - - while d < num_mutations and position_mutation[d] < right: - m = indexes_mutation[d] - c = mutations_node[m] - e = nodes_edge[c] - if e != tskit.NULL: - mutations_edge[m] = e - edges_mutations[e] += 1.0 - d += 1 - - mutations_edge = mutations_edge.astype(np.int32) - edges_stats = np.column_stack((edges_mutations, edges_span)) - - return edges_stats, mutations_edge - - -def count_mutations(ts): - """ - TODO - """ - # TODO: adjust spans by an accessibility mask - return _count_mutations( - ts.mutations_node, - ts.sites_position[ts.mutations_site], - ts.edges_parent, - ts.edges_child, - ts.edges_left, - ts.edges_right, - ts.indexes_edge_insertion_order, - ts.indexes_edge_removal_order, - ts.num_nodes, - ts.sequence_length, - ) - - -# def mutations_node(mutations_block, mutations_phase, blocks_nodes): - - @numba.njit(_i2w(_b2r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f)) def _mutation_frequency( nodes_sample, diff --git a/tsdate/rescaling.py b/tsdate/rescaling.py index a1bcec84..b0bd2c3f 100644 --- a/tsdate/rescaling.py +++ b/tsdate/rescaling.py @@ -28,7 +28,6 @@ import numba import numpy as np import tskit -from numba.types import UniTuple as _unituple from .approx import _b from .approx import _b1r @@ -40,6 +39,8 @@ from .approx import _i from .approx import _i1r from .approx import _i1w +from .approx import _tuple +from .approx import _unituple from .approx import approximate_gamma_iqr from .hypergeo import _gammainc_inv as gammainc_inv from .util import mutation_span_array # NOQA: F401 @@ -110,6 +111,282 @@ def f(i, j): # loss return breaks +@numba.njit( + _tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f) +) +def _count_mutations( + node_is_leaf, + mutations_node, + mutations_position, + edges_parent, + edges_child, + edges_left, + edges_right, + indexes_insert, + indexes_remove, + sequence_length, +): + """ + Internals for `count_mutations` + """ + assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size + assert indexes_insert.size == indexes_remove.size == edges_parent.size + assert mutations_node.size == mutations_position.size + + num_mutations = mutations_node.size + num_edges = edges_parent.size + num_nodes = node_is_leaf.size + + indexes_mutation = np.argsort(mutations_position) + position_insert = edges_left[indexes_insert] + position_remove = edges_right[indexes_remove] + position_mutation = mutations_position[indexes_mutation] + + nodes_edge = np.full(num_nodes, tskit.NULL) + mutations_edge = np.full(num_mutations, tskit.NULL) + edges_mutations = np.zeros(num_edges) + edges_span = edges_right - edges_left + + left = 0.0 + a, b, d = 0, 0, 0 + while a < num_edges or b < num_edges: + while b < num_edges and position_remove[b] == left: # edges out + e = indexes_remove[b] + c = edges_child[e] + nodes_edge[c] = tskit.NULL + b += 1 + + while a < num_edges and position_insert[a] == left: # edges in + e = indexes_insert[a] + c = edges_child[e] + nodes_edge[c] = e + a += 1 + + right = sequence_length + if b < num_edges: + right = min(right, position_remove[b]) + if a < num_edges: + right = min(right, position_insert[a]) + left = right + + while d < num_mutations and position_mutation[d] < right: + m = indexes_mutation[d] + c = mutations_node[m] + e = nodes_edge[c] + if e != tskit.NULL: + mutations_edge[m] = e + edges_mutations[e] += 1.0 + d += 1 + + mutations_edge = mutations_edge.astype(np.int32) + edges_stats = np.column_stack((edges_mutations, edges_span)) + + return edges_stats, mutations_edge + + +def count_mutations(ts, constraints=None): + """ + Return an array with `num_edges` rows, and columns that are the number of + mutations per edge and the total span per edge + """ + # TODO: adjust spans by an accessibility mask + if constraints is None: + node_is_leaf = np.full(ts.num_nodes, False) + node_is_leaf[list(ts.samples())] = True + else: + assert constraints.shape == (ts.num_nodes, 2) + node_is_leaf = np.logical_and( + constraints[:, 0] == 0.0, + constraints[:, 0] == constraints[:, 1], + ) + return _count_mutations( + node_is_leaf, + ts.mutations_node, + ts.sites_position[ts.mutations_site], + ts.edges_parent, + ts.edges_child, + ts.edges_left, + ts.edges_right, + ts.indexes_edge_insertion_order, + ts.indexes_edge_removal_order, + ts.sequence_length, + ) + + +# TODO: similar enough to count_mutations to combine, with adequate testing +@numba.njit( + _tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f) +) +def _count_sizebiased( + node_is_leaf, + mutations_node, + mutations_position, + edges_parent, + edges_child, + edges_left, + edges_right, + indexes_insert, + indexes_remove, + sequence_length, +): + """ + Internals for `count_sizebiased` + """ + assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size + assert indexes_insert.size == indexes_remove.size == edges_parent.size + assert mutations_node.size == mutations_position.size + + num_mutations = mutations_node.size + num_edges = edges_parent.size + num_nodes = node_is_leaf.size + + indexes_mutation = np.argsort(mutations_position) + position_insert = edges_left[indexes_insert] + position_remove = edges_right[indexes_remove] + position_mutation = mutations_position[indexes_mutation] + + nodes_samples = np.zeros(num_nodes) + nodes_edge = np.full(num_nodes, tskit.NULL) + nodes_parent = np.full(num_nodes, tskit.NULL) + mutations_edge = np.full(num_mutations, tskit.NULL) + edges_mutations = np.zeros(num_edges) + edges_span = np.zeros(num_edges) + + nodes_samples[node_is_leaf] = 1.0 + left = 0.0 + a, b, d = 0, 0, 0 + while a < num_edges or b < num_edges: + remainder = sequence_length - left + + while b < num_edges and position_remove[b] == left: # edges out + e = indexes_remove[b] + p, c = edges_parent[e], edges_child[e] + nodes_edge[c] = tskit.NULL + nodes_parent[c] = tskit.NULL + while p != tskit.NULL: # downdate sample counts + edges_span[e] -= nodes_samples[c] * remainder + nodes_samples[p] -= nodes_samples[c] + e, p = nodes_edge[p], nodes_parent[p] + b += 1 + + while a < num_edges and position_insert[a] == left: # edges in + e = indexes_insert[a] + p, c = edges_parent[e], edges_child[e] + nodes_edge[c] = e + nodes_parent[c] = p + while p != tskit.NULL: # update sample counts + edges_span[e] += nodes_samples[c] * remainder + nodes_samples[p] += nodes_samples[c] + e, p = nodes_edge[p], nodes_parent[p] + a += 1 + + right = sequence_length + if b < num_edges: + right = min(right, position_remove[b]) + if a < num_edges: + right = min(right, position_insert[a]) + left = right + + while d < num_mutations and position_mutation[d] < right: + m = indexes_mutation[d] + c = mutations_node[m] + e = nodes_edge[c] + if e != tskit.NULL: + mutations_edge[m] = e + edges_mutations[e] += nodes_samples[c] + d += 1 + + mutations_edge = mutations_edge.astype(np.int32) + edges_stats = np.column_stack((edges_mutations, edges_span)) + + return edges_stats, mutations_edge + + +def count_sizebiased(ts, constraints): + """ + Return an array with `num_edges` rows, and columns that are the number of + mutations per edge and the total span per edge. If `size_biased` is `True`, + then mutations and edges are weighted by frequency. + + Note that weighting edges by frequency is done tree-by-tree. + """ + # TODO: adjust spans by an accessibility mask + assert constraints.shape == (ts.num_nodes, 2) + node_is_leaf = np.logical_and( + constraints[:, 0] == 0.0, + constraints[:, 0] == constraints[:, 1], + ) + return _count_sizebiased( + node_is_leaf, + ts.mutations_node, + ts.sites_position[ts.mutations_site], + ts.edges_parent, + ts.edges_child, + ts.edges_left, + ts.edges_right, + ts.indexes_edge_insertion_order, + ts.indexes_edge_removal_order, + ts.sequence_length, + ) + + +@numba.njit(_unituple(_f1w, 3)(_f1r, _f2r, _i1r, _i1r)) +def mutational_area( + nodes_time, + likelihoods, + edges_parent, + edges_child, +): + """ + Calculate the total number of mutations and mutational area per inter-node + interval. + + :param np.ndarray nodes_time: point estimates for node ages + :param np.ndarray likelihoods: edges are rows; mutation + counts and mutational span are columns + :param np.ndarray edges_parent: node index for the parent of each edge + :param np.ndarray edges_child: node index for the child of each edge + :param np.ndarray edges_weight: a weight for each edge + """ + + assert edges_parent.size == edges_child.size + assert likelihoods.shape == (edges_parent.size, 2) + + # index node by unique time breaks + nodes_order = np.argsort(nodes_time) + nodes_index = np.zeros(nodes_time.size, dtype=np.int32) + epoch_breaks = [0.0] + k = 0 + for i, j in zip(nodes_order[1:], nodes_order[:-1]): + if nodes_time[i] > nodes_time[j]: + epoch_breaks.append(nodes_time[i]) + k += 1 + nodes_index[i] = k + epoch_breaks = np.array(epoch_breaks) + epoch_length = np.diff(epoch_breaks) + num_epochs = epoch_length.size + + # instantaneous mutation rate per edge + edges_length = nodes_time[edges_parent] - nodes_time[edges_child] + edges_subset = edges_length > 0 + edges_counts = likelihoods.copy() + edges_counts[edges_subset, 0] /= edges_length[edges_subset] + + # pass over edges, measuring overlap with each time interval + epoch_counts = np.zeros((num_epochs, 2)) + for e in np.flatnonzero(edges_subset): + p, c = edges_parent[e], edges_child[e] + a, b = nodes_index[c], nodes_index[p] + if a < num_epochs: + epoch_counts[a] += edges_counts[e] + if b < num_epochs: + epoch_counts[b] -= edges_counts[e] + counts = np.cumsum(epoch_counts[:, 0]) + offset = np.cumsum(epoch_counts[:, 1]) + + return counts, offset, epoch_length + + @numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i)) def mutational_timescale( nodes_time, diff --git a/tsdate/variational.py b/tsdate/variational.py index ac82737a..46896453 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -47,8 +47,8 @@ from .approx import _i2r from .hypergeo import _gammainc_inv as gammainc_inv from .phasing import block_singletons -from .phasing import count_mutations from .phasing import reallocate_unphased +from .rescaling import count_mutations from .rescaling import edge_sampling_weight from .rescaling import mutational_timescale from .rescaling import piecewise_scale_posterior @@ -241,7 +241,8 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): # count mutations on edges count_timing = time.time() - self.edge_likelihoods, self.mutation_edges = count_mutations(ts) + self.edge_likelihoods, self.mutation_edges = \ + count_mutations(ts, self.node_constraints) # fmt: skip self.edge_likelihoods[:, 1] *= mutation_rate count_timing -= time.time() logging.info(f"Extracted mutations in {abs(count_timing)} seconds") From 7ebc1540dde776131e31a5a3a9dcba5cd5c6b223 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 17 Jun 2024 16:13:05 -0700 Subject: [PATCH 19/29] Mutational timescale --- tests/test_rescaling.py | 71 +++---- tsdate/rescaling.py | 402 ++++++++++++++++++++++++---------------- tsdate/variational.py | 3 +- 3 files changed, 269 insertions(+), 207 deletions(-) diff --git a/tests/test_rescaling.py b/tests/test_rescaling.py index 716cc513..b5944ee2 100644 --- a/tests/test_rescaling.py +++ b/tests/test_rescaling.py @@ -34,7 +34,6 @@ import tsdate from tsdate.rescaling import count_mutations -from tsdate.rescaling import count_sizebiased from tsdate.rescaling import edge_sampling_weight from tsdate.rescaling import mutational_area @@ -114,7 +113,7 @@ def naive_mutational_area(ts): width = edge.right - edge.left area[c:p] += width muts[c:p] += edges_muts[edge.id] / length - return muts, area, np.diff(unique_node_times) + return muts, area, np.diff(unique_node_times), node_map @staticmethod def naive_total_path_area(ts): @@ -142,7 +141,7 @@ def naive_total_path_area(ts): def test_total_mutational_area(self, inferred_ts): ts = inferred_ts likelihoods, _ = count_mutations(ts) - epoch_muts, epoch_span, epoch_duration = mutational_area( + epoch_muts, epoch_span, epoch_duration, _ = mutational_area( ts.nodes_time, likelihoods, ts.edges_parent, @@ -157,11 +156,8 @@ def test_total_mutational_area(self, inferred_ts): def test_total_path_area(self, inferred_ts): ts = inferred_ts - constraints = np.zeros((ts.num_nodes, 2)) - constraints[:, 1] = np.inf - constraints[list(ts.samples())] = 0.0 - likelihoods, _ = tsdate.rescaling.count_sizebiased(ts, constraints) - epoch_muts, epoch_span, epoch_duration = mutational_area( + likelihoods, _ = tsdate.rescaling.count_mutations(ts, size_biased=True) + epoch_muts, epoch_span, epoch_duration, _ = mutational_area( ts.nodes_time, likelihoods, ts.edges_parent, @@ -176,20 +172,17 @@ def test_total_path_area(self, inferred_ts): def test_vs_naive(self, inferred_ts): ts = inferred_ts likelihoods, _ = count_mutations(inferred_ts) - epoch_muts, epoch_span, epoch_duration = mutational_area( + epoch_muts, epoch_span, epoch_duration, node_index = mutational_area( ts.nodes_time, likelihoods, ts.edges_parent, ts.edges_child, ) - ck_muts, ck_span, ck_duration = self.naive_mutational_area(ts) + ck_muts, ck_span, ck_duration, ck_index = self.naive_mutational_area(ts) np.testing.assert_allclose(epoch_muts, ck_muts) np.testing.assert_allclose(epoch_span, ck_span) np.testing.assert_allclose(epoch_duration, ck_duration) - - # TODO: for count mutations variants: - # def test_masked_mutations(...): - # def test_masked_samples(...): + np.testing.assert_allclose(node_index, ck_index) class TestCountMutations: @@ -197,28 +190,16 @@ class TestCountMutations: Test tallying of mutations on edges """ - def test_count_mutations(self, inferred_ts): - constraints = np.zeros((inferred_ts.num_nodes, 2)) - constraints[:, 1] = np.inf - constraints[list(inferred_ts.samples()), 0] = 0.0 - edge_stats, muts_edge = count_mutations(inferred_ts, constraints) - ck_edge_muts = np.zeros(inferred_ts.num_edges) - ck_muts_edge = np.full(inferred_ts.num_mutations, tskit.NULL) - for m in inferred_ts.mutations(): + @staticmethod + def naive_count_mutations(ts): + edge_muts = np.zeros(ts.num_edges) + muts_edge = np.full(ts.num_mutations, tskit.NULL) + for m in ts.mutations(): if m.edge != tskit.NULL: - ck_edge_muts[m.edge] += 1.0 - ck_muts_edge[m.id] = m.edge - ck_edge_span = inferred_ts.edges_right - inferred_ts.edges_left - np.testing.assert_array_almost_equal(ck_edge_muts, edge_stats[:, 0]) - np.testing.assert_array_almost_equal(ck_edge_span, edge_stats[:, 1]) - np.testing.assert_array_equal(ck_muts_edge, muts_edge) - - -class TestCountSizeBiased: - """ - Count sized-biased mutations and edge area. E.g. weighting the contribution - from each tree by the number of samples subtended by a mutation or edge. - """ + edge_muts[m.edge] += 1.0 + muts_edge[m.id] = m.edge + edge_span = ts.edges_right - ts.edges_left + return np.column_stack([edge_muts, edge_span]), muts_edge @staticmethod def naive_count_sizebiased(ts): @@ -243,17 +224,25 @@ def naive_count_sizebiased(ts): edge_span[e] += t.span * t.num_samples(n) return np.column_stack([edge_muts, edge_span]), muts_edge + def test_count_mutations(self, inferred_ts): + edge_stats, muts_edge = count_mutations(inferred_ts) + ck_edge_stats, ck_muts_edge = self.naive_count_mutations(inferred_ts) + np.testing.assert_array_almost_equal(ck_edge_stats, edge_stats) + np.testing.assert_array_equal(ck_muts_edge, muts_edge) + def test_count_sizebiased(self, inferred_ts): - constraints = np.zeros((inferred_ts.num_nodes, 2)) - constraints[:, 1] = np.inf - constraints[list(inferred_ts.samples())] = 0.0 - edge_stats, muts_edge = count_sizebiased(inferred_ts, constraints) + edge_stats, muts_edge = count_mutations(inferred_ts, size_biased=True) ck_edge_stats, ck_muts_edge = self.naive_count_sizebiased(inferred_ts) np.testing.assert_array_almost_equal(ck_edge_stats, edge_stats) np.testing.assert_array_equal(ck_muts_edge, muts_edge) - @pytest.mark.skip("Ancestral samples not implemented") - def test_ancestral_samples(self, inferred_ts): + @pytest.mark.skip("Ancient samples not implemented") + def test_count_sizebiased_with_ancient(self, inferred_ts): # TODO: if there are ancestral samples, these should not be used as weights. # test when ancestral samples are fully implemented. return + + @pytest.mark.skip("Accessibility mask not implemented") + def test_count_mutations_with_accessible(self, inferred_ts): + # TODO + return diff --git a/tsdate/rescaling.py b/tsdate/rescaling.py index b0bd2c3f..9ce66d0b 100644 --- a/tsdate/rescaling.py +++ b/tsdate/rescaling.py @@ -111,114 +111,113 @@ def f(i, j): # loss return breaks -@numba.njit( - _tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f) -) -def _count_mutations( - node_is_leaf, - mutations_node, - mutations_position, - edges_parent, - edges_child, - edges_left, - edges_right, - indexes_insert, - indexes_remove, - sequence_length, -): - """ - Internals for `count_mutations` - """ - assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size - assert indexes_insert.size == indexes_remove.size == edges_parent.size - assert mutations_node.size == mutations_position.size - - num_mutations = mutations_node.size - num_edges = edges_parent.size - num_nodes = node_is_leaf.size - - indexes_mutation = np.argsort(mutations_position) - position_insert = edges_left[indexes_insert] - position_remove = edges_right[indexes_remove] - position_mutation = mutations_position[indexes_mutation] - - nodes_edge = np.full(num_nodes, tskit.NULL) - mutations_edge = np.full(num_mutations, tskit.NULL) - edges_mutations = np.zeros(num_edges) - edges_span = edges_right - edges_left - - left = 0.0 - a, b, d = 0, 0, 0 - while a < num_edges or b < num_edges: - while b < num_edges and position_remove[b] == left: # edges out - e = indexes_remove[b] - c = edges_child[e] - nodes_edge[c] = tskit.NULL - b += 1 - - while a < num_edges and position_insert[a] == left: # edges in - e = indexes_insert[a] - c = edges_child[e] - nodes_edge[c] = e - a += 1 - - right = sequence_length - if b < num_edges: - right = min(right, position_remove[b]) - if a < num_edges: - right = min(right, position_insert[a]) - left = right - - while d < num_mutations and position_mutation[d] < right: - m = indexes_mutation[d] - c = mutations_node[m] - e = nodes_edge[c] - if e != tskit.NULL: - mutations_edge[m] = e - edges_mutations[e] += 1.0 - d += 1 - - mutations_edge = mutations_edge.astype(np.int32) - edges_stats = np.column_stack((edges_mutations, edges_span)) - - return edges_stats, mutations_edge - - -def count_mutations(ts, constraints=None): - """ - Return an array with `num_edges` rows, and columns that are the number of - mutations per edge and the total span per edge - """ - # TODO: adjust spans by an accessibility mask - if constraints is None: - node_is_leaf = np.full(ts.num_nodes, False) - node_is_leaf[list(ts.samples())] = True - else: - assert constraints.shape == (ts.num_nodes, 2) - node_is_leaf = np.logical_and( - constraints[:, 0] == 0.0, - constraints[:, 0] == constraints[:, 1], - ) - return _count_mutations( - node_is_leaf, - ts.mutations_node, - ts.sites_position[ts.mutations_site], - ts.edges_parent, - ts.edges_child, - ts.edges_left, - ts.edges_right, - ts.indexes_edge_insertion_order, - ts.indexes_edge_removal_order, - ts.sequence_length, - ) +#@numba.njit( +# _tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f) +#) +#def _count_mutations( +# node_is_leaf, +# mutations_node, +# mutations_position, +# edges_parent, +# edges_child, +# edges_left, +# edges_right, +# indexes_insert, +# indexes_remove, +# sequence_length, +#): +# """ +# Internals for `count_mutations` +# """ +# assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size +# assert indexes_insert.size == indexes_remove.size == edges_parent.size +# assert mutations_node.size == mutations_position.size +# +# num_mutations = mutations_node.size +# num_edges = edges_parent.size +# num_nodes = node_is_leaf.size +# +# indexes_mutation = np.argsort(mutations_position) +# position_insert = edges_left[indexes_insert] +# position_remove = edges_right[indexes_remove] +# position_mutation = mutations_position[indexes_mutation] +# +# nodes_edge = np.full(num_nodes, tskit.NULL) +# mutations_edge = np.full(num_mutations, tskit.NULL) +# edges_mutations = np.zeros(num_edges) +# edges_span = edges_right - edges_left +# +# left = 0.0 +# a, b, d = 0, 0, 0 +# while a < num_edges or b < num_edges: +# while b < num_edges and position_remove[b] == left: # edges out +# e = indexes_remove[b] +# c = edges_child[e] +# nodes_edge[c] = tskit.NULL +# b += 1 +# +# while a < num_edges and position_insert[a] == left: # edges in +# e = indexes_insert[a] +# c = edges_child[e] +# nodes_edge[c] = e +# a += 1 +# +# right = sequence_length +# if b < num_edges: +# right = min(right, position_remove[b]) +# if a < num_edges: +# right = min(right, position_insert[a]) +# left = right +# +# while d < num_mutations and position_mutation[d] < right: +# m = indexes_mutation[d] +# c = mutations_node[m] +# e = nodes_edge[c] +# if e != tskit.NULL: +# mutations_edge[m] = e +# edges_mutations[e] += 1.0 +# d += 1 +# +# mutations_edge = mutations_edge.astype(np.int32) +# edges_stats = np.column_stack((edges_mutations, edges_span)) +# +# return edges_stats, mutations_edge +# +# +#def count_mutations(ts, constraints=None): +# """ +# Return an array with `num_edges` rows, and columns that are the number of +# mutations per edge and the total span per edge +# """ +# # TODO: adjust spans by an accessibility mask +# if constraints is None: +# node_is_leaf = np.full(ts.num_nodes, False) +# node_is_leaf[list(ts.samples())] = True +# else: +# assert constraints.shape == (ts.num_nodes, 2) +# node_is_leaf = np.logical_and( +# constraints[:, 0] == 0.0, +# constraints[:, 0] == constraints[:, 1], +# ) +# return _count_mutations( +# node_is_leaf, +# ts.mutations_node, +# ts.sites_position[ts.mutations_site], +# ts.edges_parent, +# ts.edges_child, +# ts.edges_left, +# ts.edges_right, +# ts.indexes_edge_insertion_order, +# ts.indexes_edge_removal_order, +# ts.sequence_length, +# ) -# TODO: similar enough to count_mutations to combine, with adequate testing @numba.njit( - _tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f) + _tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f, _b) ) -def _count_sizebiased( - node_is_leaf, +def _count_mutations( + node_is_sample, mutations_node, mutations_position, edges_parent, @@ -228,6 +227,7 @@ def _count_sizebiased( indexes_insert, indexes_remove, sequence_length, + size_biased, ): """ Internals for `count_sizebiased` @@ -238,7 +238,7 @@ def _count_sizebiased( num_mutations = mutations_node.size num_edges = edges_parent.size - num_nodes = node_is_leaf.size + num_nodes = node_is_sample.size indexes_mutation = np.argsort(mutations_position) position_insert = edges_left[indexes_insert] @@ -252,7 +252,7 @@ def _count_sizebiased( edges_mutations = np.zeros(num_edges) edges_span = np.zeros(num_edges) - nodes_samples[node_is_leaf] = 1.0 + nodes_samples[node_is_sample] = 1.0 left = 0.0 a, b, d = 0, 0, 0 while a < num_edges or b < num_edges: @@ -263,10 +263,13 @@ def _count_sizebiased( p, c = edges_parent[e], edges_child[e] nodes_edge[c] = tskit.NULL nodes_parent[c] = tskit.NULL - while p != tskit.NULL: # downdate sample counts - edges_span[e] -= nodes_samples[c] * remainder - nodes_samples[p] -= nodes_samples[c] - e, p = nodes_edge[p], nodes_parent[p] + if size_biased: + while p != tskit.NULL: # downdate sample counts + edges_span[e] -= nodes_samples[c] * remainder + nodes_samples[p] -= nodes_samples[c] + e, p = nodes_edge[p], nodes_parent[p] + else: + edges_span[e] -= remainder b += 1 while a < num_edges and position_insert[a] == left: # edges in @@ -274,10 +277,13 @@ def _count_sizebiased( p, c = edges_parent[e], edges_child[e] nodes_edge[c] = e nodes_parent[c] = p - while p != tskit.NULL: # update sample counts - edges_span[e] += nodes_samples[c] * remainder - nodes_samples[p] += nodes_samples[c] - e, p = nodes_edge[p], nodes_parent[p] + if size_biased: + while p != tskit.NULL: # update sample counts + edges_span[e] += nodes_samples[c] * remainder + nodes_samples[p] += nodes_samples[c] + e, p = nodes_edge[p], nodes_parent[p] + else: + edges_span[e] += remainder a += 1 right = sequence_length @@ -293,7 +299,7 @@ def _count_sizebiased( e = nodes_edge[c] if e != tskit.NULL: mutations_edge[m] = e - edges_mutations[e] += nodes_samples[c] + edges_mutations[e] += nodes_samples[c] if size_biased else 1.0 d += 1 mutations_edge = mutations_edge.astype(np.int32) @@ -302,7 +308,7 @@ def _count_sizebiased( return edges_stats, mutations_edge -def count_sizebiased(ts, constraints): +def count_mutations(ts, node_is_sample=None, size_biased=False): """ Return an array with `num_edges` rows, and columns that are the number of mutations per edge and the total span per edge. If `size_biased` is `True`, @@ -310,14 +316,18 @@ def count_sizebiased(ts, constraints): Note that weighting edges by frequency is done tree-by-tree. """ - # TODO: adjust spans by an accessibility mask - assert constraints.shape == (ts.num_nodes, 2) - node_is_leaf = np.logical_and( - constraints[:, 0] == 0.0, - constraints[:, 0] == constraints[:, 1], - ) - return _count_sizebiased( - node_is_leaf, + # TODO: adjust spans by an accessibility mask: + # need to supply cumulative accessible sequence at each + # breakpoint + + if node_is_sample is None: + node_is_sample = np.full(ts.num_nodes, False) + node_is_sample[list(ts.samples())] = True + else: + assert node_is_sample.size != ts.num_nodes + + return _count_mutations( + node_is_sample, ts.mutations_node, ts.sites_position[ts.mutations_site], ts.edges_parent, @@ -327,10 +337,11 @@ def count_sizebiased(ts, constraints): ts.indexes_edge_insertion_order, ts.indexes_edge_removal_order, ts.sequence_length, + size_biased, ) -@numba.njit(_unituple(_f1w, 3)(_f1r, _f2r, _i1r, _i1r)) +@numba.njit(_tuple((_f1w, _f1w, _f1w, _i1w))(_f1r, _f2r, _i1r, _i1r)) def mutational_area( nodes_time, likelihoods, @@ -363,8 +374,7 @@ def mutational_area( k += 1 nodes_index[i] = k epoch_breaks = np.array(epoch_breaks) - epoch_length = np.diff(epoch_breaks) - num_epochs = epoch_length.size + num_epochs = epoch_breaks.size - 1 # instantaneous mutation rate per edge edges_length = nodes_time[edges_parent] - nodes_time[edges_child] @@ -383,8 +393,96 @@ def mutational_area( epoch_counts[b] -= edges_counts[e] counts = np.cumsum(epoch_counts[:, 0]) offset = np.cumsum(epoch_counts[:, 1]) - - return counts, offset, epoch_length + duration = np.diff(epoch_breaks) + + return counts, offset, duration, nodes_index + + +#@numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i)) +#def mutational_timescale( +# nodes_time, +# likelihoods, +# constraints, +# edges_parent, +# edges_child, +# edges_weight, +# max_intervals, +#): +# """ +# Rescale node ages so that the instantaneous mutation rate is constant. +# Edges with a negative duration are ignored when calculating the total +# rate. Returns a rescaled point estimate and the posterior. +# +# :param np.ndarray nodes_time: point estimates for node ages +# :param np.ndarray likelihoods: edges are rows; mutation +# counts and mutational span are columns +# :param np.ndarray constraints: lower and upper bounds on node age +# :param np.ndarray edges_parent: node index for the parent of each edge +# :param np.ndarray edges_child: node index for the child of each edge +# :param np.ndarray edges_weight: a weight for each edge +# :param int max_intervals: maximum number of intervals within which to +# estimate the time scaling +# """ +# +# assert edges_parent.size == edges_child.size == edges_weight.size +# assert likelihoods.shape[0] == edges_parent.size and likelihoods.shape[1] == 2 +# assert constraints.shape[0] == nodes_time.size and constraints.shape[1] == 2 +# assert max_intervals > 0 +# +# nodes_fixed = constraints[:, 0] == constraints[:, 1] +# assert np.all(nodes_time[nodes_fixed] == constraints[nodes_fixed, 0]) +# +# # index node by unique time breaks +# nodes_order = np.argsort(nodes_time) +# nodes_index = np.zeros(nodes_time.size, dtype=np.int32) +# epoch_breaks = [0.0] +# k = 0 +# for i, j in zip(nodes_order[1:], nodes_order[:-1]): +# if nodes_time[i] > nodes_time[j]: +# epoch_breaks.append(nodes_time[i]) +# k += 1 +# nodes_index[i] = k +# epoch_breaks = np.array(epoch_breaks) +# epoch_length = np.diff(epoch_breaks) +# num_epochs = epoch_length.size +# +# # instantaneous mutation rate per edge +# edges_length = nodes_time[edges_parent] - nodes_time[edges_child] +# edges_subset = edges_length > 0 +# edges_counts = likelihoods.copy() +# edges_counts[edges_subset, 0] /= edges_length[edges_subset] +# +# # pass over edges, measuring overlap with each time interval +# epoch_counts = np.zeros((num_epochs, 2)) +# for e in np.flatnonzero(edges_subset): +# p, c = edges_parent[e], edges_child[e] +# a, b = nodes_index[c], nodes_index[p] +# if a < num_epochs: +# epoch_counts[a] += edges_counts[e] * edges_weight[e] +# if b < num_epochs: +# epoch_counts[b] -= edges_counts[e] * edges_weight[e] +# counts = np.cumsum(epoch_counts[:, 0]) +# offset = np.cumsum(epoch_counts[:, 1]) +# +# # rescale time such that mutation density is constant between changepoints +# # TODO: use poisson changepoints to further refine +# changepoints = _fixed_changepoints(offset * epoch_length, max_intervals) +# changepoints = np.union1d(changepoints, nodes_index[nodes_fixed]) +# adjust = np.zeros(changepoints.size) +# k = 0 +# for i, j in zip(changepoints[:-1], changepoints[1:]): +# assert j > i +# # TODO: when changepoint intersects a fixed node? +# n = np.sum(offset[i:j]) +# y = np.sum(counts[i:j]) +# z = np.sum(epoch_length[i:j]) +# assert n > 0, "Zero edge span in interval" +# adjust[k + 1] = z * y / n +# k += 1 +# adjust = np.cumsum(adjust) +# origin = epoch_breaks[changepoints] +# +# return origin, adjust @numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i)) @@ -413,7 +511,7 @@ def mutational_timescale( estimate the time scaling """ - assert edges_parent.size == edges_child.size == edges_weight.size + assert edges_parent.size == edges_child.size assert likelihoods.shape[0] == edges_parent.size and likelihoods.shape[1] == 2 assert constraints.shape[0] == nodes_time.size and constraints.shape[1] == 2 assert max_intervals > 0 @@ -421,42 +519,18 @@ def mutational_timescale( nodes_fixed = constraints[:, 0] == constraints[:, 1] assert np.all(nodes_time[nodes_fixed] == constraints[nodes_fixed, 0]) - # index node by unique time breaks - nodes_order = np.argsort(nodes_time) - nodes_index = np.zeros(nodes_time.size, dtype=np.int32) - epoch_breaks = [0.0] - k = 0 - for i, j in zip(nodes_order[1:], nodes_order[:-1]): - if nodes_time[i] > nodes_time[j]: - epoch_breaks.append(nodes_time[i]) - k += 1 - nodes_index[i] = k - epoch_breaks = np.array(epoch_breaks) - epoch_length = np.diff(epoch_breaks) - num_epochs = epoch_length.size - - # instantaneous mutation rate per edge - edges_length = nodes_time[edges_parent] - nodes_time[edges_child] - edges_subset = edges_length > 0 - edges_counts = likelihoods.copy() - edges_counts[edges_subset, 0] /= edges_length[edges_subset] - - # pass over edges, measuring overlap with each time interval - epoch_counts = np.zeros((num_epochs, 2)) - for e in np.flatnonzero(edges_subset): - p, c = edges_parent[e], edges_child[e] - a, b = nodes_index[c], nodes_index[p] - if a < num_epochs: - epoch_counts[a] += edges_counts[e] * edges_weight[e] - if b < num_epochs: - epoch_counts[b] -= edges_counts[e] * edges_weight[e] - counts = np.cumsum(epoch_counts[:, 0]) - offset = np.cumsum(epoch_counts[:, 1]) + counts, offset, duration, indexes = mutational_area( + nodes_time, + likelihoods, + edges_parent, + edges_child, + ) # rescale time such that mutation density is constant between changepoints # TODO: use poisson changepoints to further refine - changepoints = _fixed_changepoints(offset * epoch_length, max_intervals) - changepoints = np.union1d(changepoints, nodes_index[nodes_fixed]) + epoch_breaks = np.append(0.0, np.cumsum(duration)) + changepoints = _fixed_changepoints(offset * duration, max_intervals) + changepoints = np.union1d(changepoints, indexes[nodes_fixed]) adjust = np.zeros(changepoints.size) k = 0 for i, j in zip(changepoints[:-1], changepoints[1:]): @@ -464,7 +538,7 @@ def mutational_timescale( # TODO: when changepoint intersects a fixed node? n = np.sum(offset[i:j]) y = np.sum(counts[i:j]) - z = np.sum(epoch_length[i:j]) + z = np.sum(duration[i:j]) assert n > 0, "Zero edge span in interval" adjust[k + 1] = z * y / n k += 1 diff --git a/tsdate/variational.py b/tsdate/variational.py index 46896453..2f7c51c8 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -241,8 +241,7 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): # count mutations on edges count_timing = time.time() - self.edge_likelihoods, self.mutation_edges = \ - count_mutations(ts, self.node_constraints) # fmt: skip + self.edge_likelihoods, self.mutation_edges = count_mutations(ts) self.edge_likelihoods[:, 1] *= mutation_rate count_timing -= time.time() logging.info(f"Extracted mutations in {abs(count_timing)} seconds") From fea19d4570ecd2d12e9d022882a1a22775e488b0 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 15 Jul 2024 17:23:04 -0700 Subject: [PATCH 20/29] Should be working, need to test --- tsdate/rescaling.py | 28 +++++++++++++--------------- tsdate/variational.py | 11 ++++++----- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/tsdate/rescaling.py b/tsdate/rescaling.py index 9ce66d0b..7c75ff45 100644 --- a/tsdate/rescaling.py +++ b/tsdate/rescaling.py @@ -111,10 +111,10 @@ def f(i, j): # loss return breaks -#@numba.njit( +# @numba.njit( # _tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f) -#) -#def _count_mutations( +# ) +# def _count_mutations( # node_is_leaf, # mutations_node, # mutations_position, @@ -125,7 +125,7 @@ def f(i, j): # loss # indexes_insert, # indexes_remove, # sequence_length, -#): +# ): # """ # Internals for `count_mutations` # """ @@ -184,7 +184,7 @@ def f(i, j): # loss # return edges_stats, mutations_edge # # -#def count_mutations(ts, constraints=None): +# def count_mutations(ts, constraints=None): # """ # Return an array with `num_edges` rows, and columns that are the number of # mutations per edge and the total span per edge @@ -230,7 +230,7 @@ def _count_mutations( size_biased, ): """ - Internals for `count_sizebiased` + Internals for `count_mutations` """ assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size assert indexes_insert.size == indexes_remove.size == edges_parent.size @@ -398,8 +398,8 @@ def mutational_area( return counts, offset, duration, nodes_index -#@numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i)) -#def mutational_timescale( +# @numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i)) +# def mutational_timescale( # nodes_time, # likelihoods, # constraints, @@ -407,7 +407,7 @@ def mutational_area( # edges_child, # edges_weight, # max_intervals, -#): +# ): # """ # Rescale node ages so that the instantaneous mutation rate is constant. # Edges with a negative duration are ignored when calculating the total @@ -485,14 +485,13 @@ def mutational_area( # return origin, adjust -@numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i)) +@numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _i)) def mutational_timescale( nodes_time, likelihoods, constraints, edges_parent, edges_child, - edges_weight, max_intervals, ): """ @@ -506,7 +505,6 @@ def mutational_timescale( :param np.ndarray constraints: lower and upper bounds on node age :param np.ndarray edges_parent: node index for the parent of each edge :param np.ndarray edges_child: node index for the child of each edge - :param np.ndarray edges_weight: a weight for each edge :param int max_intervals: maximum number of intervals within which to estimate the time scaling """ @@ -520,9 +518,9 @@ def mutational_timescale( assert np.all(nodes_time[nodes_fixed] == constraints[nodes_fixed, 0]) counts, offset, duration, indexes = mutational_area( - nodes_time, - likelihoods, - edges_parent, + nodes_time, + likelihoods, + edges_parent, edges_child, ) diff --git a/tsdate/variational.py b/tsdate/variational.py index 2f7c51c8..bebe4aaa 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -243,6 +243,8 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): count_timing = time.time() self.edge_likelihoods, self.mutation_edges = count_mutations(ts) self.edge_likelihoods[:, 1] *= mutation_rate + self.sizebiased_likelihoods, _ = count_mutations(ts, size_biased=True) + self.sizebiased_likelihoods[:, 1] *= mutation_rate count_timing -= time.time() logging.info(f"Extracted mutations in {abs(count_timing)} seconds") @@ -769,25 +771,24 @@ def rescale( quantile_width=0.5, ): """Normalise posteriors so that empirical mutation rate is constant""" - edge_weights = ( - np.ones(self.edge_weights.size) if rescale_segsites else self.edge_weights + likelihoods = ( + self.edge_likelihoods if rescale_segsites else self.sizebiased_likelihoods ) nodes_time = self._point_estimate( self.node_posterior, self.node_constraints, use_median ) reallocate_unphased( # correct mutation counts for unphased singletons - self.edge_likelihoods, + likelihoods, self.mutation_phase, self.mutation_blocks, self.block_edges, ) original_breaks, rescaled_breaks = mutational_timescale( nodes_time, - self.edge_likelihoods, + likelihoods, self.node_constraints, self.edge_parents, self.edge_children, - edge_weights, rescale_intervals, ) self.node_posterior[:] = piecewise_scale_posterior( From bf19e03183acdd57f575c796ab1711db98364365 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 15 Jul 2024 17:24:23 -0700 Subject: [PATCH 21/29] Should be working, need to test --- tsdate/variational.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tsdate/variational.py b/tsdate/variational.py index bebe4aaa..7d1d7bf8 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -771,9 +771,8 @@ def rescale( quantile_width=0.5, ): """Normalise posteriors so that empirical mutation rate is constant""" - likelihoods = ( - self.edge_likelihoods if rescale_segsites else self.sizebiased_likelihoods - ) + likelihoods = self.edge_likelihoods if rescale_segsites \ + else self.sizebiased_likelihoods # fmt: skip nodes_time = self._point_estimate( self.node_posterior, self.node_constraints, use_median ) From 63bc3e0295116cda25f2362874f1368b82a22c6b Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 17 Jul 2024 10:32:14 -0700 Subject: [PATCH 22/29] Debugging insert --- tsdate/evaluation.py | 8 ++++---- tsdate/variational.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/tsdate/evaluation.py b/tsdate/evaluation.py index 0e9f03db..452ae6fc 100644 --- a/tsdate/evaluation.py +++ b/tsdate/evaluation.py @@ -544,7 +544,7 @@ def mutations_time( max_freq=None, plotpath=None, title=None, - what=0, + what="midpoint", ): """ Return true and inferred mutation ages, optionally creating a scatterplot and @@ -588,7 +588,7 @@ def mutations_time( infr_mut = infr_mut[is_freq] true_mut = true_mut[is_freq] # get age of mutation or subtended node - if what == 1: + if what == "child": infr_node = infer_ts.edges_child[infr_edge[infr_mut]] assert np.allclose(infr_node, infer_ts.mutations_node[infr_mut]) true_node = ts.edges_child[true_edge[true_mut]] @@ -604,7 +604,7 @@ def mutations_time( nonzero = np.logical_and(mean > 0, truth > 0) mean = mean[nonzero] truth = truth[nonzero] - elif what == 2: + elif what == "parent": infr_node = infer_ts.edges_parent[infr_edge[infr_mut]] true_node = ts.edges_parent[true_edge[true_mut]] _, uniq_idx = np.unique(infr_node, return_index=True) @@ -618,7 +618,7 @@ def mutations_time( nonzero = np.logical_and(mean > 0, truth > 0) mean = mean[nonzero] truth = truth[nonzero] - elif what == 0: # midpoint on branch + elif what == "midpoint": # midpoint on branch # TODO clean up infr_p = infer_ts.edges_parent[infr_edge[infr_mut]] true_p = ts.edges_parent[true_edge[true_mut]] diff --git a/tsdate/variational.py b/tsdate/variational.py index 7d1d7bf8..14533239 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -659,6 +659,19 @@ def fixed_projection(x, y): child_cavity, edge_likelihood, ) + # DEBUG: nan in phase vector + if unphased and not np.isfinite(mutations_phase[m]): + print( + "ERR\tm:", + m, + "p:", + parent_cavity, + "c:", + child_cavity, + "e:", + edge_likelihood, + ) + # /DEBUG: nan in phase vector return np.nan From 3dbc3088d7c4fe67642a415fcedc93929ddd9406 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 17 Jul 2024 10:50:13 -0700 Subject: [PATCH 23/29] Skip nan in phase --- tsdate/phasing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tsdate/phasing.py b/tsdate/phasing.py index f79dcd89..02e59fb0 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -31,7 +31,6 @@ from .approx import _f from .approx import _f1r from .approx import _f2w -from .approx import _i from .approx import _i1r from .approx import _i1w from .approx import _i2r @@ -66,6 +65,9 @@ def reallocate_unphased( i, j = blocks_edges[b] assert tskit.NULL < i < num_edges and edges_unphased[i] assert tskit.NULL < j < num_edges and edges_unphased[j] + if np.isnan(mutations_phase[m]): # DEBUG + print("ERR skip nan in phase") + continue assert 0.0 <= mutations_phase[m] <= 1.0 edges_likelihood[i, 0] += mutations_phase[m] edges_likelihood[j, 0] += 1 - mutations_phase[m] From 82db5a67b09c28aa72440bed55e29d9aa13d2a45 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 17 Jul 2024 15:00:54 -0700 Subject: [PATCH 24/29] Expose rescaling --- tsdate/core.py | 12 +++- tsdate/rescaling.py | 142 ++++++++++++++++++++++-------------------- tsdate/variational.py | 19 ++++-- 3 files changed, 97 insertions(+), 76 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index 3b65be34..4a94ba5f 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -48,6 +48,7 @@ FORMAT_NAME = "tsdate" DEFAULT_RESCALING_INTERVALS = 1000 +DEFAULT_RESCALING_ITERATIONS = 1 DEFAULT_MAX_ITERATIONS = 10 DEFAULT_EPSILON = 1e-6 @@ -1236,6 +1237,7 @@ def run( max_iterations, max_shape, rescaling_intervals, + rescaling_iterations, match_segregating_sites, regularise_roots, singletons_phased, @@ -1255,9 +1257,10 @@ def run( singletons_phased=singletons_phased, ) posterior.run( - ep_maxitt=max_iterations, + ep_iterations=max_iterations, max_shape=max_shape, rescale_intervals=rescaling_intervals, + rescale_iterations=rescaling_iterations, regularise=regularise_roots, rescale_segsites=match_segregating_sites, progress=self.pbar, @@ -1554,6 +1557,7 @@ def variational_gamma( eps=None, max_iterations=None, rescaling_intervals=None, + rescaling_iterations=None, # deliberately undocumented parameters below. We may eventually document these max_shape=None, match_segregating_sites=None, @@ -1589,6 +1593,9 @@ def variational_gamma( :param float rescaling_intervals: For time rescaling, the number of time intervals within which to estimate a rescaling parameter. Setting this to zero means that rescaling is not performed. Default ``None``, treated as 1000. + :param float rescaling_intervals: The number of iterations for time rescaling. + Setting this to zero means that rescaling is not performed. Default + ``None``, treated as 1. :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, including ``time_units``, ``progress``, and ``record_provenance``. The arguments ``return_posteriors`` and ``return_likelihood`` can be @@ -1621,6 +1628,8 @@ def variational_gamma( max_shape = 1000 if rescaling_intervals is None: rescaling_intervals = DEFAULT_RESCALING_INTERVALS + if rescaling_iterations is None: + rescaling_iterations = DEFAULT_RESCALING_ITERATIONS if match_segregating_sites is None: match_segregating_sites = False if regularise_roots is None: @@ -1639,6 +1648,7 @@ def variational_gamma( max_iterations=max_iterations, max_shape=max_shape, rescaling_intervals=rescaling_intervals, + rescaling_iterations=rescaling_iterations, match_segregating_sites=match_segregating_sites, regularise_roots=regularise_roots, singletons_phased=singletons_phased, diff --git a/tsdate/rescaling.py b/tsdate/rescaling.py index 7c75ff45..d26ba98b 100644 --- a/tsdate/rescaling.py +++ b/tsdate/rescaling.py @@ -28,6 +28,7 @@ import numba import numpy as np import tskit +from tqdm import tqdm from .approx import _b from .approx import _b1r @@ -694,72 +695,75 @@ def edge_sampling_weight( # TODO: standalone API for rescaling -# def rescale_tree_sequence( -# ts, mutation_rate, *, rescaling_intervals=1000, match_segregating_sites=False -# ): -# """ -# Adjust the time scaling of a tree sequence so that expected mutational area -# matches the expected number of mutations on a path from leaf to root, where -# the expectation is taken over all paths and bases in the sequence. -# -# :param tskit.TreeSequence ts: the tree sequence to rescale -# :param float mutation_rate: the per-base mutation rate -# :param int rescaling_intervals: the number of time intervals for which -# to estimate a separate time rescaling parameter -# :param bool match_segregating_sites: if True, match the total number of -# mutations rather than the average number of differences from the ancestral -# state -# """ -# if match_segregating_sites: -# edge_weights = np.ones(ts.num_edges) -# else: -# has_parent = np.full(ts.num_nodes, False) -# has_child = np.full(ts.num_nodes, False) -# has_parent[ts.edges_child] = True -# has_child[ts.edges_parent] = True -# is_leaf = np.logical_and(~has_child, has_parent) -# edge_weights = edge_sampling_weight( -# is_leaf, -# ts.edges_parent, -# ts.edges_child, -# ts.edges_left, -# ts.edges_right, -# ts.indexes_edge_insertion_order, -# ts.indexes_edge_removal_order, -# ) -# # estimate time rescaling parameter within intervals -# samples = list(ts.samples()) -# if not np.all(ts.nodes_time[samples] == 0.0): -# raise ValueError("Normalisation not implemented for ancient samples") -# constraints = np.zeros((ts.num_nodes, 2)) -# constraints[:, 1] = np.inf -# constraints[samples, :] = ts.nodes_time[samples, np.newaxis] -# mutations_span, mutations_edge = mutation_span_array(ts) -# mutations_span[:, 1] *= mutation_rate -# original_breaks, rescaled_breaks = mutational_timescale( -# ts.nodes_time, -# mutations_span, -# constraints, -# ts.edges_parent, -# ts.edges_child, -# edge_weights, -# rescaling_intervals, -# ) -# # rescale node time -# assert np.all(np.diff(rescaled_breaks) > 0) -# assert np.all(np.diff(original_breaks) > 0) -# scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0) -# idx = np.searchsorted(original_breaks, ts.nodes_time, "right") - 1 -# nodes_time = rescaled_breaks[idx] + scalings[idx] * ( -# ts.nodes_time - original_breaks[idx] -# ) -# # calculate mutation time -# mutations_parent = ts.edges_parent[mutations_edge] -# mutations_child = ts.edges_child[mutations_edge] -# mutations_time = (nodes_time[mutations_parent] + nodes_time[mutations_child]) / 2 -# above_root = mutations_edge == tskit.NULL -# mutations_time[above_root] = nodes_time[mutations_child[above_root]] -# tables = ts.dump_tables() -# tables.nodes.time = nodes_time -# tables.mutations.time = mutations_time -# return tables.tree_sequence() +def rescale_tree_sequence( + ts, + mutation_rate, + *, + num_intervals=1000, + num_iterations=10, + match_segregating_sites=False, + progress=False +): + """ + Adjust the time scaling of a tree sequence so that expected mutational area + matches the expected number of mutations on a path from leaf to root, where + the expectation is taken over all paths and bases in the sequence. + + :param tskit.TreeSequence ts: the tree sequence to rescale + :param float mutation_rate: the per-base mutation rate + :param int num_intervals: the number of time intervals for which + to estimate a separate time rescaling parameter + :param int num_iterations: the number of iterations to repeat rescaling + :param bool match_segregating_sites: if True, match the total number of + mutations rather than the average number of differences from the ancestral + state + :param bool progress: if True, show a progress bar + """ + samples = list(ts.samples()) + if not np.all(ts.nodes_time[samples] == 0.0): + raise ValueError("Normalisation not implemented for ancient samples") + constraints = np.zeros((ts.num_nodes, 2)) + constraints[:, 1] = np.inf + constraints[samples, :] = ts.nodes_time[samples, np.newaxis] + if match_segregating_sites: + mutations_span, mutations_edge = count_mutations(ts) + else: + mutations_span, mutations_edge = count_mutations(ts, size_biased=True) + mutations_span[:, 1] *= mutation_rate + for _ in tqdm( + np.arange(num_iterations), + desc="Path Rescaling", + disable=not progress, + ): + original_breaks, rescaled_breaks = mutational_timescale( + ts.nodes_time, + mutations_span, + constraints, + ts.edges_parent, + ts.edges_child, + num_intervals, + ) + # rescale node time + assert np.all(np.diff(rescaled_breaks) > 0) + assert np.all(np.diff(original_breaks) > 0) + scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0) + idx = np.searchsorted(original_breaks, ts.nodes_time, "right") - 1 + nodes_time = rescaled_breaks[idx] + scalings[idx] * ( + ts.nodes_time - original_breaks[idx] + ) + # calculate mutation time + mutations_parent = ts.edges_parent[mutations_edge] + mutations_child = ts.edges_child[mutations_edge] + mutations_time = ( + nodes_time[mutations_parent] + nodes_time[mutations_child] + ) / 2 + above_root = mutations_edge == tskit.NULL + mutations_time[above_root] = nodes_time[mutations_child[above_root]] + tables = ts.dump_tables() + tables.nodes.time = nodes_time + tables.mutations.time = mutations_time + tables.sort() + tables.build_index() + tables.compute_mutation_parents() + ts = tables.tree_sequence() + return ts diff --git a/tsdate/variational.py b/tsdate/variational.py index 14533239..acb69139 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -821,17 +821,18 @@ def rescale( def run( self, *, - ep_maxitt=10, + ep_iterations=10, max_shape=1000, min_step=0.1, rescale_intervals=1000, rescale_segsites=False, + rescale_iterations=5, regularise=True, progress=None, ): nodes_timing = time.time() for _ in tqdm( - np.arange(ep_maxitt), + np.arange(ep_iterations), desc="Expectation Propagation", disable=not progress, ): @@ -893,11 +894,17 @@ def run( self.mutation_phase[switched] = 1 - self.mutation_phase[switched] logging.info(f"Switched phase of {np.sum(switched)} singletons") - if rescale_intervals > 0: + if rescale_intervals > 0 and rescale_iterations > 0: rescale_timing = time.time() - self.rescale( - rescale_intervals=rescale_intervals, rescale_segsites=rescale_segsites - ) + for _ in tqdm( + np.arange(rescale_iterations), + desc="Path Rescaling", + disable=not progress, + ): + self.rescale( + rescale_intervals=rescale_intervals, + rescale_segsites=rescale_segsites, + ) rescale_timing -= time.time() logging.info(f"Timescale rescaled in {abs(rescale_timing)} seconds") From 9951383fdfb170239baa9bc400af6360b1b95db7 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 17 Jul 2024 15:30:52 -0700 Subject: [PATCH 25/29] Remove debugging inserts --- tsdate/core.py | 16 +++++++++------- tsdate/phasing.py | 3 +-- tsdate/variational.py | 34 +++++++++++----------------------- 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index 4a94ba5f..8a805a34 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -46,6 +46,8 @@ from . import util from . import variational +logger = logging.getLogger(__name__) + FORMAT_NAME = "tsdate" DEFAULT_RESCALING_INTERVALS = 1000 DEFAULT_RESCALING_ITERATIONS = 1 @@ -983,7 +985,7 @@ def __init__( ts, Ne, approximate_priors=approx, progress=progress ) else: - logging.info("Using user-specified priors") + logger.info("Using user-specified priors") if Ne is not None: raise ValueError( "Cannot specify population size if specifying priors " @@ -1019,7 +1021,7 @@ def get_modified_ts(self, result, eps): mutations.parent = np.full(mutations.num_rows, tskit.NULL, dtype=np.int32) tables.time_units = self.time_units constr_timing -= time.time() - logging.info(f"Constrained node ages in {abs(constr_timing)} seconds") + logger.info(f"Constrained node ages in {abs(constr_timing)} seconds") # Add posterior mean and variance to node/mutation metadata meta_timing = time.time() self.set_time_metadata( @@ -1029,7 +1031,7 @@ def get_modified_ts(self, result, eps): mutations, mut_mean_t, mut_var_t, schemas.default_mutation_schema ) meta_timing -= time.time() - logging.info( + logger.info( f"Inserted node and mutation metadata in {abs(meta_timing)} seconds" ) sort_timing = time.time() @@ -1037,7 +1039,7 @@ def get_modified_ts(self, result, eps): tables.build_index() tables.compute_mutation_parents() sort_timing -= time.time() - logging.info(f"Sorted tree sequence in {abs(sort_timing)} seconds") + logger.info(f"Sorted tree sequence in {abs(sort_timing)} seconds") return tables.tree_sequence() def set_time_metadata(self, table, mean, var, default_schema, overwrite=False): @@ -1050,9 +1052,9 @@ def set_time_metadata(self, table, mean, var, default_schema, overwrite=False): md_iter = ({} for _ in range(table.num_rows)) # For speed, assume we don't need to validate encoder = table.metadata_schema.encode_row - logging.info(f"Set metadata schema on {table_name}") + logger.info(f"Set metadata schema on {table_name}") else: - logging.warning( + logger.warning( f"Could not set metadata on {table_name}: " "data already exists with no schema" ) @@ -1073,7 +1075,7 @@ def set_time_metadata(self, table, mean, var, default_schema, overwrite=False): metadata_array.append(encoder(metadata_dict)) table.packset_metadata(metadata_array) except tskit.MetadataValidationError as e: - logging.warning(f"Could not set time metadata in {table_name}: {e}") + logger.warning(f"Could not set time metadata in {table_name}: {e}") def parse_result(self, result, epsilon, extra_posterior_cols=None): # Construct the tree sequence to return and add other stuff we might want to diff --git a/tsdate/phasing.py b/tsdate/phasing.py index 02e59fb0..b81e0fe6 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -65,8 +65,7 @@ def reallocate_unphased( i, j = blocks_edges[b] assert tskit.NULL < i < num_edges and edges_unphased[i] assert tskit.NULL < j < num_edges and edges_unphased[j] - if np.isnan(mutations_phase[m]): # DEBUG - print("ERR skip nan in phase") + if np.isnan(mutations_phase[m]): # TODO: rare numerical issue continue assert 0.0 <= mutations_phase[m] <= 1.0 edges_likelihood[i, 0] += mutations_phase[m] diff --git a/tsdate/variational.py b/tsdate/variational.py index acb69139..8dbf0b13 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -54,6 +54,7 @@ from .rescaling import piecewise_scale_posterior from .util import contains_unary_nodes +logger = logging.getLogger(__name__) # columns for edge_factors ROOTWARD = 0 # edge likelihood to parent @@ -246,7 +247,7 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): self.sizebiased_likelihoods, _ = count_mutations(ts, size_biased=True) self.sizebiased_likelihoods[:, 1] *= mutation_rate count_timing -= time.time() - logging.info(f"Extracted mutations in {abs(count_timing)} seconds") + logger.info(f"Extracted mutations in {abs(count_timing)} seconds") # count mutations in singleton blocks phase_timing = time.time() @@ -260,9 +261,9 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): self.block_nodes[1] = self.edge_parents[self.block_edges[:, 1]] num_unphased = np.sum(self.mutation_blocks != tskit.NULL) phase_timing -= time.time() - logging.info(f"Found {num_unphased} unphased singleton mutations") - logging.info(f"Split unphased singleton edges into {num_blocks} blocks") - logging.info(f"Phased singletons in {abs(phase_timing)} seconds") + logger.info(f"Found {num_unphased} unphased singleton mutations") + logger.info(f"Split unphased singleton edges into {num_blocks} blocks") + logger.info(f"Phased singletons in {abs(phase_timing)} seconds") # mutable self.node_factors = np.zeros((ts.num_nodes, 2, 2)) @@ -659,19 +660,6 @@ def fixed_projection(x, y): child_cavity, edge_likelihood, ) - # DEBUG: nan in phase vector - if unphased and not np.isfinite(mutations_phase[m]): - print( - "ERR\tm:", - m, - "p:", - parent_cavity, - "c:", - child_cavity, - "e:", - edge_likelihood, - ) - # /DEBUG: nan in phase vector return np.nan @@ -843,8 +831,8 @@ def run( ) nodes_timing -= time.time() skipped_edges = np.sum(np.isnan(self.edge_logconst)) - logging.info(f"Skipped {skipped_edges} edges with invalid factors") - logging.info(f"Calculated node posteriors in {abs(nodes_timing)} seconds") + logger.info(f"Skipped {skipped_edges} edges with invalid factors") + logger.info(f"Calculated node posteriors in {abs(nodes_timing)} seconds") muts_timing = time.time() mutations_phased = self.mutation_blocks == tskit.NULL @@ -878,8 +866,8 @@ def run( ) muts_timing -= time.time() skipped_muts = np.sum(np.isnan(self.mutation_posterior[:, 0])) - logging.info(f"Skipped {skipped_muts} mutations with invalid posteriors") - logging.info(f"Calculated mutation posteriors in {abs(muts_timing)} seconds") + logger.info(f"Skipped {skipped_muts} mutations with invalid posteriors") + logger.info(f"Calculated mutation posteriors in {abs(muts_timing)} seconds") singletons = self.mutation_blocks != tskit.NULL switched_blocks = self.mutation_blocks[singletons] @@ -892,7 +880,7 @@ def run( self.mutation_nodes[singletons] = self.edge_children[switched_edges] switched = self.mutation_phase < 0.5 self.mutation_phase[switched] = 1 - self.mutation_phase[switched] - logging.info(f"Switched phase of {np.sum(switched)} singletons") + logger.info(f"Switched phase of {np.sum(switched)} singletons") if rescale_intervals > 0 and rescale_iterations > 0: rescale_timing = time.time() @@ -906,7 +894,7 @@ def run( rescale_segsites=rescale_segsites, ) rescale_timing -= time.time() - logging.info(f"Timescale rescaled in {abs(rescale_timing)} seconds") + logger.info(f"Timescale rescaled in {abs(rescale_timing)} seconds") def node_moments(self): """ From e4901dff0e61e228a0e09b3bae8a2816bf44a7d7 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Thu, 18 Jul 2024 20:54:58 -0700 Subject: [PATCH 26/29] Fix minor bug; faster iterated rescaling --- tsdate/core.py | 10 +- tsdate/phasing.py | 2 +- tsdate/rescaling.py | 270 +++++++----------------------------------- tsdate/variational.py | 107 +++++++---------- 4 files changed, 87 insertions(+), 302 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index 8a805a34..89e445de 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -50,7 +50,7 @@ FORMAT_NAME = "tsdate" DEFAULT_RESCALING_INTERVALS = 1000 -DEFAULT_RESCALING_ITERATIONS = 1 +DEFAULT_RESCALING_ITERATIONS = 5 DEFAULT_MAX_ITERATIONS = 10 DEFAULT_EPSILON = 1e-6 @@ -1021,7 +1021,7 @@ def get_modified_ts(self, result, eps): mutations.parent = np.full(mutations.num_rows, tskit.NULL, dtype=np.int32) tables.time_units = self.time_units constr_timing -= time.time() - logger.info(f"Constrained node ages in {abs(constr_timing)} seconds") + logger.info(f"Constrained node ages in {abs(constr_timing):.2f} seconds") # Add posterior mean and variance to node/mutation metadata meta_timing = time.time() self.set_time_metadata( @@ -1039,7 +1039,7 @@ def get_modified_ts(self, result, eps): tables.build_index() tables.compute_mutation_parents() sort_timing -= time.time() - logger.info(f"Sorted tree sequence in {abs(sort_timing)} seconds") + logger.info(f"Sorted tree sequence in {abs(sort_timing):.2f} seconds") return tables.tree_sequence() def set_time_metadata(self, table, mean, var, default_schema, overwrite=False): @@ -1595,9 +1595,9 @@ def variational_gamma( :param float rescaling_intervals: For time rescaling, the number of time intervals within which to estimate a rescaling parameter. Setting this to zero means that rescaling is not performed. Default ``None``, treated as 1000. - :param float rescaling_intervals: The number of iterations for time rescaling. + :param float rescaling_iterations: The number of iterations for time rescaling. Setting this to zero means that rescaling is not performed. Default - ``None``, treated as 1. + ``None``, treated as 5. :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, including ``time_units``, ``progress``, and ``record_provenance``. The arguments ``return_posteriors`` and ``return_likelihood`` can be diff --git a/tsdate/phasing.py b/tsdate/phasing.py index b81e0fe6..675d0ba3 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -65,7 +65,7 @@ def reallocate_unphased( i, j = blocks_edges[b] assert tskit.NULL < i < num_edges and edges_unphased[i] assert tskit.NULL < j < num_edges and edges_unphased[j] - if np.isnan(mutations_phase[m]): # TODO: rare numerical issue + if np.isnan(mutations_phase[m]): # TODO: fix rare numerical issue continue assert 0.0 <= mutations_phase[m] <= 1.0 edges_likelihood[i, 0] += mutations_phase[m] diff --git a/tsdate/rescaling.py b/tsdate/rescaling.py index d26ba98b..7fb24a58 100644 --- a/tsdate/rescaling.py +++ b/tsdate/rescaling.py @@ -28,7 +28,6 @@ import numba import numpy as np import tskit -from tqdm import tqdm from .approx import _b from .approx import _b1r @@ -112,108 +111,6 @@ def f(i, j): # loss return breaks -# @numba.njit( -# _tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f) -# ) -# def _count_mutations( -# node_is_leaf, -# mutations_node, -# mutations_position, -# edges_parent, -# edges_child, -# edges_left, -# edges_right, -# indexes_insert, -# indexes_remove, -# sequence_length, -# ): -# """ -# Internals for `count_mutations` -# """ -# assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size -# assert indexes_insert.size == indexes_remove.size == edges_parent.size -# assert mutations_node.size == mutations_position.size -# -# num_mutations = mutations_node.size -# num_edges = edges_parent.size -# num_nodes = node_is_leaf.size -# -# indexes_mutation = np.argsort(mutations_position) -# position_insert = edges_left[indexes_insert] -# position_remove = edges_right[indexes_remove] -# position_mutation = mutations_position[indexes_mutation] -# -# nodes_edge = np.full(num_nodes, tskit.NULL) -# mutations_edge = np.full(num_mutations, tskit.NULL) -# edges_mutations = np.zeros(num_edges) -# edges_span = edges_right - edges_left -# -# left = 0.0 -# a, b, d = 0, 0, 0 -# while a < num_edges or b < num_edges: -# while b < num_edges and position_remove[b] == left: # edges out -# e = indexes_remove[b] -# c = edges_child[e] -# nodes_edge[c] = tskit.NULL -# b += 1 -# -# while a < num_edges and position_insert[a] == left: # edges in -# e = indexes_insert[a] -# c = edges_child[e] -# nodes_edge[c] = e -# a += 1 -# -# right = sequence_length -# if b < num_edges: -# right = min(right, position_remove[b]) -# if a < num_edges: -# right = min(right, position_insert[a]) -# left = right -# -# while d < num_mutations and position_mutation[d] < right: -# m = indexes_mutation[d] -# c = mutations_node[m] -# e = nodes_edge[c] -# if e != tskit.NULL: -# mutations_edge[m] = e -# edges_mutations[e] += 1.0 -# d += 1 -# -# mutations_edge = mutations_edge.astype(np.int32) -# edges_stats = np.column_stack((edges_mutations, edges_span)) -# -# return edges_stats, mutations_edge -# -# -# def count_mutations(ts, constraints=None): -# """ -# Return an array with `num_edges` rows, and columns that are the number of -# mutations per edge and the total span per edge -# """ -# # TODO: adjust spans by an accessibility mask -# if constraints is None: -# node_is_leaf = np.full(ts.num_nodes, False) -# node_is_leaf[list(ts.samples())] = True -# else: -# assert constraints.shape == (ts.num_nodes, 2) -# node_is_leaf = np.logical_and( -# constraints[:, 0] == 0.0, -# constraints[:, 0] == constraints[:, 1], -# ) -# return _count_mutations( -# node_is_leaf, -# ts.mutations_node, -# ts.sites_position[ts.mutations_site], -# ts.edges_parent, -# ts.edges_child, -# ts.edges_left, -# ts.edges_right, -# ts.indexes_edge_insertion_order, -# ts.indexes_edge_removal_order, -# ts.sequence_length, -# ) - - @numba.njit( _tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f, _b) ) @@ -230,9 +127,6 @@ def _count_mutations( sequence_length, size_biased, ): - """ - Internals for `count_mutations` - """ assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size assert indexes_insert.size == indexes_remove.size == edges_parent.size assert mutations_node.size == mutations_position.size @@ -547,13 +441,12 @@ def mutational_timescale( return origin, adjust -@numba.njit(_f2w(_f2r, _f1r, _f1r, _f, _b)) +@numba.njit(_f2w(_f2r, _f1r, _f1r, _f)) def piecewise_scale_posterior( posteriors, original_breaks, rescaled_breaks, quantile_width, - use_median, ): """ :param float quantile_width: width of interquantile range to use for estimating @@ -567,7 +460,7 @@ def piecewise_scale_posterior( quant_lower = quantile_width / 2 quant_upper = 1 - quantile_width / 2 - # use posterior mean or median as a point estimate + # use posterior mean as a point estimate freed = np.logical_and(posteriors[:, 0] > -1, posteriors[:, 1] > 0) lower = np.zeros(dim) upper = np.zeros(dim) @@ -576,12 +469,11 @@ def piecewise_scale_posterior( alpha, beta = posteriors[i] lower[i] = gammainc_inv(alpha + 1, quant_lower) / beta upper[i] = gammainc_inv(alpha + 1, quant_upper) / beta - midpt[i] = gammainc_inv(alpha + 1, 0.5) if use_median else (alpha + 1) - midpt[i] /= beta + midpt[i] = (alpha + 1) / beta # rescale quantiles - assert np.all(np.diff(rescaled_breaks) > 0) - assert np.all(np.diff(original_breaks) > 0) + assert np.all(np.diff(rescaled_breaks) > 0), "Use fewer rescaling intervals" + assert np.all(np.diff(original_breaks) > 0), "Use fewer rescaling intervals" scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0) def rescale(x): @@ -600,109 +492,35 @@ def rescale(x): alpha, beta = approximate_gamma_iqr( quant_lower, quant_upper, lower[i], upper[i] ) - beta = gammainc_inv(alpha + 1, 0.5) if use_median else (alpha + 1) - beta /= midpt[i] # choose rate so as to keep mean or median + beta = (alpha + 1) / midpt[i] # choose rate so as to keep mean new_posteriors[i] = alpha, beta return new_posteriors -@numba.njit(_f1w(_b1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r)) -def edge_sampling_weight( - is_leaf, - edges_parent, - edges_child, - edges_left, - edges_right, - insert_index, - remove_index, +@numba.njit(_f1w(_f1r, _f1r, _f1r)) +def piecewise_scale_point_estimate( + point_estimate, + original_breaks, + rescaled_breaks, ): - """ - Calculate the probability that a randomly selected root-to-leaf path from a - random point on the sequence contains a given edge, for all edges. - - :param np.ndarray is_leaf: boolean array indicating whether a node is a leaf - """ - num_nodes = is_leaf.size - num_edges = edges_child.size - - insert_position = edges_left[insert_index] - remove_position = edges_right[remove_index] - sequence_length = remove_position[-1] - - nodes_parent = np.full(num_nodes, tskit.NULL) - nodes_edge = np.full(num_nodes, tskit.NULL) - nodes_leaves = np.zeros(num_nodes) - edges_leaves = np.zeros(num_edges) - - nodes_leaves[is_leaf] = 1.0 - total_leaves = 0.0 - position = 0.0 - a, b = 0, 0 - while position < sequence_length: - edges_out = [] - while b < num_edges and remove_position[b] == position: - edges_out.append(remove_index[b]) - b += 1 - - edges_in = [] - while a < num_edges and insert_position[a] == position: # edges in - edges_in.append(insert_index[a]) - a += 1 - - remainder = sequence_length - position - - for e in edges_out: - p, c = edges_parent[e], edges_child[e] - update = nodes_leaves[c] - while p != tskit.NULL: - u = nodes_edge[c] - edges_leaves[u] -= update * remainder - c, p = p, nodes_parent[p] - p, c = edges_parent[e], edges_child[e] - while p != tskit.NULL: - nodes_leaves[p] -= update - p = nodes_parent[p] - nodes_parent[c] = tskit.NULL - nodes_edge[c] = tskit.NULL - if is_leaf[c]: - total_leaves -= remainder - - for e in edges_in: - p, c = edges_parent[e], edges_child[e] - nodes_parent[c] = p - nodes_edge[c] = e - if is_leaf[c]: - total_leaves += remainder - update = nodes_leaves[c] - while p != tskit.NULL: - nodes_leaves[p] += update - p = nodes_parent[p] - p, c = edges_parent[e], edges_child[e] - while p != tskit.NULL: - u = nodes_edge[c] - edges_leaves[u] += update * remainder - c, p = p, nodes_parent[p] - - position = sequence_length - if b < num_edges: - position = min(position, remove_position[b]) - if a < num_edges: - position = min(position, insert_position[a]) - - edges_leaves /= total_leaves - return edges_leaves + assert np.all(np.diff(rescaled_breaks) > 0), "Use fewer rescaling intervals" + assert np.all(np.diff(original_breaks) > 0), "Use fewer rescaling intervals" + scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0) + idx = np.searchsorted(original_breaks, point_estimate, "right") - 1 + rescaled_estimate = rescaled_breaks[idx] + \ + scalings[idx] * (point_estimate - original_breaks[idx]) # fmt: skip + return rescaled_estimate -# TODO: standalone API for rescaling +# standalone API for rescaling (TODO: needs testing) def rescale_tree_sequence( ts, mutation_rate, *, - num_intervals=1000, + num_intervals=100, num_iterations=10, match_segregating_sites=False, - progress=False ): """ Adjust the time scaling of a tree sequence so that expected mutational area @@ -730,40 +548,32 @@ def rescale_tree_sequence( else: mutations_span, mutations_edge = count_mutations(ts, size_biased=True) mutations_span[:, 1] *= mutation_rate - for _ in tqdm( - np.arange(num_iterations), - desc="Path Rescaling", - disable=not progress, - ): + # rescale node ages + nodes_time = ts.nodes_time.copy() + for _ in np.arange(num_iterations): original_breaks, rescaled_breaks = mutational_timescale( - ts.nodes_time, + nodes_time, mutations_span, constraints, ts.edges_parent, ts.edges_child, num_intervals, ) - # rescale node time - assert np.all(np.diff(rescaled_breaks) > 0) - assert np.all(np.diff(original_breaks) > 0) - scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0) - idx = np.searchsorted(original_breaks, ts.nodes_time, "right") - 1 - nodes_time = rescaled_breaks[idx] + scalings[idx] * ( - ts.nodes_time - original_breaks[idx] + nodes_time = piecewise_scale_point_estimate( + nodes_time, original_breaks, rescaled_breaks ) - # calculate mutation time - mutations_parent = ts.edges_parent[mutations_edge] - mutations_child = ts.edges_child[mutations_edge] - mutations_time = ( - nodes_time[mutations_parent] + nodes_time[mutations_child] - ) / 2 - above_root = mutations_edge == tskit.NULL - mutations_time[above_root] = nodes_time[mutations_child[above_root]] - tables = ts.dump_tables() - tables.nodes.time = nodes_time - tables.mutations.time = mutations_time - tables.sort() - tables.build_index() - tables.compute_mutation_parents() - ts = tables.tree_sequence() + # calculate mutation ages + mutations_parent = ts.edges_parent[mutations_edge] + mutations_child = ts.edges_child[mutations_edge] + mutations_time = (nodes_time[mutations_parent] + nodes_time[mutations_child]) / 2 + above_root = mutations_edge == tskit.NULL + assert np.allclose(mutations_child[~above_root], ts.mutations_node[~above_root]) + mutations_time[above_root] = nodes_time[ts.mutations_node[above_root]] + tables = ts.dump_tables() + tables.nodes.time = nodes_time + tables.mutations.time = mutations_time + tables.sort() + tables.build_index() + tables.compute_mutation_parents() + ts = tables.tree_sequence() return ts diff --git a/tsdate/variational.py b/tsdate/variational.py index 8dbf0b13..2f3f5b61 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -45,12 +45,11 @@ from .approx import _i from .approx import _i1r from .approx import _i2r -from .hypergeo import _gammainc_inv as gammainc_inv from .phasing import block_singletons from .phasing import reallocate_unphased from .rescaling import count_mutations -from .rescaling import edge_sampling_weight from .rescaling import mutational_timescale +from .rescaling import piecewise_scale_point_estimate from .rescaling import piecewise_scale_posterior from .util import contains_unary_nodes @@ -196,20 +195,6 @@ def _check_valid_state( posterior_check += node_factors[:, CONSTRNT] np.testing.assert_allclose(posterior_check, posterior) - @staticmethod - @numba.njit(_f1w(_f2r, _f2r, _b)) - def _point_estimate(posteriors, constraints, median): - assert posteriors.shape == constraints.shape - fixed = constraints[:, 0] == constraints[:, 1] - point_estimate = np.zeros(posteriors.shape[0]) - for i in np.flatnonzero(~fixed): - alpha, beta = posteriors[i] - point_estimate[i] = gammainc_inv(alpha + 1, 0.5) \ - if median else (alpha + 1) # fmt: skip - point_estimate[i] /= beta - point_estimate[fixed] = constraints[fixed, 0] - return point_estimate - def __init__(self, ts, *, mutation_rate, singletons_phased=True): """ Initialize an expectation propagation algorithm for dating nodes @@ -247,7 +232,7 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): self.sizebiased_likelihoods, _ = count_mutations(ts, size_biased=True) self.sizebiased_likelihoods[:, 1] *= mutation_rate count_timing -= time.time() - logger.info(f"Extracted mutations in {abs(count_timing)} seconds") + logger.info(f"Extracted mutations in {abs(count_timing):.2f} seconds") # count mutations in singleton blocks phase_timing = time.time() @@ -263,7 +248,7 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): phase_timing -= time.time() logger.info(f"Found {num_unphased} unphased singleton mutations") logger.info(f"Split unphased singleton edges into {num_blocks} blocks") - logger.info(f"Phased singletons in {abs(phase_timing)} seconds") + logger.info(f"Phased singletons in {abs(phase_timing):.2f} seconds") # mutable self.node_factors = np.zeros((ts.num_nodes, 2, 2)) @@ -293,20 +278,11 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True): edge_unphased[self.block_edges[:, 1]] = True edges = np.arange(ts.num_edges, dtype=np.int32)[~edge_unphased] self.edge_order = np.concatenate((edges[:-1], np.flip(edges))) - self.edge_weights = edge_sampling_weight( - self.leaves, - self.edge_parents, - self.edge_children, - ts.edges_left, - ts.edges_right, - ts.indexes_edge_insertion_order, - ts.indexes_edge_removal_order, - ) self.block_order = np.arange(num_blocks, dtype=np.int32) self.mutation_order = np.arange(ts.num_mutations, dtype=np.int32) @staticmethod - @numba.njit(_f(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f, _b)) + @numba.njit(_void(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f, _b)) def propagate_likelihood( edge_order, edges_parent, @@ -465,10 +441,8 @@ def twin_projection(x, y): scale[p] *= parent_eta scale[c] *= child_eta - return np.nan - @staticmethod - @numba.njit(_f(_b1r, _f2w, _f3w, _f1w, _f, _i, _f)) + @numba.njit(_void(_b1r, _f2w, _f3w, _f1w, _f, _i, _f)) def propagate_prior( free, posterior, factors, scale, max_shape, em_maxitt, em_reltol ): @@ -522,11 +496,9 @@ def posterior_damping(x): posterior[i] *= eta scale[i] *= eta - return np.nan - @staticmethod @numba.njit( - _f(_i1r, _f2w, _f1w, _i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b) + _void(_i1r, _f2w, _f1w, _i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b) ) def propagate_mutations( mutations_order, @@ -661,8 +633,6 @@ def fixed_projection(x, y): edge_likelihood, ) - return np.nan - @staticmethod @numba.njit(_void(_i1r, _i1r, _i2r, _f3w, _f3w, _f3w, _f1w)) def rescale_factors( @@ -761,49 +731,53 @@ def iterate( self.block_factors, ) - return np.nan # TODO: placeholder for marginal likelihood - def rescale( self, *, rescale_intervals=1000, rescale_segsites=False, - use_median=False, + rescale_iterations=10, quantile_width=0.5, + progress=False, ): """Normalise posteriors so that empirical mutation rate is constant""" likelihoods = self.edge_likelihoods if rescale_segsites \ else self.sizebiased_likelihoods # fmt: skip - nodes_time = self._point_estimate( - self.node_posterior, self.node_constraints, use_median - ) reallocate_unphased( # correct mutation counts for unphased singletons likelihoods, self.mutation_phase, self.mutation_blocks, self.block_edges, ) - original_breaks, rescaled_breaks = mutational_timescale( - nodes_time, - likelihoods, - self.node_constraints, - self.edge_parents, - self.edge_children, - rescale_intervals, + nodes_time, _ = self.node_moments() + rescaled_nodes_time = nodes_time.copy() + for _ in np.arange(rescale_iterations): # estimate time rescaling + original_breaks, rescaled_breaks = mutational_timescale( + rescaled_nodes_time, + likelihoods, + self.node_constraints, + self.edge_parents, + self.edge_children, + rescale_intervals, + ) + rescaled_nodes_time = piecewise_scale_point_estimate( + rescaled_nodes_time, original_breaks, rescaled_breaks + ) + _, unique = np.unique(rescaled_nodes_time, return_index=True) + original_breaks = piecewise_scale_point_estimate( + rescaled_breaks, rescaled_nodes_time[unique], nodes_time[unique] ) self.node_posterior[:] = piecewise_scale_posterior( self.node_posterior, original_breaks, rescaled_breaks, quantile_width, - use_median, ) self.mutation_posterior[:] = piecewise_scale_posterior( self.mutation_posterior, original_breaks, rescaled_breaks, quantile_width, - use_median, ) def run( @@ -814,7 +788,7 @@ def run( min_step=0.1, rescale_intervals=1000, rescale_segsites=False, - rescale_iterations=5, + rescale_iterations=10, regularise=True, progress=None, ): @@ -832,7 +806,7 @@ def run( nodes_timing -= time.time() skipped_edges = np.sum(np.isnan(self.edge_logconst)) logger.info(f"Skipped {skipped_edges} edges with invalid factors") - logger.info(f"Calculated node posteriors in {abs(nodes_timing)} seconds") + logger.info(f"Calculated node posteriors in {abs(nodes_timing):.2f} seconds") muts_timing = time.time() mutations_phased = self.mutation_blocks == tskit.NULL @@ -867,7 +841,7 @@ def run( muts_timing -= time.time() skipped_muts = np.sum(np.isnan(self.mutation_posterior[:, 0])) logger.info(f"Skipped {skipped_muts} mutations with invalid posteriors") - logger.info(f"Calculated mutation posteriors in {abs(muts_timing)} seconds") + logger.info(f"Calculated mutation posteriors in {abs(muts_timing):.2f} seconds") singletons = self.mutation_blocks != tskit.NULL switched_blocks = self.mutation_blocks[singletons] @@ -884,17 +858,14 @@ def run( if rescale_intervals > 0 and rescale_iterations > 0: rescale_timing = time.time() - for _ in tqdm( - np.arange(rescale_iterations), - desc="Path Rescaling", - disable=not progress, - ): - self.rescale( - rescale_intervals=rescale_intervals, - rescale_segsites=rescale_segsites, - ) + self.rescale( + rescale_intervals=rescale_intervals, + rescale_iterations=rescale_iterations, + rescale_segsites=rescale_segsites, + progress=progress, + ) rescale_timing -= time.time() - logger.info(f"Timescale rescaled in {abs(rescale_timing)} seconds") + logger.info(f"Timescale rescaled in {abs(rescale_timing):.2f} seconds") def node_moments(self): """ @@ -929,12 +900,15 @@ def mutation_mapping(self): return self.mutation_edges, self.mutation_nodes +# NB: used for debugging # def date( # ts, # *, # mutation_rate, # singletons_phased=True, # max_iterations=10, +# rescaling_intervals=1000, +# rescaling_iterations=10, # match_segregating_sites=False, # regularise_roots=True, # constr_iterations=0, @@ -954,6 +928,7 @@ def mutation_mapping(self): # ep_maxitt=max_iterations, # max_shape=max_shape, # rescale_intervals=rescaling_intervals, +# rescale_iterations=rescaling_iterations, # regularise=regularise_roots, # rescale_segsites=match_segregating_sites, # progress=progress, @@ -964,8 +939,8 @@ def mutation_mapping(self): # mutation_edge, mutation_node = posterior.mutation_mapping() # # tables = ts.dump_tables() -# tables.nodes.time = constrain_ages( -# ts, node_mn, constr_iterations=constr_iterations) +# tables.nodes.time = \ +# constrain_ages(ts, node_mn, constr_iterations=constr_iterations) # tables.mutations.node = mutation_node # tables.sort() # From fe4a1855acd14cc91d13e392c1da445d70fe25ea Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Thu, 18 Jul 2024 20:57:37 -0700 Subject: [PATCH 27/29] Remove unused test --- tests/test_rescaling.py | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/tests/test_rescaling.py b/tests/test_rescaling.py index b5944ee2..a289f9bc 100644 --- a/tests/test_rescaling.py +++ b/tests/test_rescaling.py @@ -34,7 +34,6 @@ import tsdate from tsdate.rescaling import count_mutations -from tsdate.rescaling import edge_sampling_weight from tsdate.rescaling import mutational_area @@ -54,41 +53,6 @@ def inferred_ts(): return inferred_ts -# TODO: delete, methodology is flawed -class TestEdgeSamplingWeight: - @staticmethod - def naive_edge_sampling_weight(ts): - out = np.zeros(ts.num_edges) - tot = 0.0 - for t in ts.trees(): - if t.num_edges == 0: - continue - tot += t.num_samples() * t.span - for n in t.nodes(): - e = t.edge(n) - if e == tskit.NULL: - continue - out[e] += t.num_samples(n) * t.span - out /= tot - return out - - def test_edge_sampling_weight(self, inferred_ts): - ts = inferred_ts - is_leaf = np.full(ts.num_nodes, False) - is_leaf[list(ts.samples())] = True - edges_weight = edge_sampling_weight( - is_leaf, - ts.edges_parent, - ts.edges_child, - ts.edges_left, - ts.edges_right, - ts.indexes_edge_insertion_order, - ts.indexes_edge_removal_order, - ) - ck_edges_weight = self.naive_edge_sampling_weight(ts) - np.testing.assert_allclose(edges_weight, ck_edges_weight) - - class TestMutationalArea: """ Test tallying of mutational area within inter-node time intervals. From 8f4cd126933f6af70f0028832c4298a23c7e0e73 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 22 Jul 2024 09:55:34 -0700 Subject: [PATCH 28/29] Add docstring for match_segregating_sites --- tsdate/core.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tsdate/core.py b/tsdate/core.py index 89e445de..16826bbc 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1560,9 +1560,9 @@ def variational_gamma( max_iterations=None, rescaling_intervals=None, rescaling_iterations=None, + match_segregating_sites=None, # deliberately undocumented parameters below. We may eventually document these max_shape=None, - match_segregating_sites=None, regularise_roots=None, singletons_phased=None, **kwargs, @@ -1598,6 +1598,11 @@ def variational_gamma( :param float rescaling_iterations: The number of iterations for time rescaling. Setting this to zero means that rescaling is not performed. Default ``None``, treated as 5. + :param bool match_segregating_sites: If ``True``, then time is rescaled + such that branch- and site-mode segregating sites are approximately equal. + If ``False``, time is rescaled such that branch- and site-mode root-to-leaf + length are approximately equal, which gives unbiased estimates when there + are polytomies. Default ``False``. :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, including ``time_units``, ``progress``, and ``record_provenance``. The arguments ``return_posteriors`` and ``return_likelihood`` can be From 421628056f2c4496692c47afbb312d7b0a39d449 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 22 Jul 2024 12:38:59 -0700 Subject: [PATCH 29/29] Plot labeling for evaluation.mutations_time --- tsdate/evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tsdate/evaluation.py b/tsdate/evaluation.py index 452ae6fc..229005c9 100644 --- a/tsdate/evaluation.py +++ b/tsdate/evaluation.py @@ -640,7 +640,7 @@ def mutations_time( plt.hexbin(truth, mean, xscale="log", yscale="log", mincnt=1) plt.text(0.01, 0.99, info, ha="left", va="top", transform=plt.gca().transAxes) plt.axline(pt1, pt2, linestyle="--", color="firebrick") - if what != 0: + if what != "midpoint": plt.xlabel("True node age") plt.ylabel("Estimated node age") else: