Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sgrouard/extend refacto benchmark #16

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmark_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
create_random_proportion,
)
from .dataset_utils import (
create_anndata_pseudobulk,
preprocess_scrna,
split_dataset,
create_new_granularity_index,
Expand Down
45 changes: 30 additions & 15 deletions benchmark_utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,33 @@
def preprocess_scrna(
adata: ad.AnnData, keep_genes: int = 2000, batch_key: Optional[str] = None
):
"""Preprocess single-cell RNA data for deconvolution benchmarking."""
"""Preprocess single-cell RNA data for deconvolution benchmarking.

* in adata.X, the normalized log1p counts are saved
* in adata.layers["counts"], raw counts are saved
* in adata.layers["relative_counts"], the relative counts are saved
=> The highly variable genes can be found in adata.var["highly_variable"]

"""
sc.pp.filter_genes(adata, min_counts=3)
adata.layers["counts"] = adata.X.copy() # preserve counts, used for training
sc.pp.normalize_total(adata, target_sum=1e4)
adata.layers["relative_counts"] = adata.X.copy() # preserve counts, used for
adata.layers["relative_counts"] = adata.X.copy() # preserve counts, used for baselines
sc.pp.log1p(adata)
adata.raw = adata # freeze the state in `.raw`
sc.pp.highly_variable_genes(
adata,
n_top_genes=keep_genes,
subset=True,
layer="counts",
flavor="seurat_v3",
batch_key=batch_key,
subset=False,
inplace=True
)
#TODO: add the filtering / QC steps that they perform in Servier
# concat the result df to adata.var

return adata


def split_dataset(
Expand Down Expand Up @@ -110,17 +121,21 @@ def add_cell_types_grouped(
elif group == "FACS_1st_level_granularity":
train_test_index = pd.read_csv("/home/owkin/project/train_test_index_dataframes/train_test_index_facs_1st_level.csv", index_col=1).iloc[:,1:]
col_name = "grouping"
adata.obs["cell_types_grouped"] = train_test_index[col_name]
adata.obs[f"cell_types_grouped_{group}"] = train_test_index[col_name]
return adata, train_test_index


def create_anndata_pseudobulk(adata: ad.AnnData, x: np.array) -> ad.AnnData:
def create_anndata_pseudobulk(
adata_obs: pd.DataFrame, adata_var_names: list, x: np.array
) -> ad.AnnData:
"""Creates an anndata object from a pseudobulk sample.

Parameters
----------
adata: ad.AnnData
AnnData aobject storing training set
adata_obs: pd.DataFrame
Obs dataframe from anndata object storing training set
adata_var_names: list
Gene names from the anndata object
x: np.array
pseudobulk sample

Expand All @@ -130,14 +145,14 @@ def create_anndata_pseudobulk(adata: ad.AnnData, x: np.array) -> ad.AnnData:
Anndata object storing the pseudobulk array
"""
df_obs = pd.DataFrame.from_dict(
[{col: adata.obs[col].value_counts().index[0] for col in adata.obs.columns}]
[{col: adata_obs[col].value_counts().index[0] for col in adata_obs.columns}]
)
if len(x.shape) > 1 and x.shape[0] > 1:
# several pseudobulks, so duplicate df_obs row
df_obs = df_obs.loc[df_obs.index.repeat(x.shape[0])].reset_index(drop=True)
df_obs.index = [f"sample_{idx}" for idx in df_obs.index]
adata_pseudobulk = ad.AnnData(X=x, obs=df_obs)
adata_pseudobulk.var_names = adata.var_names
adata_pseudobulk.var_names = adata_var_names
adata_pseudobulk.layers["counts"] = np.copy(x)
adata_pseudobulk.raw = adata_pseudobulk

Expand Down Expand Up @@ -167,10 +182,10 @@ def create_purified_pseudobulk_dataset(
group.append(group_key)

# pseudobulk dataset
adata_pseudobulk_rc = create_anndata_pseudobulk(adata,
adata_pseudobulk_rc = create_anndata_pseudobulk(adata.obs, adata.var_names,
np.array(averaged_data["relative_counts"])
)
adata_pseudobulk_counts = create_anndata_pseudobulk(adata,
adata_pseudobulk_counts = create_anndata_pseudobulk(adata.obs, adata.var_names,
np.array(averaged_data["counts"])
)
adata_pseudobulk_rc.obs_names = group
Expand Down Expand Up @@ -210,10 +225,10 @@ def create_uniform_pseudobulk_dataset(
averaged_data["counts"].append(adata_sample.layers["counts"].sum(axis=0).tolist()[0])

# pseudobulk dataset
adata_pseudobulk_rc = create_anndata_pseudobulk(adata,
adata_pseudobulk_rc = create_anndata_pseudobulk(adata.obs, adata.var_names,
np.array(averaged_data["relative_counts"])
)
adata_pseudobulk_counts = create_anndata_pseudobulk(adata,
adata_pseudobulk_counts = create_anndata_pseudobulk(adata.obs, adata.var_names,
np.array(averaged_data["counts"])
)

Expand Down Expand Up @@ -302,10 +317,10 @@ def create_dirichlet_pseudobulk_dataset(
all_adata_samples.append(adata_sample)

# pseudobulk dataset
adata_pseudobulk_rc = create_anndata_pseudobulk(adata,
adata_pseudobulk_rc = create_anndata_pseudobulk(adata.obs, adata.var_names,
np.array(averaged_data["relative_counts"])
)
adata_pseudobulk_counts = create_anndata_pseudobulk(adata,
adata_pseudobulk_counts = create_anndata_pseudobulk(adata.obs, adata.var_names,
np.array(averaged_data["counts"])
)

Expand Down
5 changes: 3 additions & 2 deletions benchmark_utils/deconv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def perform_latent_deconv(adata_pseudobulk: ad.AnnData,
scvi.model.MixUpVI,
scvi.model.CondSCVI]],
all_adata_samples,
filtered_genes,
use_mixupvi: bool = True,
use_nnls: bool = True,
use_softmax: bool = False) -> pd.DataFrame:
Expand Down Expand Up @@ -74,15 +75,15 @@ def perform_latent_deconv(adata_pseudobulk: ad.AnnData,
if use_mixupvi:
latent_pseudobulks=[]
for i in range(len(all_adata_samples)):
latent_pseudobulks.append(model.get_latent_representation(all_adata_samples[i], get_pseudobulk=True))
latent_pseudobulks.append(model.get_latent_representation(all_adata_samples[i,filtered_genes], get_pseudobulk=True))
latent_pseudobulk = np.concatenate(latent_pseudobulks, axis=0)
else:
adata_pseudobulk = ad.AnnData(X=adata_pseudobulk.layers["counts"],
obs=adata_pseudobulk.obs,
var=adata_pseudobulk.var)
adata_pseudobulk.layers["counts"] = adata_pseudobulk.X.copy()

latent_pseudobulk = model.get_latent_representation(adata_pseudobulk)
latent_pseudobulk = model.get_latent_representation(adata_pseudobulk, get_pseudobulk=False)

if use_nnls:
deconv = LinearRegression(positive=True).fit(adata_latent_signature.X.T,
Expand Down
50 changes: 26 additions & 24 deletions benchmark_utils/latent_signature_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from .dataset_utils import create_anndata_pseudobulk
from constants import SIGNATURE_TYPE


def create_latent_signature(
Expand All @@ -16,7 +17,6 @@ def create_latent_signature(
repeats: int = 1,
average_all_cells: bool = True,
sc_per_pseudobulk: int = 3000,
signature_type: str = "pre-encoded",
cell_type_column: str = "cell_types_grouped",
count_key: Optional[str] = "counts",
representation_key: Optional[str] = "X_scvi",
Expand Down Expand Up @@ -101,34 +101,36 @@ def create_latent_signature(
)
adata_sampled = adata[sampled_cells]

if signature_type == "pre-encoded":
assert (
model is not None,
"If representing a purified pseudo bulk (aggregate before embedding",
"), must give a model",
)
assert (
count_key is not None
), "Must give a count key if aggregating before embedding."

if use_mixupvi:
result = model.get_latent_representation(
adata_sampled, get_pseudobulk=True
).reshape(-1)
else:
assert (
model is not None,
"If representing a purified pseudo bulk (aggregate before embedding",
"), must give a model",
)
assert (
count_key is not None
), "Must give a count key if aggregating before embedding."

if use_mixupvi:
# TODO: in this case, n_cells sampled will be equal to self.n_cells_per_pseudobulk by mixupvae
# so change that to being equal to either all cells (if average_all_cells) or sc_per_pseudobulk
result = model.get_latent_representation(
adata_sampled, get_pseudobulk=True
)[0] # take first pseudobulk
else:
if SIGNATURE_TYPE == "pre_encoded":
pseudobulk = (
adata_sampled.layers[count_key].mean(axis=0).reshape(1, -1)
) # .astype(int).astype(numpy.float32)
adata_pseudobulk = create_anndata_pseudobulk(
adata_sampled, pseudobulk
)
result = model.get_latent_representation(adata_pseudobulk).reshape(
-1
adata_sampled = create_anndata_pseudobulk(
adata_sampled.obs, adata_sampled.var_names, pseudobulk
)
else:
raise ValueError(
"Only pre-encoded signatures are supported for now."
result = model.get_latent_representation(
adata_sampled, get_pseudobulk = False
)
if SIGNATURE_TYPE == "pre_encoded":
result = result.reshape(-1)
elif SIGNATURE_TYPE == "post_inference":
result = result.mean(axis=0)
repeat_list.append(repeat)
representation_list.append(result)
cell_type_list.append(cell_type)
Expand Down
18 changes: 17 additions & 1 deletion benchmark_utils/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

from typing import Dict
from datetime import datetime
from loguru import logger

def plot_purified_deconv_results(deconv_results, only_fit_one_baseline, more_details=False, save=False, filename="test"):
"""Plot the deconv results from sanity check 1"""
Expand Down Expand Up @@ -106,7 +109,14 @@ def plot_deconv_lineplot(results: Dict[int, pd.DataFrame],
plt.show()

if save:
plt.savefig(f"/home/owkin/project/plots/{filename}.png", dpi=300)
path = f"/home/owkin/project/plots/{filename}.png"
if os.path.isfile(path):
new_path = f"/home/owkin/project/plots/{filename}_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.png"
logger.warning(f"{path} already exists. Saving file on this path instead: {new_path}")
path = new_path
plt.savefig(path, dpi=300)
logger.info(f"Plot saved to the following path: {path}")


def plot_metrics(model_history, train: bool = True, n_epochs: int = 100):
"""Plot the train or val metrics from training."""
Expand Down Expand Up @@ -241,6 +251,12 @@ def compare_tuning_results(

custom_palette = sns.color_palette("husl", n_colors=len(all_results[variable_tuned].unique()))
all_results["epoch"] = all_results.index
if (n_nan := all_results[variable_to_plot].isna().sum()) > 0:
print(
f"There are {n_nan} missing values in the variable to plot ({variable_to_plot})."
"Filling them with the next row values."
)
all_results[variable_to_plot] = all_results[variable_to_plot].fillna(method='bfill')
sns.set_theme(style="darkgrid")
sns.lineplot(x="epoch", y=variable_to_plot, hue=variable_tuned, ci="sd", data=all_results, err_style="bars", palette=custom_palette)
plt.show()
31 changes: 17 additions & 14 deletions benchmark_utils/sanity_checks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def run_purified_sanity_check(
adata_train: ad.AnnData,
adata_pseudobulk_test_counts: ad.AnnData,
adata_pseudobulk_test_rc: ad.AnnData,
filtered_genes: list,
signature: pd.DataFrame,
intersection: List[str],
generative_models : Dict[str, Union[scvi.model.SCVI,
scvi.model.CondSCVI,
scvi.model.DestVI,
Expand All @@ -62,8 +62,6 @@ def run_purified_sanity_check(
pseudobulk RNA seq test dataset (relative counts).
signature: pd.DataFrame
Signature matrix.
intersection: List[str]
List of genes in common between the signature and the test dataset.
generative_models: Dict[str, scvi.model]
Dictionnary of generative models.
baselines: List[str]
Expand All @@ -79,14 +77,15 @@ def run_purified_sanity_check(
deconv_results_melted_methods = pd.DataFrame(columns=["Cell type predicted", "Cell type", "Estimated Fraction", "Method"])
## NNLS
if "nnls" in baselines:
deconv_results = perform_nnls(signature, adata_pseudobulk_test_rc[:, intersection])
deconv_results = perform_nnls(signature, adata_pseudobulk_test_rc[:, signature.index])
deconv_results_melted_methods_tmp = melt_df(deconv_results)
deconv_results_melted_methods_tmp["Method"] = "nnls"
deconv_results_melted_methods = pd.concat(
[deconv_results_melted_methods, deconv_results_melted_methods_tmp]
)

# Pseudobulk Dataframe for TAPE and Scaden
intersection = set(signature.index).intersection(set(filtered_genes))
pseudobulk_test_df = pd.DataFrame(
adata_pseudobulk_test_rc[:, intersection].X,
index=adata_pseudobulk_test_rc.obs_names,
Expand Down Expand Up @@ -138,13 +137,13 @@ def run_purified_sanity_check(
# )
else:
adata_latent_signature = create_latent_signature(
adata=adata_train,
adata=adata_train[:,filtered_genes],
model=generative_models[model],
average_all_cells = True,
sc_per_pseudobulk=3000,
)
deconv_results = perform_latent_deconv(
adata_pseudobulk=adata_pseudobulk_test_counts,
adata_pseudobulk=adata_pseudobulk_test_counts[:,filtered_genes],
adata_latent_signature=adata_latent_signature,
model=generative_models[model],
)
Expand All @@ -161,9 +160,9 @@ def run_sanity_check(
adata_pseudobulk_test_counts: ad.AnnData,
adata_pseudobulk_test_rc: ad.AnnData,
all_adata_samples_test: List[ad.AnnData],
filtered_genes: list,
df_proportions_test: pd.DataFrame,
signature: pd.DataFrame,
intersection: List[str],
generative_models : Dict[str, Union[scvi.model.SCVI,
scvi.model.CondSCVI,
scvi.model.DestVI,
Expand All @@ -183,10 +182,10 @@ def run_sanity_check(
pseudobulk RNA seq test dataset.
adata_pseudobulk_test_rc: ad.AnnData
pseudobulk RNA seq test dataset (relative counts).
filtered_genes: list
The most variable genes filtered, used for all methods except NNLS (can be challenged).
signature: pd.DataFrame
Signature matrix.
intersection: List[str]
List of genes in common between the signature and the test dataset.
generative_models: Dict[str, scvi.model]
Dictionnary of generative models.
baselines: List[str]
Expand All @@ -211,12 +210,13 @@ def run_sanity_check(
# 1. Linear regression (NNLS)
if "nnls" in baselines:
deconv_results = perform_nnls(signature,
adata_pseudobulk_test_rc[:, intersection])
adata_pseudobulk_test_rc[:, signature.index])
correlations = compute_correlations(deconv_results, df_proportions_test)
group_correlations = compute_group_correlations(deconv_results, df_proportions_test)
df_test_correlations.loc[:, "nnls"] = correlations.values
df_test_group_correlations.loc[:, "nnls"] = group_correlations.values

intersection = list(set(signature.index).intersection(set(filtered_genes)))
pseudobulk_test_df = pd.DataFrame(
adata_pseudobulk_test_rc[:, intersection].X,
index=adata_pseudobulk_test_rc.obs_names,
Expand Down Expand Up @@ -265,16 +265,19 @@ def run_sanity_check(
if model == "MixupVI":
use_mixupvi=True
adata_latent_signature = create_latent_signature(
adata=adata_train,
adata=adata_train[:,filtered_genes],
model=generative_models[model],
use_mixupvi=use_mixupvi,
use_mixupvi=False, # should be equal to use_mixupvi, but if True,
# then it averages as many cells as self.n_cells_per-pseudobulk from mixupvae
# (and not the number we wish in the benchmark)
average_all_cells = True,
sc_per_pseudobulk=10000,
)
deconv_results = perform_latent_deconv(
adata_pseudobulk=adata_pseudobulk_test_counts,
adata_pseudobulk=adata_pseudobulk_test_counts[:,filtered_genes],
all_adata_samples=all_adata_samples_test,
use_mixupvi=use_mixupvi,
filtered_genes=filtered_genes,
use_mixupvi=False, # see comment above
adata_latent_signature=adata_latent_signature,
model=generative_models[model],
)
Expand Down
Loading