diff --git a/pydeseq2/ds.py b/pydeseq2/ds.py index 186548d9..ce724fc5 100644 --- a/pydeseq2/ds.py +++ b/pydeseq2/ds.py @@ -16,6 +16,7 @@ from pydeseq2.inference import Inference from pydeseq2.utils import lowess from pydeseq2.utils import make_MA_plot +from pydeseq2.utils import make_volcano_plot from pydeseq2.utils import n_or_more_replicates @@ -506,6 +507,67 @@ def plot_MA(self, log: bool = True, save_path: Optional[str] = None, **kwargs): **kwargs, ) + def plot_volcano( + self, + LFC_threshold: float = 2.0, + pval_threshold: float = 0.05, + annotate_genes: bool = True, + write_legend: bool = False, + save_path: Optional[str] = None, + figsize: tuple = (6, 6), + varying_marker_size: bool = True, + ): + """ + Create a volcano plot using matplotlib. + + Summarizes the results of a differential expression analysis by plotting + the negative log10-transformed adjusted p-values against the log2 fold change. + + Parameters + ---------- + LFC_threshold : float + Log2 fold change threshold above which genes are considered differentially + expressed. (default: ``2.``). + + pval_threshold : float + P-value threshold below which genes are considered differentially expressed. + (default: ``0.05``). + + annotate_genes : bool + Whether or not to annotate genes that pass the LFC and p-value thresholds. + (default: ``True``). + + write_legend : bool + Whether or not to write the legend on the plot. (default: ``True``). + + save_path : str or None + The path where to save the plot. If left None, the plot won't be saved + (``default=None``). + + figsize : tuple + The size of the figure. (default: ``(6, 6)``). + + varying_marker_size: bool + Whether to vary the marker size based on the base mean. (default: ``True``). + """ + # Raise an error if results_df are missing + if not hasattr(self, "results_df"): + raise AttributeError( + "Trying to make a volcano plot but p-values were not computed yet. " + "Please run the summary() method first." + ) + + make_volcano_plot( + self.results_df, + LFC_threshold, + pval_threshold, + annotate_genes, + write_legend, + save_path, + figsize, + varying_marker_size, + ) + def _independent_filtering(self) -> None: """Compute adjusted p-values using independent filtering. diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 95d93ac1..84f162e2 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -12,6 +12,8 @@ import numpy as np import pandas as pd +import seaborn as sns +from adjustText import adjust_text # type: ignore from matplotlib import pyplot as plt from scipy.linalg import solve # type: ignore from scipy.optimize import minimize # type: ignore @@ -1536,6 +1538,134 @@ def make_MA_plot( if save_path is not None: plt.savefig(save_path, bbox_inches="tight") + # Adapted from https://github.com/mousepixels/sanbomics_scripts/blob/main/high_quality_volcano_plots.ipynb # noqa: E501 + + +def make_volcano_plot( + results_df: pd.DataFrame, + LFC_threshold: float = 2, + pval_threshold: float = 0.05, + annotate_genes: bool = True, + write_legend: bool = True, + save_path: Optional[str] = None, + figsize: tuple = (6, 6), + varying_marker_size: bool = True, +): + """ + Create a volcano plot using matplotlib. + + Summarizes the results of a differential expression analysis by plotting + the negative log10-transformed adjusted p-values against the log2 fold change. + + Parameters + ---------- + results_df : pd.DataFrame + Dataframe obtained after running DeseqStats.summary() (the + ``results_df`` attribute). + + LFC_threshold : float + Log2 fold change threshold above which genes are considered differentially + expressed. (default: ``2.``). + + pval_threshold : float + P-value threshold below which genes are considered differentially expressed. + (default: ``0.05``). + + annotate_genes : bool + Whether or not to annotate genes that pass the LFC and p-value thresholds. + (default: ``True``). + + write_legend : bool + Whether or not to write the legend on the plot. (default: ``True``). + + save_path : str or None + The path where to save the plot. If left None, the plot won't be saved + (``default=None``). + + figsize : tuple + The size of the figure. (default: ``(6, 6)``). + + varying_marker_size : bool + Whether or not to vary the marker size based on the base mean. + (default: ``True``) + """ + plt.figure(figsize=figsize) + + nlgo10_pval_threshold = -np.log10(pval_threshold) + + def map_DE(a): + log2FoldChange, nlog10 = a + if nlog10 > nlgo10_pval_threshold: + if log2FoldChange > LFC_threshold: + return "positive" + elif log2FoldChange < -LFC_threshold: + return "negative" + return "none" + + df = results_df.copy() + df["nlog10"] = -df["padj"].apply(lambda x: np.log10(x)) + df["DE"] = df[["log2FoldChange", "nlog10"]].apply(map_DE, axis=1) + + ax = sns.scatterplot( + data=df, + x="log2FoldChange", + y="nlog10", + hue="DE", + hue_order=["none", "positive", "negative"], + palette=["lightgrey", "indianred", "cornflowerblue"], + size="baseMean" if varying_marker_size else 40, + sizes=(40, 400) if varying_marker_size else None, + ) + + ax.axhline(nlgo10_pval_threshold, zorder=0, c="k", lw=2, ls="--") + ax.axvline(LFC_threshold, zorder=0, c="k", lw=2, ls="--") + ax.axvline(-LFC_threshold, zorder=0, c="k", lw=2, ls="--") + + if annotate_genes: + texts = [] + for i in range(len(df)): + if ( + df.iloc[i].nlog10 > nlgo10_pval_threshold + and abs(df.iloc[i].log2FoldChange) > LFC_threshold + ): + texts.append( + plt.text( + x=df.iloc[i].log2FoldChange, + y=df.iloc[i].nlog10, + s=df.index[i], + fontsize=12, + weight="bold", + ) + ) + + adjust_text(texts, arrowprops={"arrowstyle": "-", "color": "k"}) + + if write_legend: + plt.legend( + loc=1, bbox_to_anchor=(1.4, 1), frameon=False, prop={"weight": "bold"} + ) + else: + ax.get_legend().remove() + + for axis in ["bottom", "left"]: + ax.spines[axis].set_linewidth(2) + + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + ax.tick_params(width=2) + + plt.xticks(size=12, weight="bold") + plt.yticks(size=12, weight="bold") + + plt.xlabel("$log_{2}$ fold change", size=15) + plt.ylabel("-$log_{10}$ FDR", size=15) + + if save_path is not None: + plt.savefig(save_path=save_path, bbox_inches="tight") + + plt.show() + # Authors: Alexandre Gramfort # diff --git a/pyproject.toml b/pyproject.toml index e30e39a8..41bb1349 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ -[tool.black] +[tool.linting.black] line-length = 89 -[tool.ruff] +[tool.linting.ruff] line-length = 89 select = [ "F", # Errors detected by Pyflakes diff --git a/setup.py b/setup.py index d0217943..cedf73d4 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,8 @@ "scikit-learn>=1.1.0", "scipy>=1.11.0", "matplotlib>=3.6.2", # not sure why sphinx_gallery does not work without it + "seaborn", + "adjustText", ], # external packages as dependencies extras_require={ "dev": [ diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 494da4c8..489786da 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -522,8 +522,8 @@ def test_zero_inflated(): def test_plot_MA(): """ - Test that a KeyError is thrown when attempting to run plot_MA without running the - statistical analysis first. + Test that an AttributeError is thrown when attempting to run plot_MA without running + the statistical analysis first. """ counts_df = load_example_data( @@ -551,3 +551,36 @@ def test_plot_MA(): res.summary() # Now this shouldn't throw an error res.plot_MA() + + +def test_plot_volcano(): + """ + Test that an AttributeError is thrown when attempting to run plot_volcano() without + running the statistical analysis first. + """ + + counts_df = load_example_data( + modality="raw_counts", + dataset="synthetic", + debug=False, + ) + + metadata = load_example_data( + modality="metadata", + dataset="synthetic", + debug=False, + ) + + dds = DeseqDataSet(counts=counts_df, metadata=metadata) + dds.deseq2() + + # Initialize a DeseqStats object without runnning the analysis + ds = DeseqStats(dds) + + with pytest.raises(AttributeError): + ds.plot_volcano() + + # Run the analysis + ds.summary() + # Now this shouldn't throw an error + ds.plot_volcano()