From 589d70e440b1b2ada58f9d9f2d94996e4b00b340 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 17 Oct 2023 11:38:14 +0200 Subject: [PATCH 1/8] initial commit --- pydeseq2/ds.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pydeseq2/ds.py b/pydeseq2/ds.py index fba0dd95..ce4fd7e5 100644 --- a/pydeseq2/ds.py +++ b/pydeseq2/ds.py @@ -146,6 +146,7 @@ class DeseqStats: def __init__( self, dds: DeseqDataSet, + test: Literal["Wald", "LRT"] = "Wald", contrast: Optional[List[str]] = None, alpha: float = 0.05, cooks_filter: bool = True, @@ -167,6 +168,7 @@ def __init__( self.dds = dds self.alpha = alpha + self.test = test self.cooks_filter = cooks_filter self.independent_filter = independent_filter self.base_mean = self.dds.varm["_normed_means"].copy() From 1a95da83c27342302dff5b2513aefc482695f17e Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 17 Oct 2023 11:52:57 +0200 Subject: [PATCH 2/8] arg sanity check --- pydeseq2/ds.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pydeseq2/ds.py b/pydeseq2/ds.py index ce4fd7e5..92cc4d90 100644 --- a/pydeseq2/ds.py +++ b/pydeseq2/ds.py @@ -36,6 +36,9 @@ class DeseqStats: dds : DeseqDataSet DeseqDataSet for which dispersion and LFCs were already estimated. + test : Literal["wald", "lrt"] + The statistical test to use. One of ``["wald", "lrt"]``. + contrast : list or None A list of three strings, in the following format: ``['variable_of_interest', 'tested_level', 'ref_level']``. @@ -167,8 +170,11 @@ def __init__( self.dds = dds - self.alpha = alpha + if test not in ("wald", "LRT"): + raise ValueError(f"Available tests are `wald` and `LRT`. Got: {test}.") self.test = test + + self.alpha = alpha self.cooks_filter = cooks_filter self.independent_filter = independent_filter self.base_mean = self.dds.varm["_normed_means"].copy() From 86cf0e0d041834d5bb53f916490a34dd47727052 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 17 Oct 2023 11:57:26 +0200 Subject: [PATCH 3/8] setup wald and LRT --- pydeseq2/ds.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pydeseq2/ds.py b/pydeseq2/ds.py index 92cc4d90..fc5d6808 100644 --- a/pydeseq2/ds.py +++ b/pydeseq2/ds.py @@ -257,7 +257,11 @@ def summary( self.lfc_null = lfc_null self.alt_hypothesis = alt_hypothesis rerun_summary = True - self.run_wald_test() + + if self.test == "wald": + self.run_wald_test() + else: + self.run_likelihood_ratio_test() if self.cooks_filter: # Filter p-values based on Cooks outliers @@ -363,6 +367,13 @@ def run_wald_test(self) -> None: self.statistics.loc[self.dds.new_all_zeroes_genes] = 0.0 self.p_values.loc[self.dds.new_all_zeroes_genes] = 1.0 + def run_likelihood_ratio_test(self) -> None: + """Perform a Likelihood Ratio test. + + Get gene-wise p-values for gene over/under-expression. + """ + raise NotImplementedError + def lfc_shrink(self, coeff: Optional[str] = None) -> None: """LFC shrinkage with an apeGLM prior :cite:p:`DeseqStats-zhu2019heavy`. From 4a663d0e7a964863f89e3b5cffdcf8e6004fffd6 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 17 Oct 2023 15:18:07 +0200 Subject: [PATCH 4/8] implemented lrt test --- pydeseq2/ds.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++- pydeseq2/utils.py | 56 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 1 deletion(-) diff --git a/pydeseq2/ds.py b/pydeseq2/ds.py index fc5d6808..82774e8b 100644 --- a/pydeseq2/ds.py +++ b/pydeseq2/ds.py @@ -3,6 +3,7 @@ from typing import List from typing import Literal from typing import Optional +from typing import Tuple # import anndata as ad import numpy as np @@ -20,6 +21,7 @@ from pydeseq2.utils import get_num_processes from pydeseq2.utils import make_MA_plot from pydeseq2.utils import nbinomGLM +from pydeseq2.utils import lrt_test from pydeseq2.utils import wald_test @@ -372,7 +374,67 @@ def run_likelihood_ratio_test(self) -> None: Get gene-wise p-values for gene over/under-expression. """ - raise NotImplementedError + + num_genes = self.dds.n_vars + num_vars = self.design_matrix.shape[1] + + # XXX: Raise a warning if LFCs are shrunk. + + def reduce( + design_matrix: np.ndarray, ridge_factor: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + indices = np.full(design_matrix.shape[1], True, dtype=bool) + indices[self.contrast_idx] = False + return design_matrix[:, indices], ridge_factor[indices] + + # Set regularization factors. + if self.prior_LFC_var is not None: + ridge_factor = np.diag(1 / self.prior_LFC_var**2) + else: + ridge_factor = np.diag(np.repeat(1e-6, num_vars)) + + design_matrix = self.design_matrix.values + LFCs = self.LFC.values + + reduced_design_matrix, reduced_ridge_factor = reduce(design_matrix, ridge_factor) + self.dds.obsm["reduced_design_matrix"] = reduced_design_matrix + + if not self.quiet: + print("Running LRT tests...", file=sys.stderr) + start = time.time() + with parallel_backend("loky", inner_max_num_threads=1): + res = Parallel( + n_jobs=self.n_processes, + verbose=self.joblib_verbosity, + batch_size=self.batch_size, + )( + delayed(lrt_test)( + counts=self.dds.X[:, i], + design_matrix=design_matrix, + reduced_design_matrix=reduced_design_matrix, + size_factors=self.dds.obsm["size_factors"], + disp=self.dds.varm["dispersions"][i], + lfc=LFCs[i], + min_mu=self.dds.min_mu, + ridge_factor=ridge_factor, + reduced_ridge_factor=reduced_ridge_factor, + beta_tol=self.dds.beta_tol, + ) + for i in range(num_genes) + ) + end = time.time() + if not self.quiet: + print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr) + + pvals, stats = zip(*res) + + self.p_values: pd.Series = pd.Series(pvals, index=self.dds.var_names) + self.statistics: pd.Series = pd.Series(stats, index=self.dds.var_names) + + # Account for possible all_zeroes due to outlier refitting in DESeqDataSet + if self.dds.refit_cooks and self.dds.varm["replaced"].sum() > 0: + self.statistics.loc[self.dds.new_all_zeroes_genes] = 0.0 + self.p_values.loc[self.dds.new_all_zeroes_genes] = 1.0 def lfc_shrink(self, coeff: Optional[str] = None) -> None: """LFC shrinkage with an apeGLM prior :cite:p:`DeseqStats-zhu2019heavy`. diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 5b18559e..581e5b8a 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -17,6 +17,7 @@ from scipy.special import gammaln # type: ignore from scipy.special import polygamma # type: ignore from scipy.stats import norm # type: ignore +from scipy.stats import chi2 # type: ignore from sklearn.linear_model import LinearRegression # type: ignore import pydeseq2 @@ -979,6 +980,61 @@ def less_abs(lfc_null): return wald_p_value, wald_statistic, wald_se +def lrt_test( + counts: np.ndarray, + design_matrix: np.ndarray, + reduced_design_matrix: np.ndarray, + size_factors: np.ndarray, + disp: float, + lfc: np.ndarray, + min_mu: float, + ridge_factor: np.ndarray, + reduced_ridge_factor: np.ndarray, + beta_tol: float, +) -> Tuple[float, float]: + """Run likelihood ratio test for differential expression. + + Compute likelihood ratio test statistics and p-values from + dispersion and LFC estimates. + + Parameters + ---------- + + Returns + ------- + lrt_p_value : float + Estimated p-value. + + lrt_statistic : float + LRT statistic. + """ + def reg_nb_nll( + beta: np.ndarray, design_matrix: np.ndarray, ridge_factor: np.ndarray + ) -> float: + # closure to minimize + mu_ = np.maximum(size_factors * np.exp(design_matrix @ beta), min_mu) + val = nb_nll(counts, mu_, disp) + 0.5 * (ridge_factor @ beta**2).sum() + return -1.0 * val # maximize the likelihood + + beta_reduced, *_ = irls_solver( + counts=counts, + size_factors=size_factors, + design_matrix=reduced_design_matrix, + disp=disp, + min_mu=min_mu, + beta_tol=beta_tol, + ) + + reduced_ll = reg_nb_nll(beta_reduced, reduced_design_matrix, reduced_ridge_factor) + full_ll = reg_nb_nll(lfc, design_matrix, ridge_factor) + + lrt_statistic = 2 * (full_ll - reduced_ll) + # df = 1 since contrast_idx is the only variable removed + lrt_p_value = chi2.sf(lrt_statistic, df=1) + + return lrt_p_value, lrt_statistic + + def fit_rough_dispersions( normed_counts: np.ndarray, design_matrix: pd.DataFrame ) -> np.ndarray: From 570f88111ee9baeb6bf606877c2ce5a3f8693050 Mon Sep 17 00:00:00 2001 From: Barbara Bodinier Date: Tue, 17 Oct 2023 16:05:24 +0200 Subject: [PATCH 5/8] fix and test on example --- pydeseq2/ds.py | 15 ++++++++------- pydeseq2/utils.py | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/pydeseq2/ds.py b/pydeseq2/ds.py index 82774e8b..1e4368c7 100644 --- a/pydeseq2/ds.py +++ b/pydeseq2/ds.py @@ -19,9 +19,9 @@ from pydeseq2.dds import DeseqDataSet from pydeseq2.utils import get_num_processes +from pydeseq2.utils import lrt_test from pydeseq2.utils import make_MA_plot from pydeseq2.utils import nbinomGLM -from pydeseq2.utils import lrt_test from pydeseq2.utils import wald_test @@ -38,8 +38,8 @@ class DeseqStats: dds : DeseqDataSet DeseqDataSet for which dispersion and LFCs were already estimated. - test : Literal["wald", "lrt"] - The statistical test to use. One of ``["wald", "lrt"]``. + test : Literal["Wald", "LRT"] + The statistical test to use. One of ``["Wald", "LRT"]``. contrast : list or None A list of three strings, in the following format: @@ -172,7 +172,7 @@ def __init__( self.dds = dds - if test not in ("wald", "LRT"): + if test not in ("Wald", "LRT"): raise ValueError(f"Available tests are `wald` and `LRT`. Got: {test}.") self.test = test @@ -260,7 +260,7 @@ def summary( self.alt_hypothesis = alt_hypothesis rerun_summary = True - if self.test == "wald": + if self.test == "Wald": self.run_wald_test() else: self.run_likelihood_ratio_test() @@ -282,7 +282,8 @@ def summary( self.results_df = pd.DataFrame(index=self.dds.var_names) self.results_df["baseMean"] = self.base_mean self.results_df["log2FoldChange"] = self.LFC @ self.contrast_vector / np.log(2) - self.results_df["lfcSE"] = self.SE / np.log(2) + if self.test == "Wald": + self.results_df["lfcSE"] = self.SE / np.log(2) self.results_df["stat"] = self.statistics self.results_df["pvalue"] = self.p_values self.results_df["padj"] = self.padj @@ -385,7 +386,7 @@ def reduce( ) -> Tuple[np.ndarray, np.ndarray]: indices = np.full(design_matrix.shape[1], True, dtype=bool) indices[self.contrast_idx] = False - return design_matrix[:, indices], ridge_factor[indices] + return design_matrix[:, indices], ridge_factor[indices][:, indices] # Set regularization factors. if self.prior_LFC_var is not None: diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 581e5b8a..73bed67f 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -16,8 +16,8 @@ from scipy.optimize import minimize # type: ignore from scipy.special import gammaln # type: ignore from scipy.special import polygamma # type: ignore -from scipy.stats import norm # type: ignore from scipy.stats import chi2 # type: ignore +from scipy.stats import norm # type: ignore from sklearn.linear_model import LinearRegression # type: ignore import pydeseq2 @@ -999,6 +999,39 @@ def lrt_test( Parameters ---------- + counts : ndarray + Raw counts for a given gene. + + design_matrix : ndarray + Design matrix. + + reduced_design_matrix : ndarray + Reduced design matrix. + + size_factors : ndarray + DESeq2 normalization factors. + + disp : float + Dispersion estimate. + + lfc : ndarray + Log-fold change estimate (in natural log scale). + + min_mu : float + Lower bound on estimated means, to ensure numerical stability. + (default: ``0.5``). + + ridge_factor : ndarray + Regularization factors. + + reduced_ridge_factor : ndarray + Reduced regularization factors. + + beta_tol : float + Stopping criterion for IRWLS: + :math:`\vert dev - dev_{old}\vert / \vert dev + 0.1 \vert < \beta_{tol}`. + (default: ``1e-8``). + Returns ------- @@ -1008,6 +1041,7 @@ def lrt_test( lrt_statistic : float LRT statistic. """ + def reg_nb_nll( beta: np.ndarray, design_matrix: np.ndarray, ridge_factor: np.ndarray ) -> float: @@ -1032,6 +1066,9 @@ def reg_nb_nll( # df = 1 since contrast_idx is the only variable removed lrt_p_value = chi2.sf(lrt_statistic, df=1) + print(lrt_p_value) + print(lrt_statistic) + return lrt_p_value, lrt_statistic From 428952156494efb7c71740c54a432e667d6e1728 Mon Sep 17 00:00:00 2001 From: alexandrenowak Date: Tue, 17 Oct 2023 16:32:44 +0200 Subject: [PATCH 6/8] minor errors --- examples/plot_minimal_pydeseq2_pipeline.py | 2 +- pydeseq2/ds.py | 25 ++++++++++++---------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/examples/plot_minimal_pydeseq2_pipeline.py b/examples/plot_minimal_pydeseq2_pipeline.py index 28b27aab..25a59dd0 100644 --- a/examples/plot_minimal_pydeseq2_pipeline.py +++ b/examples/plot_minimal_pydeseq2_pipeline.py @@ -217,7 +217,7 @@ # should be a *fitted* :class:`DeseqDataSet ` # object. -stat_res = DeseqStats(dds, n_cpus=8) +stat_res = DeseqStats(dds, test="LRT", n_cpus=8) # %% # It also has a set of optional keyword arguments (see the :doc:`API documentation diff --git a/pydeseq2/ds.py b/pydeseq2/ds.py index 1e4368c7..f71d4f1a 100644 --- a/pydeseq2/ds.py +++ b/pydeseq2/ds.py @@ -151,7 +151,7 @@ class DeseqStats: def __init__( self, dds: DeseqDataSet, - test: Literal["Wald", "LRT"] = "Wald", + test: Literal["wald", "LRT"] = "wald", contrast: Optional[List[str]] = None, alpha: float = 0.05, cooks_filter: bool = True, @@ -172,7 +172,7 @@ def __init__( self.dds = dds - if test not in ("Wald", "LRT"): + if test not in ("wald", "LRT"): raise ValueError(f"Available tests are `wald` and `LRT`. Got: {test}.") self.test = test @@ -217,6 +217,9 @@ def __init__( "to False." ) + self.p_values: pd.Series + self.statistics: pd.Series + def summary( self, **kwargs, @@ -260,7 +263,7 @@ def summary( self.alt_hypothesis = alt_hypothesis rerun_summary = True - if self.test == "Wald": + if self.test == "wald": self.run_wald_test() else: self.run_likelihood_ratio_test() @@ -282,7 +285,7 @@ def summary( self.results_df = pd.DataFrame(index=self.dds.var_names) self.results_df["baseMean"] = self.base_mean self.results_df["log2FoldChange"] = self.LFC @ self.contrast_vector / np.log(2) - if self.test == "Wald": + if self.test == "wald": self.results_df["lfcSE"] = self.SE / np.log(2) self.results_df["stat"] = self.statistics self.results_df["pvalue"] = self.p_values @@ -290,11 +293,11 @@ def summary( if self.contrast[1] == self.contrast[2] == "": # The factor is continuous - print(f"Log2 fold change & Wald test p-value: " f"{self.contrast[0]}") + print(f"Log2 fold change & test p-value: " f"{self.contrast[0]}") else: # The factor is categorical print( - f"Log2 fold change & Wald test p-value: " + f"Log2 fold change & test p-value: " f"{self.contrast[0]} {self.contrast[1]} vs {self.contrast[2]}" ) display(self.results_df) @@ -360,8 +363,8 @@ def run_wald_test(self) -> None: pvals, stats, se = zip(*res) - self.p_values: pd.Series = pd.Series(pvals, index=self.dds.var_names) - self.statistics: pd.Series = pd.Series(stats, index=self.dds.var_names) + self.p_values = pd.Series(pvals, index=self.dds.var_names) + self.statistics = pd.Series(stats, index=self.dds.var_names) self.SE: pd.Series = pd.Series(se, index=self.dds.var_names) # Account for possible all_zeroes due to outlier refitting in DESeqDataSet @@ -386,7 +389,7 @@ def reduce( ) -> Tuple[np.ndarray, np.ndarray]: indices = np.full(design_matrix.shape[1], True, dtype=bool) indices[self.contrast_idx] = False - return design_matrix[:, indices], ridge_factor[indices][:, indices] + return design_matrix[:, indices], ridge_factor[:, indices][indices] # Set regularization factors. if self.prior_LFC_var is not None: @@ -429,8 +432,8 @@ def reduce( pvals, stats = zip(*res) - self.p_values: pd.Series = pd.Series(pvals, index=self.dds.var_names) - self.statistics: pd.Series = pd.Series(stats, index=self.dds.var_names) + self.p_values = pd.Series(pvals, index=self.dds.var_names) + self.statistics = pd.Series(stats, index=self.dds.var_names) # Account for possible all_zeroes due to outlier refitting in DESeqDataSet if self.dds.refit_cooks and self.dds.varm["replaced"].sum() > 0: From e2afb7d9b03e9490abe43c5724f1268b78316c3d Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 17 Oct 2023 16:59:14 +0200 Subject: [PATCH 7/8] cosmit --- pydeseq2/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 73bed67f..118c6d48 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -1032,7 +1032,6 @@ def lrt_test( :math:`\vert dev - dev_{old}\vert / \vert dev + 0.1 \vert < \beta_{tol}`. (default: ``1e-8``). - Returns ------- lrt_p_value : float From b913d20c105800e13b402b3bdaae51fed21a7b93 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 17 Oct 2023 17:09:28 +0200 Subject: [PATCH 8/8] cosmit --- pydeseq2/ds.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pydeseq2/ds.py b/pydeseq2/ds.py index f71d4f1a..b2398db4 100644 --- a/pydeseq2/ds.py +++ b/pydeseq2/ds.py @@ -219,6 +219,7 @@ def __init__( self.p_values: pd.Series self.statistics: pd.Series + self.SE: pd.Series def summary( self, @@ -365,7 +366,7 @@ def run_wald_test(self) -> None: self.p_values = pd.Series(pvals, index=self.dds.var_names) self.statistics = pd.Series(stats, index=self.dds.var_names) - self.SE: pd.Series = pd.Series(se, index=self.dds.var_names) + self.SE = pd.Series(se, index=self.dds.var_names) # Account for possible all_zeroes due to outlier refitting in DESeqDataSet if self.dds.refit_cooks and self.dds.varm["replaced"].sum() > 0: