Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH volcano plots #317

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions pydeseq2/ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pydeseq2.inference import Inference
from pydeseq2.utils import lowess
from pydeseq2.utils import make_MA_plot
from pydeseq2.utils import make_volcano_plot
from pydeseq2.utils import n_or_more_replicates


Expand Down Expand Up @@ -506,6 +507,67 @@ def plot_MA(self, log: bool = True, save_path: Optional[str] = None, **kwargs):
**kwargs,
)

def plot_volcano(
self,
LFC_threshold: float = 2.0,
pval_threshold: float = 0.05,
annotate_genes: bool = True,
write_legend: bool = False,
save_path: Optional[str] = None,
figsize: tuple = (6, 6),
varying_marker_size: bool = True,
):
"""
Create a volcano plot using matplotlib.

Summarizes the results of a differential expression analysis by plotting
the negative log10-transformed adjusted p-values against the log2 fold change.

Parameters
----------
LFC_threshold : float
Log2 fold change threshold above which genes are considered differentially
expressed. (default: ``2.``).

pval_threshold : float
P-value threshold below which genes are considered differentially expressed.
(default: ``0.05``).

annotate_genes : bool
Whether or not to annotate genes that pass the LFC and p-value thresholds.
(default: ``True``).

write_legend : bool
Whether or not to write the legend on the plot. (default: ``True``).

save_path : str or None
The path where to save the plot. If left None, the plot won't be saved
(``default=None``).

figsize : tuple
The size of the figure. (default: ``(6, 6)``).

varying_marker_size: bool
Whether to vary the marker size based on the base mean. (default: ``True``).
"""
# Raise an error if results_df are missing
if not hasattr(self, "results_df"):
raise AttributeError(
"Trying to make a volcano plot but p-values were not computed yet. "
"Please run the summary() method first."
)

make_volcano_plot(
self.results_df,
LFC_threshold,
pval_threshold,
annotate_genes,
write_legend,
save_path,
figsize,
varying_marker_size,
)

def _independent_filtering(self) -> None:
"""Compute adjusted p-values using independent filtering.

Expand Down
130 changes: 130 additions & 0 deletions pydeseq2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import numpy as np
import pandas as pd
import seaborn as sns
from adjustText import adjust_text # type: ignore
from matplotlib import pyplot as plt
from scipy.linalg import solve # type: ignore
from scipy.optimize import minimize # type: ignore
Expand Down Expand Up @@ -1536,6 +1538,134 @@ def make_MA_plot(
if save_path is not None:
plt.savefig(save_path, bbox_inches="tight")

# Adapted from https://github.com/mousepixels/sanbomics_scripts/blob/main/high_quality_volcano_plots.ipynb # noqa: E501


def make_volcano_plot(
results_df: pd.DataFrame,
LFC_threshold: float = 2,
pval_threshold: float = 0.05,
annotate_genes: bool = True,
write_legend: bool = True,
save_path: Optional[str] = None,
figsize: tuple = (6, 6),
varying_marker_size: bool = True,
):
"""
Create a volcano plot using matplotlib.

Summarizes the results of a differential expression analysis by plotting
the negative log10-transformed adjusted p-values against the log2 fold change.

Parameters
----------
results_df : pd.DataFrame
Dataframe obtained after running DeseqStats.summary() (the
``results_df`` attribute).

LFC_threshold : float
Log2 fold change threshold above which genes are considered differentially
expressed. (default: ``2.``).

pval_threshold : float
P-value threshold below which genes are considered differentially expressed.
(default: ``0.05``).

annotate_genes : bool
Whether or not to annotate genes that pass the LFC and p-value thresholds.
(default: ``True``).

write_legend : bool
Whether or not to write the legend on the plot. (default: ``True``).

save_path : str or None
The path where to save the plot. If left None, the plot won't be saved
(``default=None``).

figsize : tuple
The size of the figure. (default: ``(6, 6)``).

varying_marker_size : bool
Whether or not to vary the marker size based on the base mean.
(default: ``True``)
"""
plt.figure(figsize=figsize)

nlgo10_pval_threshold = -np.log10(pval_threshold)

def map_DE(a):
log2FoldChange, nlog10 = a
if nlog10 > nlgo10_pval_threshold:
if log2FoldChange > LFC_threshold:
return "positive"
elif log2FoldChange < -LFC_threshold:
return "negative"
return "none"

df = results_df.copy()
df["nlog10"] = -df["padj"].apply(lambda x: np.log10(x))
df["DE"] = df[["log2FoldChange", "nlog10"]].apply(map_DE, axis=1)

ax = sns.scatterplot(
data=df,
x="log2FoldChange",
y="nlog10",
hue="DE",
hue_order=["none", "positive", "negative"],
palette=["lightgrey", "indianred", "cornflowerblue"],
size="baseMean" if varying_marker_size else 40,
sizes=(40, 400) if varying_marker_size else None,
)

ax.axhline(nlgo10_pval_threshold, zorder=0, c="k", lw=2, ls="--")
ax.axvline(LFC_threshold, zorder=0, c="k", lw=2, ls="--")
ax.axvline(-LFC_threshold, zorder=0, c="k", lw=2, ls="--")

if annotate_genes:
texts = []
for i in range(len(df)):
if (
df.iloc[i].nlog10 > nlgo10_pval_threshold
and abs(df.iloc[i].log2FoldChange) > LFC_threshold
):
texts.append(
plt.text(
x=df.iloc[i].log2FoldChange,
y=df.iloc[i].nlog10,
s=df.index[i],
fontsize=12,
weight="bold",
)
)

adjust_text(texts, arrowprops={"arrowstyle": "-", "color": "k"})

if write_legend:
plt.legend(
loc=1, bbox_to_anchor=(1.4, 1), frameon=False, prop={"weight": "bold"}
)
else:
ax.get_legend().remove()

for axis in ["bottom", "left"]:
ax.spines[axis].set_linewidth(2)

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

ax.tick_params(width=2)

plt.xticks(size=12, weight="bold")
plt.yticks(size=12, weight="bold")

plt.xlabel("$log_{2}$ fold change", size=15)
plt.ylabel("-$log_{10}$ FDR", size=15)

if save_path is not None:
plt.savefig(save_path=save_path, bbox_inches="tight")

plt.show()


# Authors: Alexandre Gramfort <[email protected]>
#
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.black]
[tool.linting.black]
line-length = 89

[tool.ruff]
[tool.linting.ruff]
line-length = 89
select = [
"F", # Errors detected by Pyflakes
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
"scikit-learn>=1.1.0",
"scipy>=1.11.0",
"matplotlib>=3.6.2", # not sure why sphinx_gallery does not work without it
"seaborn",
"adjustText",
], # external packages as dependencies
extras_require={
"dev": [
Expand Down
37 changes: 35 additions & 2 deletions tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,8 @@ def test_zero_inflated():

def test_plot_MA():
"""
Test that a KeyError is thrown when attempting to run plot_MA without running the
statistical analysis first.
Test that an AttributeError is thrown when attempting to run plot_MA without running
the statistical analysis first.
"""

counts_df = load_example_data(
Expand Down Expand Up @@ -551,3 +551,36 @@ def test_plot_MA():
res.summary()
# Now this shouldn't throw an error
res.plot_MA()


def test_plot_volcano():
"""
Test that an AttributeError is thrown when attempting to run plot_volcano() without
running the statistical analysis first.
"""

counts_df = load_example_data(
modality="raw_counts",
dataset="synthetic",
debug=False,
)

metadata = load_example_data(
modality="metadata",
dataset="synthetic",
debug=False,
)

dds = DeseqDataSet(counts=counts_df, metadata=metadata)
dds.deseq2()

# Initialize a DeseqStats object without runnning the analysis
ds = DeseqStats(dds)

with pytest.raises(AttributeError):
ds.plot_volcano()

# Run the analysis
ds.summary()
# Now this shouldn't throw an error
ds.plot_volcano()
Loading