From 413976b5579d510f254f7e91e7e8a8c2ce4c747b Mon Sep 17 00:00:00 2001 From: SimonGrouard Date: Tue, 7 May 2024 15:42:13 +0000 Subject: [PATCH 1/6] Changes related to sensitivity analysis --- benchmark_utils/plotting_utils.py | 6 +++ benchmark_utils/training_utils.py | 9 ++-- benchmark_utils/tuning_utils.py | 70 ++++++++++++++++++++++++++++++- constants.py | 17 ++++---- run_mixupvi.py | 27 ++++++------ scvi/module/_mixupvae.py | 38 ++++++----------- tuning_configs.py | 42 ++++++++++--------- 7 files changed, 138 insertions(+), 71 deletions(-) diff --git a/benchmark_utils/plotting_utils.py b/benchmark_utils/plotting_utils.py index caecacfaff..d7066bfc3e 100644 --- a/benchmark_utils/plotting_utils.py +++ b/benchmark_utils/plotting_utils.py @@ -241,6 +241,12 @@ def compare_tuning_results( custom_palette = sns.color_palette("husl", n_colors=len(all_results[variable_tuned].unique())) all_results["epoch"] = all_results.index + if (n_nan := all_results[variable_to_plot].isna().sum()) > 0: + print( + f"There are {n_nan} missing values in the variable to plot ({variable_to_plot})." + "Filling them with the next row values." + ) + all_results[variable_to_plot] = all_results[variable_to_plot].fillna(method='bfill') sns.set_theme(style="darkgrid") sns.lineplot(x="epoch", y=variable_to_plot, hue=variable_tuned, ci="sd", data=all_results, err_style="bars", palette=custom_palette) plt.show() diff --git a/benchmark_utils/training_utils.py b/benchmark_utils/training_utils.py index 04c6c8c787..efc477e587 100644 --- a/benchmark_utils/training_utils.py +++ b/benchmark_utils/training_utils.py @@ -15,6 +15,7 @@ MAX_EPOCHS, BATCH_SIZE, LATENT_SIZE, + N_HIDDEN, N_PSEUDOBULKS, N_CELLS_PER_PSEUDOBULK, TRAIN_SIZE, @@ -29,8 +30,6 @@ MIXUP_PENALTY, DISPERSION, GENE_LIKELIHOOD, - MIXUP_PENATLY_AGGREGATION, - AVERAGE_VARIABLES_MIXUP_PENALTY, SEED, ) @@ -60,11 +59,12 @@ def tune_mixupvi(adata: ad.AnnData, search_space=search_space, num_samples=num_samples, # will randomly num_samples samples (with replacement) among the HP cominations specified max_epochs=MAX_EPOCHS, - resources={"cpu": 10, "gpu": 0.5}, + resources={"cpu": 6, "gpu": 1}, ) all_results, best_hp, tuning_path, search_path = format_and_save_tuning_results( tuning_results, variables=TUNED_VARIABLES, training_dataset=training_dataset, + cat_cov=CAT_COV, cont_cov=CONT_COV, ) return all_results, best_hp, tuning_path, search_path @@ -95,6 +95,7 @@ def fit_mixupvi(adata: ad.AnnData, n_pseudobulks=N_PSEUDOBULKS, n_cells_per_pseudobulk=N_CELLS_PER_PSEUDOBULK, n_latent=LATENT_SIZE, + n_hidden=N_HIDDEN, use_batch_norm=USE_BATCH_NORM, signature_type=SIGNATURE_TYPE, loss_computation=LOSS_COMPUTATION, @@ -103,8 +104,6 @@ def fit_mixupvi(adata: ad.AnnData, mixup_penalty=MIXUP_PENALTY, dispersion=DISPERSION, gene_likelihood=GENE_LIKELIHOOD, - mixup_penalty_aggregation=MIXUP_PENATLY_AGGREGATION, - average_variables_mixup_penalty=AVERAGE_VARIABLES_MIXUP_PENALTY, ) mixupvi_model.view_anndata_setup() mixupvi_model.train( diff --git a/benchmark_utils/tuning_utils.py b/benchmark_utils/tuning_utils.py index b53231e791..c1e77e646a 100644 --- a/benchmark_utils/tuning_utils.py +++ b/benchmark_utils/tuning_utils.py @@ -1,11 +1,17 @@ +"""Tuning utils file.""" + import json from collections import defaultdict import numpy as np import pandas as pd import os import pickle +from tuning_configs import TUNED_VARIABLES, SEARCH_SPACE, METRIC, ADDITIONAL_METRICS +from constants import TRAINING_DATASET -def format_and_save_tuning_results(tuning_results, variables: str, training_dataset : str): +def format_and_save_tuning_results( + tuning_results, variables: str, training_dataset : str, cat_cov : list, cont_cov : list, +): """Format the tuning results and save them in the project directory.""" # format the results of all experiments keys = list(tuning_results.results[0].metrics.keys()) @@ -61,6 +67,8 @@ def format_and_save_tuning_results(tuning_results, variables: str, training_data all_results.to_csv(tuning_path) search_space = tuning_results.search_space + search_space["cat_cov"] = cat_cov + search_space["cont_cov"] = cont_cov search_space["best_hp"] = best_hp with open(search_path, "wb") as ff: pickle.dump(search_space, ff) @@ -77,4 +85,62 @@ def read_search_space(search_path): search_space = pickle.load(ff) return search_space - \ No newline at end of file + +def format_and_save_tuning_results_backup(ray_directory: str = "tune_mixupvi_2024-04-08-08:55:24"): + """This function essentially does the same as format_and_save_tuning_results. + + But this one should be used in a handcrafted manner (by providing the ray directory + saved locally) when for some reason, tuning results were successfully saved locally + by ray, but not formatted and saved in the shared /project folder. + + Five global variables are used here and should be specified accordingly in the + tuning and constants config files : TUNED_VARIABLES, SEARCH_SPACE, TRAINING_DATASET, + METRIC, ADDITIONAL_METRICS + """ + directory = f"/home/owkin/deepdeconv-fork/ray/{ray_directory}/" + all_metrics = [METRIC] + ADDITIONAL_METRICS # all metric columns we want to retrieve + + all_results = [] + for path in os.listdir(directory): # loop through every result of hyperparameters tried + if path.startswith("_trainable"): + path = directory + path + results = defaultdict(list) + with open(path+"/result.json", "r") as ff: + for line in ff: + # loop through every epoch of the training + data = json.loads(line.strip()) + for key in all_metrics: + if key in data: + results[key].append(data[key]) + else: + results[key].append(np.nan) + results = pd.DataFrame(results) + + hyperparameters = path.split("/")[-1] + for i, variable in enumerate(sorted(TUNED_VARIABLES)): + hyperparameters=hyperparameters.split(f"{variable}=")[1] + if i < len(TUNED_VARIABLES)-1: + value = hyperparameters.split(",")[0] + else: + value = hyperparameters.split("-")[0][:-5] + results[variable] = value + + all_results.append(results) + + all_results = pd.concat(all_results) + # save results and search space + save_dir = f"/home/owkin/project/mixupvi_tuning/{'-'.join(TUNED_VARIABLES)}/" + new_path = save_dir + f"{TRAINING_DATASET}_dataset_{ray_directory}" + if not os.path.exists(save_dir): + # create a directory for the variable tuned + os.makedirs(save_dir) + if not os.path.exists(new_path): + # create a directory for the specific grid search performed + os.makedirs(new_path) + tuning_path = f"{new_path}/tuning_results.csv" + search_path = f"{new_path}/search_space.pkl" + all_results.to_csv(tuning_path) + + search_space = SEARCH_SPACE + with open(search_path, "wb") as ff: + pickle.dump(search_space, ff) diff --git a/constants.py b/constants.py index 494202c240..2140d5c863 100644 --- a/constants.py +++ b/constants.py @@ -2,7 +2,7 @@ ## constants for run_mixupvi.py TUNE_MIXUPVI = True -TRAINING_DATASET = "CTI_PROCESSED" # ["CTI", "TOY", "CTI_PROCESSED", "CTI_RAW"] +TRAINING_DATASET = "CTI" # ["CTI", "TOY", "CTI_PROCESSED", "CTI_RAW"] TRAINING_CELL_TYPE_GROUP = ( "2nd_level_granularity" # ["1st_level_granularity", "2nd_level_granularity", "3rd_level_granularity", "4th_level_granularity", "FACS_1st_level_granularity"] ) @@ -30,19 +30,20 @@ ## general mixupvi constants when training it or preprocessing data SAVE_MODEL = False -SEED = 0 -N_GENES = 3000 # number of input genes after preprocessing +SEED = 3 +N_GENES = 2000 # number of input genes after preprocessing # MixUpVI training hyperparameters MAX_EPOCHS = 100 -BATCH_SIZE = 2048 +BATCH_SIZE = 1024 TRAIN_SIZE = 0.7 # as opposed to validation CHECK_VAL_EVERY_N_EPOCH = None if TRAIN_SIZE < 1: CHECK_VAL_EVERY_N_EPOCH = 1 # MixUpVI model hyperparameters -N_PSEUDOBULKS = 1 -N_CELLS_PER_PSEUDOBULK = None # None (then will be batch size) or int (will cap at batch size) +N_PSEUDOBULKS = 100 +N_CELLS_PER_PSEUDOBULK = 256 # None (then will be batch size) or int (will cap at batch size) LATENT_SIZE = 30 +N_HIDDEN = 512 CONT_COV = None # None or list of continuous covariates to include CAT_COV = None # None or ["donor_id", "assay"] ENCODE_COVARIATES = False # whether to encode cont/cat covars (they are always decoded) @@ -52,8 +53,6 @@ MIXUP_PENALTY = "l2" # ["l2", "kl"] DISPERSION = "gene" # ["gene", "gene_label"] GENE_LIKELIHOOD = "zinb" # ["zinb", "nb", "poisson"] -MIXUP_PENATLY_AGGREGATION = "mean" # ["mean", "sum", "max"] -AVERAGE_VARIABLES_MIXUP_PENALTY = False USE_BATCH_NORM = "none" # ["encoder", "decoder", "none", "both"] # different possibilities of cell groupings with the CTI dataset @@ -193,3 +192,5 @@ } } + +# %% diff --git a/run_mixupvi.py b/run_mixupvi.py index 54f1afa785..f315ea6b1a 100644 --- a/run_mixupvi.py +++ b/run_mixupvi.py @@ -102,14 +102,15 @@ # %% Load model / results: Uncomment if not running previous cells # if TUNE_MIXUPVI: -# path = "/home/owkin/project/mixupvi_tuning/n_latent-seed/CTI_PROCESSED_dataset_tune_mixupvi_2024-02-21-11:25:28" +# path = "/home/owkin/project/mixupvi_tuning/n_hidden-seed/CTI_dataset_tune_mixupvi_2024-05-06-09:15:06" # all_results = read_tuning_results(f"{path}/tuning_results.csv") # search_space = read_search_space(f"{path}/search_space.pkl") -# best_hp = search_space["best_hp"] -# model_history = all_results.copy() -# for variable in best_hp : -# # plots for the best hp found by tuning -# model_history = model_history.loc[model_history[variable] == best_hp[variable]] +# if "best_hp" in search_space: +# best_hp = search_space["best_hp"] +# model_history = all_results.copy() +# for variable in best_hp : +# # plots for the best hp found by tuning +# model_history = model_history.loc[model_history[variable] == best_hp[variable]] # else: # import torch # model = torch.load(f"{model_path}/model.pt") @@ -130,17 +131,19 @@ # %% Plots to compare HPs if TUNE_MIXUPVI: - n_epochs = len(model_history["train_loss_epoch"]) + n_epochs = len(set(all_results["train_loss_epoch"].index)) hp_index_to_plot = None - # hp_index_to_plot = [1, 2, 3] # only these index (of the HPs tried) will be plotted, for clearer visualisation + # hp_index_to_plot = [2, 3] # only these index (of the HPs tried) will be plotted, for clearer visualisation - if len(best_hp) == 1 or (len(best_hp) == 2 and "seed" in best_hp): - tuned_variable = list(set(best_hp.keys()) - {"seed"})[0] + tuned_hps = all_results.T.loc[["train" not in col and "validation" not in col for col in all_results.columns]].index + if len(tuned_hps) == 1 or (len(tuned_hps) == 2 and "seed" in tuned_hps): + variable_tuned = list(set(tuned_hps) - {"seed"})[0] + # variable_tuned = "seed" for variable_to_plot in all_results.columns: if "validation" in variable_to_plot: compare_tuning_results( - all_results, variable_to_plot=variable_to_plot, - variable_tuned=tuned_variable, n_epochs=n_epochs, + all_results.copy(), variable_to_plot=variable_to_plot, + variable_tuned=variable_tuned, n_epochs=n_epochs, hp_index_to_plot=hp_index_to_plot, ) else: diff --git a/scvi/module/_mixupvae.py b/scvi/module/_mixupvae.py index bbf14463ba..32134cb63d 100644 --- a/scvi/module/_mixupvae.py +++ b/scvi/module/_mixupvae.py @@ -103,16 +103,6 @@ class MixUpVAE(VAE): Whether to concatenate covariates to expression in encoder mixup_penalty The loss to use to compare the average of encoded cell and the encoded pseudobulk. - mixup_penalty_aggregation - One of - - * ``'sum'`` - Sum the n_pseudobulk L2 losses - * ``'mean'`` - Average the n_pseudobulk L2 losses - * ``'max'`` - Take the max among the n_pseudobulk L2 losses - average_variables_mixup_penalty - Whether to average the mixup penalty across the different variables. When False, the values are just summed. - If the loss_computation is in the latent space, there are n_latent variables. - If inside the reconstructed space, there are n_input variables. deeply_inject_covariates Whether to concatenate covariates into output of hidden layers in encoder/decoder. This option only applies when `n_layers` > 1. The covariates are concatenated to the input of subsequent hidden layers. @@ -146,7 +136,7 @@ def __init__( n_input: int, n_batch: int = 0, n_labels: int = 0, - n_hidden: Tunable[int] = 128, + n_hidden: Tunable[int] = 512, n_latent: Tunable[int] = 10, n_layers: Tunable[int] = 1, seed: Tunable[int] = 0, @@ -177,8 +167,6 @@ def __init__( loss_computation: Tunable[str] = "latent_space", pseudo_bulk: Tunable[str] = "pre_encoded", mixup_penalty: Tunable[str] = "l2", - mixup_penalty_aggregation: Tunable[str] = "mean", - average_variables_mixup_penalty: Tunable[bool] = False, ): torch.manual_seed(seed) @@ -222,8 +210,6 @@ def __init__( self.loss_computation = loss_computation self.pseudo_bulk = pseudo_bulk self.mixup_penalty = mixup_penalty - self.mixup_penalty_aggregation = mixup_penalty_aggregation - self.average_variables_mixup_penalty = average_variables_mixup_penalty self.z_signature = None self.logger_messages = set() @@ -366,7 +352,7 @@ def inference( categorical_pseudobulk_input = [] categorical_signature_input = [] j=0 - for n_cat in self.z_encoder.encoder.n_cat_list: + for n_cat in self.decoder.px_decoder.n_cat_list: if n_cat > 0 : # if n_cat == 0 then no batch index was given, so skip it one_hot_cat_covs = one_hot(cat_covs[j], n_cat) @@ -653,11 +639,18 @@ def loss( cosine_deconv_results = [] mse_deconv_results = [] mae_deconv_results = [] + z_signature = inference_outputs["z_signature"] if self.z_signature is None else self.z_signature for i, pseudobulk in enumerate(pseudobulk_z.detach().cpu().numpy()): predicted_proportions = nnls( - self.z_signature.detach().cpu().numpy().T, + z_signature.detach().cpu().numpy().T, pseudobulk, )[0] + if self.z_signature is None: + # resize predicted_proportions to the right number of labels + full_predicted_proportions = np.zeros(self.n_labels) + for j, cell_type in enumerate(tensors["labels"].unique().detach().cpu()): + full_predicted_proportions[int(cell_type)] = predicted_proportions[j] + predicted_proportions = full_predicted_proportions if np.any(predicted_proportions): # if not all zeros, sum the predictions to 1 predicted_proportions = predicted_proportions / predicted_proportions.sum() @@ -727,8 +720,6 @@ def get_mix_up_loss(self, inference_outputs, generative_outputs): mean_single_cells = generative_outputs["px"].rate[pseudobulk_indices, :].mean(axis=1) pseudobulk = generative_outputs["px_pseudobulk"].rate mixup_penalty = torch.sum((pseudobulk - mean_single_cells) ** 2, axis=1) - if self.average_variables_mixup_penalty: - mixup_penalty /= mean_single_cells.shape[1] elif self.mixup_penalty == "kl": # kl of mean(cells) compared to reference pseudobulk if self.loss_computation == "latent_space": @@ -765,10 +756,7 @@ def get_mix_up_loss(self, inference_outputs, generative_outputs): mixup_penalty = kl( averaged_cells_distrib, pseudobulk_reference_distrib ).sum(dim=-1) - if self.mixup_penalty_aggregation == "max": - mixup_penalty = mixup_penalty.max() - else : - mixup_penalty = torch.sum(mixup_penalty) - if self.mixup_penalty_aggregation == "mean": - mixup_penalty /= mean_single_cells.shape[0] + + mixup_penalty = torch.sum(mixup_penalty) / mean_single_cells.shape[0] + return mixup_penalty diff --git a/tuning_configs.py b/tuning_configs.py index 0ed630c7ba..aafaaaa33f 100644 --- a/tuning_configs.py +++ b/tuning_configs.py @@ -1,3 +1,5 @@ +"""Hyperparameter search configs.""" + from ray import tune from constants import ( @@ -16,8 +18,6 @@ LATENT_SIZE, N_PSEUDOBULKS, N_CELLS_PER_PSEUDOBULK, - MIXUP_PENATLY_AGGREGATION, - AVERAGE_VARIABLES_MIXUP_PENALTY, SEED, ) @@ -32,6 +32,7 @@ repeat_with_several_seeds = { "seed": tune.grid_search( [0, 3, 8, 12, 23] + # [0,1] ) } example_with_several_seeds = { @@ -40,13 +41,10 @@ } latent_space_search_space = { "n_latent": tune.grid_search( - list(range(len(GROUPS[TRAINING_CELL_TYPE_GROUP]) - 1, 550, 20)) # from n cell types to n marker genes - ) -} -latent_space_search_space_precise = { - "n_latent": tune.grid_search( - list(range(len(GROUPS[TRAINING_CELL_TYPE_GROUP]) - 1, 60, 5)) # from n cell types to 60 - ) + # list(range(len(GROUPS[TRAINING_CELL_TYPE_GROUP]) - 1, 550, 20)) # from n cell types to n marker genes + [10, 20, 30, 40, 50, 100, 200] + ), + "seed": tune.grid_search([3, 8, 12, 23, 42]) } batch_size_search_space = { "batch_size": tune.grid_search( @@ -61,7 +59,8 @@ signature_type_search_space = { "signature_type": tune.grid_search( ["pre_encoded", "post_inference"] - ) + ), + "seed": tune.grid_search([3, 8, 12, 23, 42]) } loss_computation_search_space = { "loss_computation": tune.grid_search( @@ -69,20 +68,27 @@ ) } gene_likelihood_search_space = { - "gene_likelihood": tune.grid_search(["zinb", "nb", "poisson"]) + "gene_likelihood": tune.grid_search(["zinb", "nb", "poisson"]), + "seed": tune.grid_search([3, 8, 12, 23, 42]) } n_hidden_search_space = { - "n_hidden": tune.grid_search([128, 256, 512, 1024]) + "n_hidden": tune.grid_search([128, 256, 512, 1024]), + "seed": tune.grid_search([3, 8, 12, 23, 42]) } n_layers_search_space = { - "n_layers": tune.grid_search([1, 2, 3]) + "n_layers": tune.grid_search([1, 2, 3]), + "seed": tune.grid_search([3, 8, 12, 23, 42]) } n_pseudobulks_search_space = { - "n_pseudobulks": tune.grid_search([1, 5, 10, 30, 50, 100]), + "n_pseudobulks": tune.grid_search([1, 100]), + "seed": tune.grid_search([3, 8, 12]) + # "seed": tune.grid_search([3, 8, 12, 23, 42]) +} +n_cells_per_pseudobulk_search_space = { + "n_cells_per_pseudobulk": tune.grid_search([100, 256, 512, 1024, 2048]), "seed": tune.grid_search([3, 8, 12]) - # "seed": tune.grid_search([3, 8, 12, 23, 42]) } -SEARCH_SPACE = n_pseudobulks_search_space +SEARCH_SPACE = n_layers_search_space TUNED_VARIABLES = list(SEARCH_SPACE.keys()) NUM_SAMPLES = 1 # will only perform once the gridsearch (useful to change if mix of grid and random search for instance) @@ -102,8 +108,6 @@ "n_latent": LATENT_SIZE, "n_pseudobulks": N_PSEUDOBULKS, "n_cells_per_pseudobulk": N_CELLS_PER_PSEUDOBULK, - "mixup_penalty_aggregation": MIXUP_PENATLY_AGGREGATION, - "average_variables_mixup_penalty": AVERAGE_VARIABLES_MIXUP_PENALTY, "seed": SEED, } for key in list(model_fixed_hps): @@ -136,4 +140,4 @@ "pearson_coeff_deconv_train", "mse_deconv_train", "mae_deconv_train", -] +] \ No newline at end of file From 81d4b0c18263c6de1c9bae88d7ec7e430c4ace22 Mon Sep 17 00:00:00 2001 From: SimonGrouard Date: Wed, 19 Jun 2024 10:54:54 +0000 Subject: [PATCH 2/6] Signature filtered genes and benchmark changes --- benchmark_utils/dataset_utils.py | 4 +- benchmark_utils/deconv_utils.py | 5 ++- benchmark_utils/latent_signature_utils.py | 50 ++++++++++++----------- benchmark_utils/plotting_utils.py | 12 +++++- benchmark_utils/sanity_checks_utils.py | 31 +++++++------- benchmark_utils/signature_utils.py | 39 +++++++++--------- constants.py | 19 +++++---- run_mixupvi.py | 6 +-- run_pseudobulk_benchmark.py | 44 ++++++++++---------- scvi/module/_mixupvae.py | 14 +++---- scvi/module/_utils.py | 49 +--------------------- tuning_configs.py | 8 +++- 12 files changed, 128 insertions(+), 153 deletions(-) diff --git a/benchmark_utils/dataset_utils.py b/benchmark_utils/dataset_utils.py index 3c01da76a3..246bd1e558 100644 --- a/benchmark_utils/dataset_utils.py +++ b/benchmark_utils/dataset_utils.py @@ -21,8 +21,9 @@ def preprocess_scrna( adata.layers["relative_counts"] = adata.X.copy() # preserve counts, used for sc.pp.log1p(adata) adata.raw = adata # freeze the state in `.raw` + adata_filtered = adata.copy() sc.pp.highly_variable_genes( - adata, + adata_filtered, n_top_genes=keep_genes, subset=True, layer="counts", @@ -30,6 +31,7 @@ def preprocess_scrna( batch_key=batch_key, ) #TODO: add the filtering / QC steps that they perform in Servier + return adata, adata_filtered.var_names def split_dataset( diff --git a/benchmark_utils/deconv_utils.py b/benchmark_utils/deconv_utils.py index c5f5fa9f08..199998ea11 100644 --- a/benchmark_utils/deconv_utils.py +++ b/benchmark_utils/deconv_utils.py @@ -47,6 +47,7 @@ def perform_latent_deconv(adata_pseudobulk: ad.AnnData, scvi.model.MixUpVI, scvi.model.CondSCVI]], all_adata_samples, + filtered_genes, use_mixupvi: bool = True, use_nnls: bool = True, use_softmax: bool = False) -> pd.DataFrame: @@ -74,7 +75,7 @@ def perform_latent_deconv(adata_pseudobulk: ad.AnnData, if use_mixupvi: latent_pseudobulks=[] for i in range(len(all_adata_samples)): - latent_pseudobulks.append(model.get_latent_representation(all_adata_samples[i], get_pseudobulk=True)) + latent_pseudobulks.append(model.get_latent_representation(all_adata_samples[i,filtered_genes], get_pseudobulk=True)) latent_pseudobulk = np.concatenate(latent_pseudobulks, axis=0) else: adata_pseudobulk = ad.AnnData(X=adata_pseudobulk.layers["counts"], @@ -82,7 +83,7 @@ def perform_latent_deconv(adata_pseudobulk: ad.AnnData, var=adata_pseudobulk.var) adata_pseudobulk.layers["counts"] = adata_pseudobulk.X.copy() - latent_pseudobulk = model.get_latent_representation(adata_pseudobulk) + latent_pseudobulk = model.get_latent_representation(adata_pseudobulk, get_pseudobulk=False) if use_nnls: deconv = LinearRegression(positive=True).fit(adata_latent_signature.X.T, diff --git a/benchmark_utils/latent_signature_utils.py b/benchmark_utils/latent_signature_utils.py index 848e9e0a6f..731020abf0 100644 --- a/benchmark_utils/latent_signature_utils.py +++ b/benchmark_utils/latent_signature_utils.py @@ -8,6 +8,7 @@ import torch from .dataset_utils import create_anndata_pseudobulk +from constants import SIGNATURE_TYPE def create_latent_signature( @@ -16,7 +17,6 @@ def create_latent_signature( repeats: int = 1, average_all_cells: bool = True, sc_per_pseudobulk: int = 3000, - signature_type: str = "pre-encoded", cell_type_column: str = "cell_types_grouped", count_key: Optional[str] = "counts", representation_key: Optional[str] = "X_scvi", @@ -101,34 +101,36 @@ def create_latent_signature( ) adata_sampled = adata[sampled_cells] - if signature_type == "pre-encoded": - assert ( - model is not None, - "If representing a purified pseudo bulk (aggregate before embedding", - "), must give a model", - ) - assert ( - count_key is not None - ), "Must give a count key if aggregating before embedding." - - if use_mixupvi: - result = model.get_latent_representation( - adata_sampled, get_pseudobulk=True - ).reshape(-1) - else: + assert ( + model is not None, + "If representing a purified pseudo bulk (aggregate before embedding", + "), must give a model", + ) + assert ( + count_key is not None + ), "Must give a count key if aggregating before embedding." + + if use_mixupvi: + # TODO: in this case, n_cells sampled will be equal to self.n_cells_per_pseudobulk by mixupvae + # so change that to being equal to either all cells (if average_all_cells) or sc_per_pseudobulk + result = model.get_latent_representation( + adata_sampled, get_pseudobulk=True + )[0] # take first pseudobulk + else: + if SIGNATURE_TYPE == "pre_encoded": pseudobulk = ( adata_sampled.layers[count_key].mean(axis=0).reshape(1, -1) - ) # .astype(int).astype(numpy.float32) - adata_pseudobulk = create_anndata_pseudobulk( - adata_sampled, pseudobulk ) - result = model.get_latent_representation(adata_pseudobulk).reshape( - -1 + adata_sampled = create_anndata_pseudobulk( + adata_sampled, pseudobulk ) - else: - raise ValueError( - "Only pre-encoded signatures are supported for now." + result = model.get_latent_representation( + adata_sampled, get_pseudobulk = False ) + if SIGNATURE_TYPE == "pre_encoded": + result = result.reshape(-1) + elif SIGNATURE_TYPE == "post_inference": + result = result.mean(axis=0) repeat_list.append(repeat) representation_list.append(result) cell_type_list.append(cell_type) diff --git a/benchmark_utils/plotting_utils.py b/benchmark_utils/plotting_utils.py index d7066bfc3e..340736fde0 100644 --- a/benchmark_utils/plotting_utils.py +++ b/benchmark_utils/plotting_utils.py @@ -1,9 +1,12 @@ +import os import matplotlib.pyplot as plt import seaborn as sns import numpy as np import pandas as pd from typing import Dict +from datetime import datetime +from loguru import logger def plot_purified_deconv_results(deconv_results, only_fit_one_baseline, more_details=False, save=False, filename="test"): """Plot the deconv results from sanity check 1""" @@ -106,7 +109,14 @@ def plot_deconv_lineplot(results: Dict[int, pd.DataFrame], plt.show() if save: - plt.savefig(f"/home/owkin/project/plots/{filename}.png", dpi=300) + path = f"/home/owkin/project/plots/{filename}.png" + if os.path.isfile(path): + new_path = f"/home/owkin/project/plots/{filename}_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.png" + logger.warning(f"{path} already exists. Saving file on this path instead: {new_path}") + path = new_path + plt.savefig(path, dpi=300) + logger.info(f"Plot saved to the following path: {path}") + def plot_metrics(model_history, train: bool = True, n_epochs: int = 100): """Plot the train or val metrics from training.""" diff --git a/benchmark_utils/sanity_checks_utils.py b/benchmark_utils/sanity_checks_utils.py index 1682823aeb..58338e91bc 100644 --- a/benchmark_utils/sanity_checks_utils.py +++ b/benchmark_utils/sanity_checks_utils.py @@ -36,8 +36,8 @@ def run_purified_sanity_check( adata_train: ad.AnnData, adata_pseudobulk_test_counts: ad.AnnData, adata_pseudobulk_test_rc: ad.AnnData, + filtered_genes: list, signature: pd.DataFrame, - intersection: List[str], generative_models : Dict[str, Union[scvi.model.SCVI, scvi.model.CondSCVI, scvi.model.DestVI, @@ -62,8 +62,6 @@ def run_purified_sanity_check( pseudobulk RNA seq test dataset (relative counts). signature: pd.DataFrame Signature matrix. - intersection: List[str] - List of genes in common between the signature and the test dataset. generative_models: Dict[str, scvi.model] Dictionnary of generative models. baselines: List[str] @@ -79,7 +77,7 @@ def run_purified_sanity_check( deconv_results_melted_methods = pd.DataFrame(columns=["Cell type predicted", "Cell type", "Estimated Fraction", "Method"]) ## NNLS if "nnls" in baselines: - deconv_results = perform_nnls(signature, adata_pseudobulk_test_rc[:, intersection]) + deconv_results = perform_nnls(signature, adata_pseudobulk_test_rc[:, signature.index]) deconv_results_melted_methods_tmp = melt_df(deconv_results) deconv_results_melted_methods_tmp["Method"] = "nnls" deconv_results_melted_methods = pd.concat( @@ -87,6 +85,7 @@ def run_purified_sanity_check( ) # Pseudobulk Dataframe for TAPE and Scaden + intersection = set(signature.index).intersection(set(filtered_genes)) pseudobulk_test_df = pd.DataFrame( adata_pseudobulk_test_rc[:, intersection].X, index=adata_pseudobulk_test_rc.obs_names, @@ -138,13 +137,13 @@ def run_purified_sanity_check( # ) else: adata_latent_signature = create_latent_signature( - adata=adata_train, + adata=adata_train[:,filtered_genes], model=generative_models[model], average_all_cells = True, sc_per_pseudobulk=3000, ) deconv_results = perform_latent_deconv( - adata_pseudobulk=adata_pseudobulk_test_counts, + adata_pseudobulk=adata_pseudobulk_test_counts[:,filtered_genes], adata_latent_signature=adata_latent_signature, model=generative_models[model], ) @@ -161,9 +160,9 @@ def run_sanity_check( adata_pseudobulk_test_counts: ad.AnnData, adata_pseudobulk_test_rc: ad.AnnData, all_adata_samples_test: List[ad.AnnData], + filtered_genes: list, df_proportions_test: pd.DataFrame, signature: pd.DataFrame, - intersection: List[str], generative_models : Dict[str, Union[scvi.model.SCVI, scvi.model.CondSCVI, scvi.model.DestVI, @@ -183,10 +182,10 @@ def run_sanity_check( pseudobulk RNA seq test dataset. adata_pseudobulk_test_rc: ad.AnnData pseudobulk RNA seq test dataset (relative counts). + filtered_genes: list + The most variable genes filtered, used for all methods except NNLS (can be challenged). signature: pd.DataFrame Signature matrix. - intersection: List[str] - List of genes in common between the signature and the test dataset. generative_models: Dict[str, scvi.model] Dictionnary of generative models. baselines: List[str] @@ -211,12 +210,13 @@ def run_sanity_check( # 1. Linear regression (NNLS) if "nnls" in baselines: deconv_results = perform_nnls(signature, - adata_pseudobulk_test_rc[:, intersection]) + adata_pseudobulk_test_rc[:, signature.index]) correlations = compute_correlations(deconv_results, df_proportions_test) group_correlations = compute_group_correlations(deconv_results, df_proportions_test) df_test_correlations.loc[:, "nnls"] = correlations.values df_test_group_correlations.loc[:, "nnls"] = group_correlations.values + intersection = list(set(signature.index).intersection(set(filtered_genes))) pseudobulk_test_df = pd.DataFrame( adata_pseudobulk_test_rc[:, intersection].X, index=adata_pseudobulk_test_rc.obs_names, @@ -265,16 +265,19 @@ def run_sanity_check( if model == "MixupVI": use_mixupvi=True adata_latent_signature = create_latent_signature( - adata=adata_train, + adata=adata_train[:,filtered_genes], model=generative_models[model], - use_mixupvi=use_mixupvi, + use_mixupvi=False, # should be equal to use_mixupvi, but if True, + # then it averages as many cells as self.n_cells_per-pseudobulk from mixupvae + # (and not the number we wish in the benchmark) average_all_cells = True, sc_per_pseudobulk=10000, ) deconv_results = perform_latent_deconv( - adata_pseudobulk=adata_pseudobulk_test_counts, + adata_pseudobulk=adata_pseudobulk_test_counts[:,filtered_genes], all_adata_samples=all_adata_samples_test, - use_mixupvi=use_mixupvi, + filtered_genes=filtered_genes, + use_mixupvi=False, # see comment above adata_latent_signature=adata_latent_signature, model=generative_models[model], ) diff --git a/benchmark_utils/signature_utils.py b/benchmark_utils/signature_utils.py index 6dad54e394..2a1039d7b3 100644 --- a/benchmark_utils/signature_utils.py +++ b/benchmark_utils/signature_utils.py @@ -6,29 +6,33 @@ def create_signature( - adata: ad.AnnData, signature_type: str = "crosstissue_general", ): """Create the signature matrix from the single cell dataset.""" if signature_type == "laughney": - signature = pd.read_csv( - "/home/owkin/project/laughney_signature.csv", index_col=0 - ).drop(["Endothelial", "Malignant", "Stroma", "Epithelial"], axis=1) - # map the HGNC notation to ENSG if the signature matrix uses HGNC notation - mg = mygene.MyGeneInfo() - genes = mg.querymany( - signature.index, - scopes="symbol", - fields=["ensembl"], - species="human", - verbose=False, - as_dataframe=True, + raise NotImplementedError( + "Laughney signature not available now. To solve, upload it directly with " + "ENSG names." ) - ensg_names = map_hgnc_to_ensg(genes, adata) - signature.index = ensg_names + # signature = pd.read_csv( + # "/home/owkin/project/laughney_signature.csv", index_col=0 + # ).drop(["Endothelial", "Malignant", "Stroma", "Epithelial"], axis=1) + # # map the HGNC notation to ENSG if the signature matrix uses HGNC notation + # mg = mygene.MyGeneInfo() + # genes = mg.querymany( + # signature.index, + # scopes="symbol", + # fields=["ensembl"], + # species="human", + # verbose=False, + # as_dataframe=True, + # ) + # ensg_names = map_hgnc_to_ensg(genes, adata) + # signature.index = ensg_names elif signature_type == "CTI_1st_level_granularity": signature = read_txt_r_signature( "/home/owkin/project/Almudena/Output/Crosstiss_Immune_norm/CTI.txt" + # "/home/owkin/project/Almudena/Output/Crosstiss_Immune/CTI.txt" ) # it is the normalised one (using adata.X and not adata.raw.X) elif signature_type == "CTI_2nd_level_granularity": signature = read_txt_r_signature( @@ -46,10 +50,7 @@ def create_signature( signature = read_txt_r_signature( "/home/owkin/project/Simon/signature_FACS_1st_level_granularity/FACS_1st_level_granularity_ensg.txt" ) - # intersection between all genes and marker genes - intersection = list(set(adata.var_names).intersection(signature.index)) - signature = signature.loc[intersection] - return signature, intersection + return signature def read_txt_r_signature(path): diff --git a/constants.py b/constants.py index 2140d5c863..5e3a47127d 100644 --- a/constants.py +++ b/constants.py @@ -1,14 +1,14 @@ """Constants and global variables to run the different deconv files.""" -## constants for run_mixupvi.py +## Constants for run_mixupvi.py TUNE_MIXUPVI = True TRAINING_DATASET = "CTI" # ["CTI", "TOY", "CTI_PROCESSED", "CTI_RAW"] TRAINING_CELL_TYPE_GROUP = ( "2nd_level_granularity" # ["1st_level_granularity", "2nd_level_granularity", "3rd_level_granularity", "4th_level_granularity", "FACS_1st_level_granularity"] ) -## constants for run_pseudobulk_benchmark.py -SIGNATURE_CHOICE = "CTI_2nd_level_granularity" # ["laughney", "CTI_1st_level_granularity", "CTI_2nd_level_granularity", "CTI_3rd_level_granularity", "CTI_4th_level_granularity", "FACS_1st_level_granularity"] +## Constants for run_pseudobulk_benchmark.py +SIGNATURE_CHOICE = "CTI_1st_level_granularity" # ["laughney", "CTI_1st_level_granularity", "CTI_2nd_level_granularity", "CTI_3rd_level_granularity", "CTI_4th_level_granularity", "FACS_1st_level_granularity"] if SIGNATURE_CHOICE in ["laughney", "CTI_1st_level_granularity"]: BENCHMARK_CELL_TYPE_GROUP = "1st_level_granularity" elif SIGNATURE_CHOICE == "CTI_2nd_level_granularity": @@ -23,17 +23,20 @@ BENCHMARK_CELL_TYPE_GROUP = None # no signature was created BENCHMARK_DATASET = "CTI" # ["CTI", "TOY", "CTI_PROCESSED", "CTI_RAW"] N_SAMPLES = 500 # number of pseudbulk samples to create and assess for deconvolution +N_CELLS = [2000] # list of number of cells to try for the lineplot GENERATIVE_MODELS = ["MixupVI"] #, "DestVI"] # "scVI", "CondscVI", "DestVI" -# GENERATIVE_MODELS = [] # if only want baselines BASELINES = ["nnls"] # "nnls", "TAPE", "Scaden" -# BASELINES = ["nnls"] # if only want nnls +COMPUTE_SC_RESULTS_WHEN_FACS = True -## general mixupvi constants when training it or preprocessing data +## General constants to change depending on the task SAVE_MODEL = False SEED = 3 +LATENT_SIZE = 10 +MAX_EPOCHS = 100 + +## Other constants to tune and then fix N_GENES = 2000 # number of input genes after preprocessing # MixUpVI training hyperparameters -MAX_EPOCHS = 100 BATCH_SIZE = 1024 TRAIN_SIZE = 0.7 # as opposed to validation CHECK_VAL_EVERY_N_EPOCH = None @@ -42,7 +45,6 @@ # MixUpVI model hyperparameters N_PSEUDOBULKS = 100 N_CELLS_PER_PSEUDOBULK = 256 # None (then will be batch size) or int (will cap at batch size) -LATENT_SIZE = 30 N_HIDDEN = 512 CONT_COV = None # None or list of continuous covariates to include CAT_COV = None # None or ["donor_id", "assay"] @@ -189,7 +191,6 @@ "ILC3", "Erythroid", "Megakaryocytes", "Progenitor", "Alveolar macrophages","Erythrophagocytic macrophages", "Intermediate macrophages", "Intestinal macrophages","Mast cells"] - } } diff --git a/run_mixupvi.py b/run_mixupvi.py index f315ea6b1a..f0e693ea00 100644 --- a/run_mixupvi.py +++ b/run_mixupvi.py @@ -102,7 +102,7 @@ # %% Load model / results: Uncomment if not running previous cells # if TUNE_MIXUPVI: -# path = "/home/owkin/project/mixupvi_tuning/n_hidden-seed/CTI_dataset_tune_mixupvi_2024-05-06-09:15:06" +# path = "/home/owkin/project/mixupvi_tuning/n_latent-seed/CTI_dataset_tune_mixupvi_2024-06-07-18:30:37" # all_results = read_tuning_results(f"{path}/tuning_results.csv") # search_space = read_search_space(f"{path}/search_space.pkl") # if "best_hp" in search_space: @@ -132,8 +132,8 @@ # %% Plots to compare HPs if TUNE_MIXUPVI: n_epochs = len(set(all_results["train_loss_epoch"].index)) - hp_index_to_plot = None - # hp_index_to_plot = [2, 3] # only these index (of the HPs tried) will be plotted, for clearer visualisation + # hp_index_to_plot = None + hp_index_to_plot = [0,1] # only these index (of the HPs tried) will be plotted, for clearer visualisation tuned_hps = all_results.T.loc[["train" not in col and "validation" not in col for col in all_results.columns]].index if len(tuned_hps) == 1 or (len(tuned_hps) == 2 and "seed" in tuned_hps): diff --git a/run_pseudobulk_benchmark.py b/run_pseudobulk_benchmark.py index c9685295f6..11c7797dba 100644 --- a/run_pseudobulk_benchmark.py +++ b/run_pseudobulk_benchmark.py @@ -1,8 +1,8 @@ """Pseudobulk benchmark.""" # %% import scanpy as sc -from loguru import logger import warnings +from loguru import logger from constants import ( BENCHMARK_DATASET, @@ -13,6 +13,7 @@ N_SAMPLES, GENERATIVE_MODELS, BASELINES, + N_CELLS, ) from benchmark_utils import ( @@ -44,27 +45,29 @@ # preprocess_scrna(adata, keep_genes=1200) elif BENCHMARK_DATASET == "CTI": adata = sc.read("/home/owkin/project/cti/cti_adata.h5ad") - preprocess_scrna(adata, + adata, filtered_genes = preprocess_scrna(adata, keep_genes=N_GENES, batch_key="donor_id") elif BENCHMARK_DATASET == "CTI_RAW": warnings.warn("The raw data of this adata is on adata.raw.X, but the normalised " "adata.X will be used here") adata = sc.read("/home/owkin/data/cross-tissue/omics/raw/local.h5ad") - preprocess_scrna(adata, + adata, filtered_genes = preprocess_scrna(adata, keep_genes=N_GENES, batch_key="donor_id", ) elif BENCHMARK_DATASET == "CTI_PROCESSED": # Load processed for speed-up (already filtered, normalised, etc.) - adata = sc.read(f"/home/owkin/data/cti_data/processed/cti_processed_{N_GENES}.h5ad") + raise NotImplementedError( + "Not possible to use a CTI_PROCESSED dataset because we would need the " + "not-filtered adata to be processed as well. To solve: separate the " + "preprocessing function between normalization and filtering parts." + ) + # adata_filtered = sc.read(f"/home/owkin/data/cti_data/processed/cti_processed_{N_GENES}.h5ad") # %% load signature logger.info(f"Loading signature matrix: {SIGNATURE_CHOICE} | {BENCHMARK_CELL_TYPE_GROUP}...") -signature, intersection = create_signature( - adata, - signature_type=SIGNATURE_CHOICE, -) +signature = create_signature(signature_type=SIGNATURE_CHOICE) # %% add cell types groups and split train/test adata, train_test_index = add_cell_types_grouped(adata, BENCHMARK_CELL_TYPE_GROUP) @@ -80,7 +83,7 @@ if "scVI" in GENERATIVE_MODELS: logger.info("Fit scVI ...") model_path = f"project/models/{BENCHMARK_DATASET}_scvi.pkl" - scvi_model = fit_scvi(adata_train, + scvi_model = fit_scvi(adata_train[:,filtered_genes].copy(), model_path, save_model=SAVE_MODEL) generative_models["scVI"] = scvi_model @@ -94,12 +97,12 @@ # ) # Dirichlet adata_pseudobulk_train_counts, adata_pseudobulk_train_rc, df_proportions_test = create_dirichlet_pseudobulk_dataset( - adata_train, prior_alphas = None, n_sample = N_SAMPLES, + adata_train[:,filtered_genes].copy(), prior_alphas = None, n_sample = N_SAMPLES, ) model_path_1 = f"project/models/{BENCHMARK_DATASET}_condscvi.pkl" model_path_2 = f"project/models/{BENCHMARK_DATASET}_destvi.pkl" - condscvi_model , destvi_model= fit_destvi(adata_train, + condscvi_model , destvi_model= fit_destvi(adata_train[:,filtered_genes].copy(), adata_pseudobulk_train_counts, model_path_1, model_path_2, @@ -112,7 +115,7 @@ if "MixupVI" in GENERATIVE_MODELS: logger.info("Train mixupVI ...") model_path = f"project/models/{BENCHMARK_DATASET}_{BENCHMARK_CELL_TYPE_GROUP}_{N_GENES}_mixupvi.pkl" - mixupvi_model = fit_mixupvi(adata_train, + mixupvi_model = fit_mixupvi(adata_train[:,filtered_genes].copy(), model_path, cell_type_group="cell_types_grouped", save_model=SAVE_MODEL, @@ -121,14 +124,10 @@ # %% Sanity check 3 -#num_cells = [50, 100, 300, 500, 1000] - -num_cells = [2000] - results = {} results_group = {} -for n in num_cells: +for n in N_CELLS: logger.info(f"Pseudobulk simulation with {n} sampled cells ...") all_adata_samples_test, adata_pseudobulk_test_counts, adata_pseudobulk_test_rc, df_proportions_test = create_dirichlet_pseudobulk_dataset( adata_test, @@ -149,9 +148,9 @@ adata_pseudobulk_test_counts=adata_pseudobulk_test_counts, adata_pseudobulk_test_rc=adata_pseudobulk_test_rc, all_adata_samples_test=all_adata_samples_test, + filtered_genes=filtered_genes, df_proportions_test=df_proportions_test, signature=signature, - intersection=intersection, generative_models=generative_models, baselines=BASELINES, ) @@ -163,15 +162,17 @@ if len(results) > 1: plot_deconv_lineplot(results, save=True, - filename=f"sim_pseudobulk_lineplot") + filename=f"lineplot_tuned_mixupvi_third_granularity_retry_normal") else: key = list(results.keys())[0] plot_deconv_results(results[key], save=True, - filename=f"sim_pseudobulk_{key}") + # filename=f"benchmark_{key}_cells_first_granularity") + filename="test_first_type") plot_deconv_results_group(results_group[key], save=True, - filename=f"sim_pseudobulk_{key}_per_celltype") + # filename=f"benchmark_{key}_cells_first_granularity_cell_type") + filename="test_first_type_cell_type") # %% (Optional) Sanity check 1. @@ -184,7 +185,6 @@ # adata_pseudobulk_test_counts=adata_pseudobulk_test_counts, # adata_pseudobulk_test_rc=adata_pseudobulk_test_rc, # signature=signature, -# intersection=intersection, # generative_models=generative_models, # baselines=BASELINES, # ) diff --git a/scvi/module/_mixupvae.py b/scvi/module/_mixupvae.py index 32134cb63d..3794db19a4 100644 --- a/scvi/module/_mixupvae.py +++ b/scvi/module/_mixupvae.py @@ -18,7 +18,6 @@ from scvi.nn import Encoder from ._vae import VAE from ._utils import ( - run_incompatible_value_checks, get_mean_pearsonr_torch, compute_ground_truth_proportions, compute_signature, @@ -196,13 +195,12 @@ def __init__( extra_encoder_kwargs=extra_encoder_kwargs, extra_decoder_kwargs=extra_decoder_kwargs, ) - run_incompatible_value_checks( - pseudo_bulk=pseudo_bulk, - loss_computation=loss_computation, - use_batch_norm=use_batch_norm, - mixup_penalty=mixup_penalty, - gene_likelihood=gene_likelihood, - ) + if use_batch_norm != "none" and n_pseudobulks == 1: + raise ValueError( + "Batch normalization cannot be used when only one pseudobulk is " + "computed - it cannot be considered as a batch on which batch " + "normalization can be applied." + ) self.n_pseudobulks = n_pseudobulks self.n_cells_per_pseudobulk = n_cells_per_pseudobulk diff --git a/scvi/module/_utils.py b/scvi/module/_utils.py index 482f8626d8..a1fc210c80 100644 --- a/scvi/module/_utils.py +++ b/scvi/module/_utils.py @@ -140,51 +140,4 @@ def get_mean_pearsonr_torch(x, y): r_num = (xm*ym).sum(dim=1) r_den = torch.norm(xm, p=2, dim=1) * torch.norm(ym, p=2, dim=1) r_val = r_num / r_den - return torch.mean(r_val) - - -def run_incompatible_value_checks( - pseudo_bulk, loss_computation, use_batch_norm, mixup_penalty, gene_likelihood -): - """Check the values of the categorical variables to run MixUpVI are compatible. - The first 4 checks will only be relevant when pseudobulk will not be computed both - in encoder and decoder (right now, computed in both). Until then, use_batch_norm - should be None. - """ - if ( - pseudo_bulk == "pre_encoded" - and loss_computation == "latent_space" - and use_batch_norm in ["encoder", "both"] - ): - raise ValueError( - "MixUpVI cannot use batch normalization there, as the batch size of pseudobulk is 1." - ) - elif ( - pseudo_bulk == "pre_encoded" - and loss_computation == "reconstructed_space" - and use_batch_norm != "none" - ): - raise ValueError( - "MixUpVI cannot use batch normalization there, as the batch size of pseudobulk is 1." - ) - elif pseudo_bulk == "post_inference" and loss_computation == "latent_space": - raise ValueError( - "Pseudo bulk needs to be pre-encoded to compute the MixUp loss in the latent space." - ) - elif ( - pseudo_bulk == "post_inference" - and loss_computation == "reconstructed_space" - and use_batch_norm in ["decoder", "both"] - ): - raise ValueError( - "MixUpVI cannot use batch normalization there, as the batch size of pseudobulk is 1." - ) - if ( - mixup_penalty == "kl" - and loss_computation != "latent_space" - and gene_likelihood == "zinb" - ): - raise NotImplementedError( - "The KL divergence between ZINB distributions for the MixUp loss is not " - "implemented." - ) \ No newline at end of file + return torch.mean(r_val) \ No newline at end of file diff --git a/tuning_configs.py b/tuning_configs.py index aafaaaa33f..3d625b040e 100644 --- a/tuning_configs.py +++ b/tuning_configs.py @@ -42,7 +42,7 @@ latent_space_search_space = { "n_latent": tune.grid_search( # list(range(len(GROUPS[TRAINING_CELL_TYPE_GROUP]) - 1, 550, 20)) # from n cell types to n marker genes - [10, 20, 30, 40, 50, 100, 200] + [70, 100, 120, 150] ), "seed": tune.grid_search([3, 8, 12, 23, 42]) } @@ -88,7 +88,11 @@ "n_cells_per_pseudobulk": tune.grid_search([100, 256, 512, 1024, 2048]), "seed": tune.grid_search([3, 8, 12]) } -SEARCH_SPACE = n_layers_search_space +use_batch_norm_search_space = { + "use_batch_norm": tune.grid_search(["none", "encoder", "decoder", "both"]), + "seed": tune.grid_search([3, 8, 12, 23, 42]) +} +SEARCH_SPACE = latent_space_search_space TUNED_VARIABLES = list(SEARCH_SPACE.keys()) NUM_SAMPLES = 1 # will only perform once the gridsearch (useful to change if mix of grid and random search for instance) From 8457537c97df155a6d4c369d6ca90b9c9500ac8a Mon Sep 17 00:00:00 2001 From: SimonGrouard Date: Tue, 13 Aug 2024 14:13:25 +0000 Subject: [PATCH 3/6] brutal integration of facs in the pipeline --- constants.py | 8 +- run_pseudobulk_benchmark.py | 209 ++++++++++++++++++++++++++++-------- 2 files changed, 167 insertions(+), 50 deletions(-) diff --git a/constants.py b/constants.py index 5e3a47127d..8ab34e913d 100644 --- a/constants.py +++ b/constants.py @@ -8,7 +8,7 @@ ) ## Constants for run_pseudobulk_benchmark.py -SIGNATURE_CHOICE = "CTI_1st_level_granularity" # ["laughney", "CTI_1st_level_granularity", "CTI_2nd_level_granularity", "CTI_3rd_level_granularity", "CTI_4th_level_granularity", "FACS_1st_level_granularity"] +SIGNATURE_CHOICE = "FACS_1st_level_granularity" # ["laughney", "CTI_1st_level_granularity", "CTI_2nd_level_granularity", "CTI_3rd_level_granularity", "CTI_4th_level_granularity", "FACS_1st_level_granularity"] if SIGNATURE_CHOICE in ["laughney", "CTI_1st_level_granularity"]: BENCHMARK_CELL_TYPE_GROUP = "1st_level_granularity" elif SIGNATURE_CHOICE == "CTI_2nd_level_granularity": @@ -23,7 +23,7 @@ BENCHMARK_CELL_TYPE_GROUP = None # no signature was created BENCHMARK_DATASET = "CTI" # ["CTI", "TOY", "CTI_PROCESSED", "CTI_RAW"] N_SAMPLES = 500 # number of pseudbulk samples to create and assess for deconvolution -N_CELLS = [2000] # list of number of cells to try for the lineplot +N_CELLS = [100] # list of number of cells to try for the lineplot GENERATIVE_MODELS = ["MixupVI"] #, "DestVI"] # "scVI", "CondscVI", "DestVI" BASELINES = ["nnls"] # "nnls", "TAPE", "Scaden" COMPUTE_SC_RESULTS_WHEN_FACS = True @@ -31,11 +31,11 @@ ## General constants to change depending on the task SAVE_MODEL = False SEED = 3 -LATENT_SIZE = 10 +LATENT_SIZE = 25 MAX_EPOCHS = 100 ## Other constants to tune and then fix -N_GENES = 2000 # number of input genes after preprocessing +N_GENES = 4000 # number of input genes after preprocessing # MixUpVI training hyperparameters BATCH_SIZE = 1024 TRAIN_SIZE = 0.7 # as opposed to validation diff --git a/run_pseudobulk_benchmark.py b/run_pseudobulk_benchmark.py index 11c7797dba..1d0942e256 100644 --- a/run_pseudobulk_benchmark.py +++ b/run_pseudobulk_benchmark.py @@ -1,8 +1,11 @@ """Pseudobulk benchmark.""" # %% import scanpy as sc +import pandas as pd +import anndata as ad import warnings from loguru import logger +from sklearn.linear_model import LinearRegression from constants import ( BENCHMARK_DATASET, @@ -14,6 +17,7 @@ GENERATIVE_MODELS, BASELINES, N_CELLS, + COMPUTE_SC_RESULTS_WHEN_FACS, ) from benchmark_utils import ( @@ -28,11 +32,15 @@ add_cell_types_grouped, run_purified_sanity_check, run_sanity_check, + create_latent_signature, + compute_correlations, + compute_group_correlations, plot_purified_deconv_results, plot_deconv_results, plot_deconv_results_group, plot_deconv_lineplot, ) +from benchmark_utils.dataset_utils import create_anndata_pseudobulk # %% Load scRNAseq dataset logger.info(f"Loading single-cell dataset: {BENCHMARK_DATASET} ...") @@ -112,7 +120,7 @@ generative_models["DestVI"] = destvi_model # 3. MixupVI - if "MixupVI" in GENERATIVE_MODELS: + if "MixUpVI" in GENERATIVE_MODELS: logger.info("Train mixupVI ...") model_path = f"project/models/{BENCHMARK_DATASET}_{BENCHMARK_CELL_TYPE_GROUP}_{N_GENES}_mixupvi.pkl" mixupvi_model = fit_mixupvi(adata_train[:,filtered_genes].copy(), @@ -122,57 +130,166 @@ ) generative_models["MixupVI"] = mixupvi_model -# %% Sanity check 3 +# %% FACS + +if BENCHMARK_CELL_TYPE_GROUP == "FACS_1st_level_granularity": + logger.info("Computing FACS results...") -results = {} -results_group = {} - -for n in N_CELLS: - logger.info(f"Pseudobulk simulation with {n} sampled cells ...") - all_adata_samples_test, adata_pseudobulk_test_counts, adata_pseudobulk_test_rc, df_proportions_test = create_dirichlet_pseudobulk_dataset( - adata_test, - prior_alphas = None, - n_sample = N_SAMPLES, - n_cells = n, - add_sparsity=False # useless in the current modifications + # Load data + facs_results = pd.read_csv( + "/home/owkin/project/bulk_facs/240214_majorCelltypes.csv", index_col=0 + ).drop(["No.B.Cells.in.Live.Cells","NKT.Cells.in.Live.Cells"],axis=1).set_index("Sample") + facs_results = facs_results.rename( + { + "B.Cells.in.Live.Cells":"B", + "NK.Cells.in.Live.Cells":"NK", + "T.Cells.in.Live.Cells":"T", + "Monocytes.in.Live.Cells":"Mono", + "Dendritic.Cells.in.Live.Cells":"DC", + }, axis=1 ) - # decomment following for Sanity check 2. - # adata_pseudobulk_test_counts, adata_pseudobulk_test_rc, df_proportions_test = create_uniform_pseudobulk_dataset( - # adata_test, - # n_sample = N_SAMPLES, - # n_cells = n, - # ) - - df_test_correlations, df_test_group_correlations = run_sanity_check( - adata_train=adata_train, - adata_pseudobulk_test_counts=adata_pseudobulk_test_counts, - adata_pseudobulk_test_rc=adata_pseudobulk_test_rc, - all_adata_samples_test=all_adata_samples_test, - filtered_genes=filtered_genes, - df_proportions_test=df_proportions_test, - signature=signature, - generative_models=generative_models, - baselines=BASELINES, + facs_results = facs_results.dropna() + bulk_data = pd.read_csv( + ( + "/home/owkin/project/bulk_facs/" + "gene_counts20230103_batch1-5_all_cleaned-TPMnorm-allpatients.tsv" + ), + sep="\t", + index_col=0 + ).T + common_samples = pd.read_csv( + "/home/owkin/project/bulk_facs/RNA-FACS_common-samples.csv", index_col=0 ) - results[n] = df_test_correlations - results_group[n] = df_test_group_correlations + # Align bulk and facs samples + common_facs = common_samples.set_index("FACS.ID")["Patient"] + facs_results = facs_results.loc[facs_results.index.isin(common_facs.keys())] + facs_results = facs_results.rename(index=common_facs) + common_bulk = common_samples.set_index("RNAseq_ID")["Patient"] + bulk_data = bulk_data.loc[bulk_data.index.isin(common_bulk.keys())] + bulk_data = bulk_data.rename(index=common_bulk) + bulk_data = bulk_data.loc[facs_results.index].T -# %% Plots -if len(results) > 1: - plot_deconv_lineplot(results, - save=True, - filename=f"lineplot_tuned_mixupvi_third_granularity_retry_normal") -else: - key = list(results.keys())[0] - plot_deconv_results(results[key], + + ### Most of the following is repeated from the sanity checks fct, so move this code there + df_test_correlations = pd.DataFrame( + index=bulk_data.columns, + columns=["nnls", "MixUpVI"] + ) + df_test_group_correlations = pd.DataFrame( + index=facs_results.columns, + columns=["nnls", "MixUpVI"] + ) + + # NNLS + deconv = LinearRegression(positive=True).fit( + signature, bulk_data.loc[signature.index] + ) + deconv_results = pd.DataFrame( + deconv.coef_, index=bulk_data.columns, columns=signature.columns + ) + deconv_results = deconv_results.div( + deconv_results.sum(axis=1), axis=0 + ) # to sum up to 1 + correlations = compute_correlations(deconv_results, facs_results) + group_correlations = compute_group_correlations(deconv_results, facs_results) + df_test_correlations.loc[:, "nnls"] = correlations.values + df_test_group_correlations.loc[:, "nnls"] = group_correlations.values + + # MixUpVI + bulk_mixupvi = bulk_data.loc[filtered_genes] + model = "MixupVI" + adata_latent_signature = create_latent_signature( + adata=adata_train[:,filtered_genes], + model=generative_models[model], + use_mixupvi=False, # should be equal to use_mixupvi, but if True, + # then it averages as many cells as self.n_cells_per-pseudobulk from mixupvae + # (and not the number we wish in the benchmark) + average_all_cells = True, + ) + + adata_bulk = create_anndata_pseudobulk( + adata=adata_train[:,filtered_genes], x=bulk_mixupvi.T.values + ) + latent_bulk = generative_models[model].get_latent_representation( + adata_bulk, get_pseudobulk=False + ) + deconv = LinearRegression(positive=True).fit(adata_latent_signature.X.T, + latent_bulk.T) + deconv_results = pd.DataFrame( + deconv.coef_, index=bulk_data.columns, columns=signature.columns + ) + deconv_results = deconv_results.div( + deconv_results.sum(axis=1), axis=0 + ) # to sum up to 1 + correlations = compute_correlations(deconv_results, facs_results) + group_correlations = compute_group_correlations(deconv_results, facs_results) + df_test_correlations.loc[:, "MixupVI"] = correlations.values + df_test_group_correlations.loc[:, "MixupVI"] = group_correlations.values + + # Plots + plot_deconv_results(df_test_correlations, save=True, - # filename=f"benchmark_{key}_cells_first_granularity") - filename="test_first_type") - plot_deconv_results_group(results_group[key], - save=True, - # filename=f"benchmark_{key}_cells_first_granularity_cell_type") - filename="test_first_type_cell_type") + filename="facs_1st_try") + plot_deconv_results_group(df_test_group_correlations, + save=True, + filename="facs_1st_try_cell_type") + +# %% Sanity check 3 + +if ( + (BENCHMARK_CELL_TYPE_GROUP != "FACS_1st_level_granularity") or + (BENCHMARK_CELL_TYPE_GROUP == "FACS_1st_level_granularity" and COMPUTE_SC_RESULTS_WHEN_FACS) +): + results = {} + results_group = {} + + for n in N_CELLS: + logger.info(f"Pseudobulk simulation with {n} sampled cells ...") + all_adata_samples_test, adata_pseudobulk_test_counts, adata_pseudobulk_test_rc, df_proportions_test = create_dirichlet_pseudobulk_dataset( + adata_test, + prior_alphas = None, + n_sample = N_SAMPLES, + n_cells = n, + add_sparsity=False # useless in the current modifications + ) + # decomment following for Sanity check 2. + # adata_pseudobulk_test_counts, adata_pseudobulk_test_rc, df_proportions_test = create_uniform_pseudobulk_dataset( + # adata_test, + # n_sample = N_SAMPLES, + # n_cells = n, + # ) + + df_test_correlations, df_test_group_correlations = run_sanity_check( + adata_train=adata_train, + adata_pseudobulk_test_counts=adata_pseudobulk_test_counts, + adata_pseudobulk_test_rc=adata_pseudobulk_test_rc, + all_adata_samples_test=all_adata_samples_test, + filtered_genes=filtered_genes, + df_proportions_test=df_proportions_test, + signature=signature, + generative_models=generative_models, + baselines=BASELINES, + ) + + results[n] = df_test_correlations + results_group[n] = df_test_group_correlations + + # %% Plots + if len(results) > 1: + plot_deconv_lineplot(results, + save=True, + filename=f"lineplot_tuned_mixupvi_third_granularity_retry_normal") + else: + key = list(results.keys())[0] + plot_deconv_results(results[key], + save=True, + # filename=f"benchmark_{key}_cells_first_granularity") + filename="test_first_type") + plot_deconv_results_group(results_group[key], + save=True, + # filename=f"benchmark_{key}_cells_first_granularity_cell_type") + filename="test_first_type_cell_type") # %% (Optional) Sanity check 1. From 9efd0e398e1f155ae1fedf6f69b5e100ad04c910 Mon Sep 17 00:00:00 2001 From: SimonGrouard Date: Tue, 20 Aug 2024 15:57:00 +0000 Subject: [PATCH 4/6] Additions to make bulk facs work --- constants.py | 2 +- run_pseudobulk_benchmark.py | 83 ++++++++++++++++++++----------------- 2 files changed, 45 insertions(+), 40 deletions(-) diff --git a/constants.py b/constants.py index 8ab34e913d..e550027427 100644 --- a/constants.py +++ b/constants.py @@ -24,7 +24,7 @@ BENCHMARK_DATASET = "CTI" # ["CTI", "TOY", "CTI_PROCESSED", "CTI_RAW"] N_SAMPLES = 500 # number of pseudbulk samples to create and assess for deconvolution N_CELLS = [100] # list of number of cells to try for the lineplot -GENERATIVE_MODELS = ["MixupVI"] #, "DestVI"] # "scVI", "CondscVI", "DestVI" +GENERATIVE_MODELS = ["MixUpVI"] #, "DestVI"] # "scVI", "CondscVI", "DestVI" BASELINES = ["nnls"] # "nnls", "TAPE", "Scaden" COMPUTE_SC_RESULTS_WHEN_FACS = True diff --git a/run_pseudobulk_benchmark.py b/run_pseudobulk_benchmark.py index 1d0942e256..9c01cd611b 100644 --- a/run_pseudobulk_benchmark.py +++ b/run_pseudobulk_benchmark.py @@ -73,6 +73,47 @@ ) # adata_filtered = sc.read(f"/home/owkin/data/cti_data/processed/cti_processed_{N_GENES}.h5ad") +if BENCHMARK_CELL_TYPE_GROUP == "FACS_1st_level_granularity": + + # Load data + facs_results = pd.read_csv( + "/home/owkin/project/bulk_facs/240214_majorCelltypes.csv", index_col=0 + ).drop(["No.B.Cells.in.Live.Cells","NKT.Cells.in.Live.Cells"],axis=1).set_index("Sample") + facs_results = facs_results.rename( + { + "B.Cells.in.Live.Cells":"B", + "NK.Cells.in.Live.Cells":"NK", + "T.Cells.in.Live.Cells":"T", + "Monocytes.in.Live.Cells":"Mono", + "Dendritic.Cells.in.Live.Cells":"DC", + }, axis=1 + ) + facs_results = facs_results.dropna() + bulk_data = pd.read_csv( + ( + "/home/owkin/project/bulk_facs/" + "gene_counts_batchs1-5_raw.csv" + # "gene_counts20230103_batch1-5_all_cleaned-TPMnorm-allpatients.tsv" + ), + index_col=0, + # sep="\t" + ).T + common_samples = pd.read_csv( + "/home/owkin/project/bulk_facs/RNA-FACS_common-samples.csv", index_col=0 + ) + + # Align bulk and facs samples + common_facs = common_samples.set_index("FACS.ID")["Patient"] + facs_results = facs_results.loc[facs_results.index.isin(common_facs.keys())] + facs_results = facs_results.rename(index=common_facs) + common_bulk = common_samples.set_index("RNAseq_ID")["Patient"] + bulk_data = bulk_data.loc[bulk_data.index.isin(common_bulk.keys())] + bulk_data = bulk_data.rename(index=common_bulk) + bulk_data = bulk_data.loc[facs_results.index].T + + # Intersect the filtered genes with the bulk data + filtered_genes = list(set(filtered_genes).intersection(bulk_data.index)) + # %% load signature logger.info(f"Loading signature matrix: {SIGNATURE_CHOICE} | {BENCHMARK_CELL_TYPE_GROUP}...") signature = create_signature(signature_type=SIGNATURE_CHOICE) @@ -131,46 +172,9 @@ generative_models["MixupVI"] = mixupvi_model # %% FACS - if BENCHMARK_CELL_TYPE_GROUP == "FACS_1st_level_granularity": logger.info("Computing FACS results...") - # Load data - facs_results = pd.read_csv( - "/home/owkin/project/bulk_facs/240214_majorCelltypes.csv", index_col=0 - ).drop(["No.B.Cells.in.Live.Cells","NKT.Cells.in.Live.Cells"],axis=1).set_index("Sample") - facs_results = facs_results.rename( - { - "B.Cells.in.Live.Cells":"B", - "NK.Cells.in.Live.Cells":"NK", - "T.Cells.in.Live.Cells":"T", - "Monocytes.in.Live.Cells":"Mono", - "Dendritic.Cells.in.Live.Cells":"DC", - }, axis=1 - ) - facs_results = facs_results.dropna() - bulk_data = pd.read_csv( - ( - "/home/owkin/project/bulk_facs/" - "gene_counts20230103_batch1-5_all_cleaned-TPMnorm-allpatients.tsv" - ), - sep="\t", - index_col=0 - ).T - common_samples = pd.read_csv( - "/home/owkin/project/bulk_facs/RNA-FACS_common-samples.csv", index_col=0 - ) - - # Align bulk and facs samples - common_facs = common_samples.set_index("FACS.ID")["Patient"] - facs_results = facs_results.loc[facs_results.index.isin(common_facs.keys())] - facs_results = facs_results.rename(index=common_facs) - common_bulk = common_samples.set_index("RNAseq_ID")["Patient"] - bulk_data = bulk_data.loc[bulk_data.index.isin(common_bulk.keys())] - bulk_data = bulk_data.rename(index=common_bulk) - bulk_data = bulk_data.loc[facs_results.index].T - - ### Most of the following is repeated from the sanity checks fct, so move this code there df_test_correlations = pd.DataFrame( index=bulk_data.columns, @@ -182,11 +186,12 @@ ) # NNLS + intersected_signature = signature.loc[signature.index.intersection(bulk_data.index)] deconv = LinearRegression(positive=True).fit( - signature, bulk_data.loc[signature.index] + intersected_signature, bulk_data.loc[intersected_signature.index] ) deconv_results = pd.DataFrame( - deconv.coef_, index=bulk_data.columns, columns=signature.columns + deconv.coef_, index=bulk_data.columns, columns=intersected_signature.columns ) deconv_results = deconv_results.div( deconv_results.sum(axis=1), axis=0 From 235ffe5cc3bea537f95ed07645a0a395735dfcae Mon Sep 17 00:00:00 2001 From: SimonGrouard Date: Thu, 10 Oct 2024 14:14:36 +0000 Subject: [PATCH 5/6] Refacto changes --- benchmark_utils/__init__.py | 1 + benchmark_utils/dataset_utils.py | 49 +-- benchmark_utils/latent_signature_utils.py | 2 +- run_benchmark.py | 137 +++++++++ run_benchmark_config_dataclass.py | 345 ++++++++++++++++++++++ run_benchmark_constants.py | 81 +++++ 6 files changed, 596 insertions(+), 19 deletions(-) create mode 100644 run_benchmark.py create mode 100644 run_benchmark_config_dataclass.py create mode 100644 run_benchmark_constants.py diff --git a/benchmark_utils/__init__.py b/benchmark_utils/__init__.py index bef2bb3210..cf96707dcc 100644 --- a/benchmark_utils/__init__.py +++ b/benchmark_utils/__init__.py @@ -7,6 +7,7 @@ create_random_proportion, ) from .dataset_utils import ( + create_anndata_pseudobulk, preprocess_scrna, split_dataset, create_new_granularity_index, diff --git a/benchmark_utils/dataset_utils.py b/benchmark_utils/dataset_utils.py index 246bd1e558..9104c0319b 100644 --- a/benchmark_utils/dataset_utils.py +++ b/benchmark_utils/dataset_utils.py @@ -14,24 +14,33 @@ def preprocess_scrna( adata: ad.AnnData, keep_genes: int = 2000, batch_key: Optional[str] = None ): - """Preprocess single-cell RNA data for deconvolution benchmarking.""" + """Preprocess single-cell RNA data for deconvolution benchmarking. + + * in adata.X, the normalized log1p counts are saved + * in adata.layers["counts"], raw counts are saved + * in adata.layers["relative_counts"], the relative counts are saved + => The highly variable genes can be found in adata.var["highly_variable"] + + """ sc.pp.filter_genes(adata, min_counts=3) adata.layers["counts"] = adata.X.copy() # preserve counts, used for training sc.pp.normalize_total(adata, target_sum=1e4) - adata.layers["relative_counts"] = adata.X.copy() # preserve counts, used for + adata.layers["relative_counts"] = adata.X.copy() # preserve counts, used for baselines sc.pp.log1p(adata) adata.raw = adata # freeze the state in `.raw` - adata_filtered = adata.copy() sc.pp.highly_variable_genes( - adata_filtered, + adata, n_top_genes=keep_genes, - subset=True, layer="counts", flavor="seurat_v3", batch_key=batch_key, + subset=False, + inplace=True ) #TODO: add the filtering / QC steps that they perform in Servier - return adata, adata_filtered.var_names + # concat the result df to adata.var + + return adata def split_dataset( @@ -112,17 +121,21 @@ def add_cell_types_grouped( elif group == "FACS_1st_level_granularity": train_test_index = pd.read_csv("/home/owkin/project/train_test_index_dataframes/train_test_index_facs_1st_level.csv", index_col=1).iloc[:,1:] col_name = "grouping" - adata.obs["cell_types_grouped"] = train_test_index[col_name] + adata.obs[f"cell_types_grouped_{group}"] = train_test_index[col_name] return adata, train_test_index -def create_anndata_pseudobulk(adata: ad.AnnData, x: np.array) -> ad.AnnData: +def create_anndata_pseudobulk( + adata_obs: pd.DataFrame, adata_var_names: list, x: np.array +) -> ad.AnnData: """Creates an anndata object from a pseudobulk sample. Parameters ---------- - adata: ad.AnnData - AnnData aobject storing training set + adata_obs: pd.DataFrame + Obs dataframe from anndata object storing training set + adata_var_names: list + Gene names from the anndata object x: np.array pseudobulk sample @@ -132,14 +145,14 @@ def create_anndata_pseudobulk(adata: ad.AnnData, x: np.array) -> ad.AnnData: Anndata object storing the pseudobulk array """ df_obs = pd.DataFrame.from_dict( - [{col: adata.obs[col].value_counts().index[0] for col in adata.obs.columns}] + [{col: adata_obs[col].value_counts().index[0] for col in adata_obs.columns}] ) if len(x.shape) > 1 and x.shape[0] > 1: # several pseudobulks, so duplicate df_obs row df_obs = df_obs.loc[df_obs.index.repeat(x.shape[0])].reset_index(drop=True) df_obs.index = [f"sample_{idx}" for idx in df_obs.index] adata_pseudobulk = ad.AnnData(X=x, obs=df_obs) - adata_pseudobulk.var_names = adata.var_names + adata_pseudobulk.var_names = adata_var_names adata_pseudobulk.layers["counts"] = np.copy(x) adata_pseudobulk.raw = adata_pseudobulk @@ -169,10 +182,10 @@ def create_purified_pseudobulk_dataset( group.append(group_key) # pseudobulk dataset - adata_pseudobulk_rc = create_anndata_pseudobulk(adata, + adata_pseudobulk_rc = create_anndata_pseudobulk(adata.obs, adata.var_names, np.array(averaged_data["relative_counts"]) ) - adata_pseudobulk_counts = create_anndata_pseudobulk(adata, + adata_pseudobulk_counts = create_anndata_pseudobulk(adata.obs, adata.var_names, np.array(averaged_data["counts"]) ) adata_pseudobulk_rc.obs_names = group @@ -212,10 +225,10 @@ def create_uniform_pseudobulk_dataset( averaged_data["counts"].append(adata_sample.layers["counts"].sum(axis=0).tolist()[0]) # pseudobulk dataset - adata_pseudobulk_rc = create_anndata_pseudobulk(adata, + adata_pseudobulk_rc = create_anndata_pseudobulk(adata.obs, adata.var_names, np.array(averaged_data["relative_counts"]) ) - adata_pseudobulk_counts = create_anndata_pseudobulk(adata, + adata_pseudobulk_counts = create_anndata_pseudobulk(adata.obs, adata.var_names, np.array(averaged_data["counts"]) ) @@ -304,10 +317,10 @@ def create_dirichlet_pseudobulk_dataset( all_adata_samples.append(adata_sample) # pseudobulk dataset - adata_pseudobulk_rc = create_anndata_pseudobulk(adata, + adata_pseudobulk_rc = create_anndata_pseudobulk(adata.obs, adata.var_names, np.array(averaged_data["relative_counts"]) ) - adata_pseudobulk_counts = create_anndata_pseudobulk(adata, + adata_pseudobulk_counts = create_anndata_pseudobulk(adata.obs, adata.var_names, np.array(averaged_data["counts"]) ) diff --git a/benchmark_utils/latent_signature_utils.py b/benchmark_utils/latent_signature_utils.py index 731020abf0..bead373412 100644 --- a/benchmark_utils/latent_signature_utils.py +++ b/benchmark_utils/latent_signature_utils.py @@ -122,7 +122,7 @@ def create_latent_signature( adata_sampled.layers[count_key].mean(axis=0).reshape(1, -1) ) adata_sampled = create_anndata_pseudobulk( - adata_sampled, pseudobulk + adata_sampled.obs, adata_sampled.var_names, pseudobulk ) result = model.get_latent_representation( adata_sampled, get_pseudobulk = False diff --git a/run_benchmark.py b/run_benchmark.py new file mode 100644 index 0000000000..cc56128c5f --- /dev/null +++ b/run_benchmark.py @@ -0,0 +1,137 @@ +"""Pseudobulk benchmark.""" + +import argparse +import os +import scanpy as sc +import pandas as pd +import anndata as ad +import warnings +from loguru import logger +from sklearn.linear_model import LinearRegression +from typing import Optional + +from benchmark_utils import add_cell_types_grouped, create_signature +from run_benchmark_help import initialise_deconv_methods, load_preprocessed_datasets +from run_benchmark_config_dataclass import ( + RunBenchmarkConfig, + GRANULARITY_TO_DATASET, + SINGLE_CELL_DATASETS, +) + +def run_benchmark( + deconv_methods: list, + evaluation_datasets: list, + granularities: list, + evaluation_pseudobulk_samplings: Optional[list], + signature_matrices: Optional[list], + train_dataset: Optional[str], + n_variable_genes: Optional[int], + save: bool, + experiment_name: Optional[str], +): + """Run the deconvolution benchmark pipeline. + + The arguments are defined in a config yaml file passed to the RunBenchmarkConfig + dataclass. + """ + all_data = load_preprocessed_datasets( + evaluation_datasets, + train_dataset, + n_variable_genes, + ) + + if signature_matrices is not None: + all_data["signature_matrices"] = {} + for signature_matrix in signature_matrices: + logger.info(f"Loading signature matrix: {signature_matrix}...") + all_data["signature_matrices"][signature_matrix] = create_signature( + signature_matrix + ) + + # Will there be a problem for differentiation of FACS vs SC ? + for granularity in granularities: + logger.info( + f"Loading train/test index for granularity: {granularity}..." + ) + for dataset in all_data["datasets"]: + if GRANULARITY_TO_DATASET[granularity] == dataset: + all_data["datasets"][dataset]["dataset"], train_test_index = \ + add_cell_types_grouped( + all_data["datasets"][dataset]["dataset"], + granularity + ) + all_data["datasets"][dataset][granularity] = train_test_index + + + logger.info( + "All the data are now loaded." + ) + + for granularity in granularities: + logger.info( + f"Launching the deconvolution experiments for granularity: {granularity}..." + ) + deconv_methods_initialized = initialise_deconv_methods( + deconv_methods=deconv_methods, + all_data=all_data, + granularity=granularity, + train_dataset=train_dataset, + signature_matrices=signature_matrices, + ) + evaluation_dataset = GRANULARITY_TO_DATASET[granularity] + logger.info(f"Running evaluation on {evaluation_dataset}...") + if evaluation_dataset in SINGLE_CELL_DATASETS: + for evaluation_pseudobulk_sampling in evaluation_pseudobulk_samplings: + # TODO: TAKE from here, we also need the N_CELLS argument and to reformat the sanity checks to make them more readable (which was the whole point of this PR at first) + logger.info( + f"Creating pseudobulks with {evaluation_pseudobulk_sampling} " + "method..." + ) + for deconv_method_initialized in deconv_methods_initialized: + pass + else: + # direct pred + pass + + + + + # for deconv_method_name in deconv_methods: + # logger.info(f"Running deconvolution method: {deconv_method_name}") + # if not os.path.exists(f"{experiment_path}/{deconv_method_name}"): + # os.mkdir(f"{experiment_path}/{deconv_method_name}") + # # Instantiate deconvolution method + # method = deconv_methods[deconv_method_name] + # if method in FIT_SINGLE_CELL: + # # Fit deconvolution method + # method.fit(adata_train[:,filtered_genes].copy()) + # if method in CREATE_LATENT_SIGNATURE: + # # Create the signature matrix inside the latent space + # method.create_latent_signature(adata_train[:,filtered_genes]) + + # for sanity_check_name in sanity_checks: + # sanity_check_type = sanity_check_name.split("_")[3] + # logger.info(f"Run following sanity check: {sanity_check_name}") + # if not os.path.exists(f"{experiment_path}/{deconv_method_name}/{sanity_check_type}"): + # os.mkdir(f"{experiment_path}/{deconv_method_name}/{sanity_check_type}") + + # # Create the pseudobulk (or bulk) test and true simulated (or facs) proportions + # pseudobulk_test, df_proportions_test = sanity_check.create_test_data(adata_test.copy()) + + # # Perform test deconvolution + # deconv_results = method(pseudobulk_test, df_proportions_test) + # saving_path = f"{experiment_path}/{deconv_method_name}/{sanity_check_type}/{sanity_check_name}.csv" + # deconv_results.to_csv(saving_path) + # logger.info(f"The deconvolution results are saved inside the following directory: {saving_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", type=str, required=True, help="Path to the YAML configuration file" + ) + args = parser.parse_args() + + config_dict = RunBenchmarkConfig.from_config_yaml(config_path=args.config) + + run_benchmark(**config_dict) + # constants should be used for the methods HPs for now, but saved + logged as full configs \ No newline at end of file diff --git a/run_benchmark_config_dataclass.py b/run_benchmark_config_dataclass.py new file mode 100644 index 0000000000..26a70c897f --- /dev/null +++ b/run_benchmark_config_dataclass.py @@ -0,0 +1,345 @@ +"""Benchmark config class for a benchmark run.""" + +import dataclasses +import json +import os +import shutil +from loguru import logger +from dataclasses import asdict +from datetime import datetime +from pydantic.dataclasses import dataclass +from typing import Optional +import yaml +from zoneinfo import ZoneInfo + +from run_benchmark_constants import ( + DATASETS, + DECONV_METHODS, + EVALUATION_PSEUDOBULK_SAMPLINGS, + GRANULARITIES, + GRANULARITY_TO_DATASET, + MODEL_TO_FIT, + SIGNATURE_MATRIX_MODELS, + SIGNATURE_TO_GRANULARITY, + SINGLE_CELL_DATASETS, + SINGLE_CELL_GRANULARITIES, + TRAIN_DATASETS, +) + +def save_experiment(experiment_name: str): + """Save the logs and experiment outputs. + + Parameters + ---------- + experiment_name: str + The experiment directory name, unique to each experiment. + + Returns + ------- + full_path: str + The full experiment path. + """ + output_dir = "/home/owkin/project/run_benchmark_experiments" + if not os.path.exists(output_dir): + os.mkdir(output_dir) + logger.info(f"Created output dir: {output_dir}") + full_path = f"{output_dir}/{experiment_name}" + if not os.path.exists(full_path): + os.mkdir(full_path) + logger.info( + f"Created output dir: {full_path}" + ) + elif "experiment_over.txt" in os.listdir(f"{full_path}"): + raise ValueError( + f"An experiment was already finished at {full_path}. If you wish to " + "run a new experiment, please find another experiment name." + ) + else: + shutil.rmtree(full_path) + os.mkdir(full_path) + logger.warning( + "An experiment was already started but not finished at the path " + f"{full_path}. It is now deleted." + ) + # Save logs + logger.add(f"{full_path}/logs.txt") + logger.add( + f"{full_path}/warnings.txt", + level="WARNING", # also keep the warnings separate + ) + + return full_path + +@dataclass +class RunBenchmarkConfig: + """Full configuration for a benchmark deconvolution experiment. + + Parameters + ---------- + deconv_methods: list + All the deconvolution methods to compare to evaluate deconvolution. + evaluation_datasets: list + All the evaluation datasets on which to evaluate the deconvolution performance + of the deconvolution methods. Can be single cell datasets (for pseudobulk + simulations) or real bulk/facs data. + granularities: list + The granularities of the deconvolution to evaluate. + evaluation_pseudobulk_samplings: list + All type of samplings to perform for simulations on the single cell data in case + single cell datasets were given to evaluation datasets. + signature_matrices: list + The signature matrices to use for the methods requiring one. + train_dataset: str + The train dataset to use in case some deconvolution methods (usually scvi-verse + models) need fitting before evaluation. On the contrary, this is not the case + for models like NNLS, because the signature matrix was constructed on a train + dataset separately. + n_variable_genes: int + The number of most highly variable genes in case some deconvolution methods + (usually scvi-verse models) need to filter genes based on variance prior to + fitting and evaluation. + save: bool + Whether to save the deconvolution experiment outputs. + experiment_name: str + The experiment directory name, unique to each experiment. + """ + deconv_methods: list + evaluation_datasets: list + granularities: list + evaluation_pseudobulk_samplings: Optional[list] + signature_matrices: Optional[list] + train_dataset: Optional[str] + n_variable_genes: Optional[int] + save: bool + experiment_name: Optional[str] + + @classmethod + def from_config_yaml(cls, config_path: str): + """Return a dictionnary mapping evaluation config keys to their value.""" + with open(config_path, "r", encoding="utf-8") as file: + config_dict = yaml.safe_load(file) + + # Check types + config_dict = asdict( + cls( + deconv_methods = config_dict["deconv_methods"], + evaluation_datasets = config_dict["evaluation_datasets"], + granularities = config_dict["granularities"], + evaluation_pseudobulk_samplings = config_dict[ + "evaluation_pseudobulk_samplings"], + signature_matrices = config_dict["signature_matrices"], + train_dataset = config_dict["train_dataset"], + n_variable_genes = config_dict["n_variable_genes"], + save = config_dict["save"], + experiment_name = config_dict["experiment_name"], + ) + ) + + # Save experiment if required + if config_dict["save"]: + if config_dict["experiment_name"] is None: + paris_tz = ZoneInfo("Europe/Paris") + config_dict["experiment_name"] = datetime.now(paris_tz).strftime( + "%Y-%m-%d-%H-%M-%S" + ) + full_path = save_experiment(config_dict["experiment_name"]) + + # Test and homogenise missing arguments + if len(config_dict["deconv_methods"]) == 0: + message = ( + "You should provide at least one deconvolution method to run an " + "experiment." + ) + logger.error(message) + raise ValueError(message) + if len(config_dict["evaluation_datasets"]) == 0: + message = ( + "You should provide at least one evaluation dataset to run an " + "experiment." + ) + logger.error(message) + raise ValueError(message) + if len(config_dict["granularities"]) == 0: + message = ( + "You should provide at least one granularity to run an " + "experiment." + ) + logger.error(message) + raise ValueError(message) + if config_dict["evaluation_pseudobulk_samplings"] is not None and \ + len(config_dict["evaluation_pseudobulk_samplings"]) == 0: + config_dict["evaluation_pseudobulk_samplings"] = None + if config_dict["signature_matrices"] is not None and \ + len(config_dict["signature_matrices"]) == 0: + config_dict["signature_matrices"] = None + if config_dict["train_dataset"] is not None and \ + len(config_dict["train_dataset"]) == 0: + config_dict["train_dataset"] = None + + # Check that all provided arguments exist + if not set(config_dict["deconv_methods"]).issubset(DECONV_METHODS): + message = ( + "Only the following deconvolution methods can be used: " + f"{DECONV_METHODS}" + ) + logger.error(message) + raise NotImplementedError(message) + if not set(config_dict["evaluation_datasets"]).issubset(DATASETS): + message = ( + "Only the following evaluation datasets can be used: " + f"{DATASETS}" + ) + logger.error(message) + raise NotImplementedError(message) + if not set(config_dict["evaluation_pseudobulk_samplings"]).issubset( + EVALUATION_PSEUDOBULK_SAMPLINGS + ): + message = ( + "Only the following evaluation pseudobulk samplings can be used: " + f"{EVALUATION_PSEUDOBULK_SAMPLINGS}" + ) + logger.error(message) + raise NotImplementedError(message) + if not set(config_dict["granularities"]).issubset(GRANULARITIES): + message = ( + "Only the following granularities can be used: " + f"{GRANULARITIES}" + ) + logger.error(message) + raise NotImplementedError(message) + if not set(config_dict["signature_matrices"]).issubset( + set(SIGNATURE_TO_GRANULARITY.keys()) + ): + message = ( + "Only the following signature matrices can be used: " + f"{set(SIGNATURE_TO_GRANULARITY.keys())}" + ) + logger.error(message) + raise NotImplementedError(message) + if config_dict["train_dataset"] not in TRAIN_DATASETS: + message = ( + "Only the following train datasets can be used: " + f"{TRAIN_DATASETS}" + ) + logger.error(message) + raise NotImplementedError(message) + + # Check for missing arguments + if len(intersection := set(config_dict["deconv_methods"]).intersection( + MODEL_TO_FIT)) > 0: + if config_dict["train_dataset"] is None: + message = ( + "train_dataset must be provided when models that require fitting " + f"({intersection}) are provided as deconvolution methods." + ) + logger.error(message) + raise ValueError(message) + if config_dict["n_variable_genes"] is None: + message = ( + "n_variable_genes must be provided when models that require to be " + f"filtered to the most variables genes ({intersection}) are " + "provided as deconvolution methods." + ) + logger.error(message) + raise ValueError(message) + if len(intersection := set(config_dict["evaluation_datasets"]).intersection( + SINGLE_CELL_DATASETS)) > 0 and config_dict[ + "evaluation_pseudobulk_samplings" + ] is None: + message = ( + "evaluation_pseudobulk_samplings strategies must be provided when " + f"single cell datasets ({intersection}) are provided as evaluation " + "datasets." + ) + logger.error(message) + raise ValueError(message) + if len(intersection := set(config_dict["deconv_methods"]).intersection( + SIGNATURE_MATRIX_MODELS)) > 0: + if config_dict["signature_matrices"] is None: + message = ( + "No signature matrix was provided even though some methods " + f"({intersection}) require a signature matrix." + ) + logger.error(message) + raise ValueError(message) + for signature_matrix in config_dict["signature_matrices"]: + if SIGNATURE_TO_GRANULARITY[signature_matrix] not in \ + config_dict["granularities"]: + message = ( + "Signature matrix's associated granularity (signature matrix: " + f"{signature_matrix}, associated granularity: " + f"{SIGNATURE_TO_GRANULARITY[signature_matrix]}) does not match " + "any of the provided granularities: " + f"{config_dict['granularities']}." + ) + logger.error(message) + raise ValueError(message) + for granularity in config_dict["granularities"]: + if GRANULARITY_TO_DATASET[granularity] not in \ + config_dict["evaluation_datasets"]: + message = ( + f"The provided granularity {granularity} should be evaluated on " + f"the following dataset: {GRANULARITY_TO_DATASET['granularity']}. " + "However, none of the evaluation datasets (" + f"{config_dict['evaluation_datasets']}) contain this dataset." + ) + + # Check for useless arguments + if len(intersection := set(config_dict["deconv_methods"]).intersection( + MODEL_TO_FIT)) == 0: + if config_dict["train_dataset"] is not None: + logger.warning( + "A train dataset was provided even though none of the provided " + f"deconvolution methods ({config_dict['deconvolution_methods']}) " + "require fitting. Thus, train_dataset will not be used." + ) + config_dict["train_dataset"] = None + if config_dict["n_variable_genes"] is not None: + logger.warning( + "n_variable_genes was provided even though none of the provided " + f"deconvolution methods ({config_dict['deconvolution_methods']}) " + "require to be filtered to their most variable genes. Thus, this " + "argument will not be used." + ) + config_dict["n_variable_genes"] = None + if len(set(config_dict["evaluation_datasets"]).intersection( + SINGLE_CELL_DATASETS)) == 0 and config_dict[ + "evaluation_pseudobulk_samplings" + ] is not None: + logger.warning( + "evaluation_pseudobulk_samplings strategies were provided even though " + "none of the provided evaluation datasets (" + f"{config_dict['evaluation_datasets']}) are single cell and thus " + "none can be sampled for simulations. Thus, this argument will not be " + "used." + ) + config_dict["evaluation_pseudobulk_samplings"] = None + if len(set(config_dict["deconv_methods"]).intersection( + SIGNATURE_MATRIX_MODELS)) == 0 and config_dict["signature_matrices"] \ + is not None: + logger.warning( + "A signature matrix was provided even though none of the provided " + "deconvolution methods require one. Thus, this argument will not be " + "used." + ) + config_dict["signature_matrices"] = None + + # Check for likely long running time + if len(intersection1:=set(config_dict["deconv_methods"]).intersection( + MODEL_TO_FIT + ))> 0 and len(intersection2:=set(config_dict["granularities"]).intersection( + SINGLE_CELL_GRANULARITIES + )) > 1: + logger.warning( + f"Some deconvolution methods ({intersection1}) need fitting, while " + "several granularities needing fitting were provided for evaluation " + f"({intersection2}). Training on several granularities can take time." + ) + + if config_dict["save"]: + config_path = full_path + "/config.json" + with open(config_path, "w", encoding="utf-8") as json_file: + json.dump(config_dict, json_file) + logger.info(f"Saved config dict to {config_path}") + + return config_dict \ No newline at end of file diff --git a/run_benchmark_constants.py b/run_benchmark_constants.py new file mode 100644 index 0000000000..25a7a29676 --- /dev/null +++ b/run_benchmark_constants.py @@ -0,0 +1,81 @@ +"""All the constants used by run_benchmark.py to configure the pipeline.""" + +DECONV_METHODS = { + "scVI": { + "_target_": "", + "adata": None, + "model_path": "", + "save_model": False, + }, + "DestVI": { + "_target_": "", + "adata": None, + "prior_alphas": None, + "n_sample": None, + "model_path1": "", + "model_path2": "", + "cell_type_key": "", + "save_model": False, + }, + "MixUpVI": { + "_target_": "run_benchmark_help.MixUpVIMethod", + "adata_train": None, + "model_path": "", + "cell_type_group": "cell_types_grouped", + "save_model": False, + }, + "NNLS": { + "_target_": "run_benchmark_help.NNLSMethod", + "signature_matrix": None, + }, + "TAPE": { + "_target_": "", + }, + "Scaden": { + "_target_": "", + } +} + +DATASETS = { # something that takes preprocessing into account !! (not train/test split yet though) + "TOY": { + "_target_": "", + }, + "CTI": { + "_target_": "run_benchmark_help.load_cti", + "n_variable_genes": None, + }, + "BULK_FACS": { + "_target_": "run_benchmark_help.load_bulk_facs", + }, +} + +EVALUATION_PSEUDOBULK_SAMPLINGS = {"PURIFIED", "UNIFORM", "DIRICHLET"} +TRAIN_DATASETS = {"CTI"} +SINGLE_CELL_DATASETS = {"TOY", "CTI"} +MODEL_TO_FIT = {"MixUpVI", "scVI", "DestVI"} +SIGNATURE_MATRIX_MODELS = {"NNLS", "TAPE", "Scaden"} +SINGLE_CELL_GRANULARITIES = { + "1st_level_granularity", + "2nd_level_granularity", + "3rd_level_granularity", + "4th_level_granularity", +} +GRANULARITIES = SINGLE_CELL_GRANULARITIES.union({ + "FACS_1st_level_granularity", +}) +SIGNATURE_TO_GRANULARITY = { + "laughney": "1st_level_granularity", + "CTI_1st_level_granularity": "1st_level_granularity", + "CTI_2nd_level_granularity": "2nd_level_granularity", + "CTI_3rd_level_granularity": "3rd_level_granularity", + "CTI_4th_level_granularity": "4th_level_granularity", + "FACS_1st_level_granularity": "FACS_1st_level_granularity", +} +GRANULARITY_TO_DATASET = { + "1st_level_granularity": "CTI", + "2nd_level_granularity": "CTI", + "3rd_level_granularity": "CTI", + "4th_level_granularity": "CTI", + "FACS_1st_level_granularity": "CTI", + # add the one for TOY +} \ No newline at end of file From f298bca6f19d678621ffab1066933622e1a29f0f Mon Sep 17 00:00:00 2001 From: SimonGrouard Date: Thu, 10 Oct 2024 14:16:00 +0000 Subject: [PATCH 6/6] Other refacto changes --- run_benchmark_configs/config_test.yaml | 11 + run_benchmark_help.py | 272 +++++++++++++++++++++++++ run_pseudobulk_benchmark.py | 4 +- 3 files changed, 285 insertions(+), 2 deletions(-) create mode 100644 run_benchmark_configs/config_test.yaml create mode 100644 run_benchmark_help.py diff --git a/run_benchmark_configs/config_test.yaml b/run_benchmark_configs/config_test.yaml new file mode 100644 index 0000000000..cd937bf7a3 --- /dev/null +++ b/run_benchmark_configs/config_test.yaml @@ -0,0 +1,11 @@ +## A test config to run the benchmark + +deconv_methods: ["NNLS", "MixUpVI"] +evaluation_datasets: ["CTI", "BULK_FACS"] +granularities: ["2nd_level_granularity", "FACS_1st_level_granularity"] +evaluation_pseudobulk_samplings: ["DIRICHLET"] +signature_matrices: ["CTI_2nd_level_granularity", "FACS_1st_level_granularity"] +train_dataset: "CTI" +n_variable_genes: 3000 +save: False +experiment_name: "test_benchmark" diff --git a/run_benchmark_help.py b/run_benchmark_help.py new file mode 100644 index 0000000000..cb70b13bcf --- /dev/null +++ b/run_benchmark_help.py @@ -0,0 +1,272 @@ +""" +""" + +from __future__ import annotations + +import anndata as ad +import importlib +import pandas as pd +import scanpy as sc +from abc import abstractmethod +from functools import partial +from loguru import logger +from sklearn.linear_model import LinearRegression + +from benchmark_utils import ( + create_anndata_pseudobulk, + create_latent_signature, + fit_mixupvi, + preprocess_scrna, +) +from run_benchmark_constants import ( + DATASETS, + DECONV_METHODS, + MODEL_TO_FIT, + SIGNATURE_MATRIX_MODELS, + SIGNATURE_TO_GRANULARITY, +) + +# # DECONV_METHODS_FITTING_REQUIREMENTS = { +# # "MixUpVI": {"single_cell", "variable_genes"}, +# # "NNLS": {"signature_matrix"}, +# # "scVI": {"single_cell", "variable_genes"}, +# # "DestVI": {"single_cell", "variable_genes"}, +# # "TAPE": {"signature_matrix",}, +# # "Scaden": {"signature_matrix",}, +# # } +# # FITTING_DATASETS = {"CTI"} +# # SINGLE_CELL_DATASETS = {"TOY", "CTI"} + +# DECONV_EVALUATIONS = {"CORRELATION", "GROUP_CORRELATION"} + +def load_cti(n_variable_genes: int, **kwargs): + """TODO: Right now, it's just a raw function to test the code. + """ + adata = sc.read("/home/owkin/project/cti/cti_adata.h5ad") + adata = preprocess_scrna(adata, + keep_genes=n_variable_genes, + batch_key="donor_id") + return adata + +def load_bulk_facs(**kwargs): + """TODO: Right now, it's just a raw function to test the code. + """ + # Load data + facs_results = pd.read_csv( + "/home/owkin/project/bulk_facs/240214_majorCelltypes.csv", index_col=0 + ).drop(["No.B.Cells.in.Live.Cells","NKT.Cells.in.Live.Cells"],axis=1).set_index( + "Sample" + ) + facs_results = facs_results.rename( + { + "B.Cells.in.Live.Cells":"B", + "NK.Cells.in.Live.Cells":"NK", + "T.Cells.in.Live.Cells":"T", + "Monocytes.in.Live.Cells":"Mono", + "Dendritic.Cells.in.Live.Cells":"DC", + }, axis=1 + ) + facs_results = facs_results.dropna() + bulk_data = pd.read_csv( + ( + "/home/owkin/project/bulk_facs/" + "gene_counts_batchs1-5_raw.csv" + # "gene_counts20230103_batch1-5_all_cleaned-TPMnorm-allpatients.tsv" + ), + index_col=0, + # sep="\t" + ).T + common_samples = pd.read_csv( + "/home/owkin/project/bulk_facs/RNA-FACS_common-samples.csv", index_col=0 + ) + # Align bulk and facs samples + common_facs = common_samples.set_index("FACS.ID")["Patient"] + facs_results = facs_results.loc[facs_results.index.isin(common_facs.keys())] + facs_results = facs_results.rename(index=common_facs) + common_bulk = common_samples.set_index("RNAseq_ID")["Patient"] + bulk_data = bulk_data.loc[bulk_data.index.isin(common_bulk.keys())] + bulk_data = bulk_data.rename(index=common_bulk) + bulk_data = bulk_data.loc[facs_results.index].T + + return (bulk_data, facs_results) + +def use_nnls_method(to_deconvolve: pd.DataFrame, signature_matrix: pd.DataFrame): + """TODO: Right now, it's just a raw function to test the code. + """ + intersected_signature = signature_matrix.loc[signature_matrix.index.intersection( + to_deconvolve.index + )] + deconv = LinearRegression(positive=True).fit( + intersected_signature, to_deconvolve.loc[intersected_signature.index] + ) + deconv_results = pd.DataFrame( + deconv.coef_, index=to_deconvolve.columns, columns=intersected_signature.columns + ) + deconv_results = deconv_results.div( + deconv_results.sum(axis=1), axis=0 + ) # to sum up to 1 + + return deconv_results + +class AbstractDeconvolutionMethod: + """TODO: Right now, it's just a raw class to test the code. + """ + @abstractmethod + def apply_deconvolution(self, to_deconvolve: pd.DataFrame, **kwargs): + """Apply deconvolution method on data to deconvolve. + + Parameters + ---------- + to_deconvolve: pd.DataFrame + The data to deconvolve. + """ + +class NNLSMethod(AbstractDeconvolutionMethod): + """TODO: Right now, it's just a raw class to test the code. + """ + def __init__(self, signature_matrix): + self.signature_matrix = signature_matrix + + def apply_deconvolution(self, to_deconvolve: pd.DataFrame): + """ + """ + deconvolution_results = use_nnls_method(to_deconvolve, self.signature_matrix) + return deconvolution_results + +class MixUpVIMethod(AbstractDeconvolutionMethod): + """TODO: Right now, it's just a raw class to test the code. + """ + def __init__( + self, + adata_train: ad.AnnData, + cell_type_group: str, + model_path: str = "", + save_model: bool = False, + ): + self.filtered_genes = adata_train.var.index[ + adata_train.var["highly_variable"] + ].tolist() + adata_train = adata_train[:,self.filtered_genes] + self.adata_obs = adata_train.obs + + logger.info("Fitting MixUpVI...") + self.mixupvi = fit_mixupvi( + adata=adata_train.copy(), + model_path=model_path, + cell_type_group=cell_type_group, + save_model=save_model, + ) + + logger.info("Training over. Creation of latent signature matrix...") + self.adata_latent_signature = create_latent_signature( + adata=adata_train, + model=self.mixupvi, + use_mixupvi=False, # should be equal to use_mixupvi, but if True, + # then it averages as many cells as self.n_cells_per-pseudobulk from mixupvae + # (and not the number we wish in the benchmark) + average_all_cells = True, + ) + + def apply_deconvolution(self, to_deconvolve: pd.DataFrame): + """ + """ + adata_to_deconvolve = create_anndata_pseudobulk( + adata_obs=self.adata_obs, + adata_var_names=self.filtered_genes, + x=to_deconvolve.T.values, + ) + latent_adata = self.mixupvi.get_latent_representation( + adata_to_deconvolve, get_pseudobulk=False + ) + deconvolution_results = use_nnls_method( + latent_adata, self.adata_latent_signature + ) + return deconvolution_results + +def initialise_func(func_config: dict): + """ + """ + target_path = func_config["_target_"] + module_name, func_name = target_path.rsplit(".", 1) + module = importlib.import_module(module_name) + initialised_func = getattr(module, func_name) + kwargs = {k: v for k, v in func_config.items() if k != "_target_"} + return initialised_func, kwargs + +def load_preprocessed_datasets( + evaluation_datasets: list, + train_dataset: str | None = None, + n_variable_genes: int | None = None, +): + """ + """ + data = {"datasets": {}} + for evaluation_dataset in evaluation_datasets: + logger.info(f"Loading dataset: {evaluation_dataset}...") + data["datasets"][evaluation_dataset] = {} + dataset_config = DATASETS[evaluation_dataset] + initialised_func, kwargs = initialise_func(dataset_config) + kwargs["n_variable_genes"] = n_variable_genes + data["datasets"][evaluation_dataset]["dataset"] = initialised_func(**kwargs) + + if train_dataset is not None and train_dataset not in evaluation_datasets: + logger.info(f"Loading train dataset: {train_dataset}...") + data["datasets"][train_dataset] = {} + dataset_config = DATASETS[train_dataset] + initialised_func, kwargs = initialise_func(dataset_config) + kwargs["n_variable_genes"] = n_variable_genes + data["datasets"][train_dataset]["dataset"] = initialised_func(**kwargs) + + return data + +def initialise_deconv_methods( + deconv_methods, + all_data, + granularity: str, + train_dataset: str, + signature_matrices: list, +): + """ + """ + deconv_methods_initialised = {} + for deconv_method in deconv_methods: + deconv_method_func, kwargs = initialise_func(DECONV_METHODS[deconv_method]) + if (deconv_method in MODEL_TO_FIT)==(deconv_method in SIGNATURE_MATRIX_MODELS): + message = ( + "The codebase is not formatted yet to have a deconvolution method " + "needing both to be fit and a user-provided signature matrix, or none " + "of these two options. It needs one of these options only." + ) + logger.error(message) + raise NotImplementedError(message) + if deconv_method in MODEL_TO_FIT: + logger.info(f"Training deconvolution method {deconv_method}...") + all_train_dset = all_data["datasets"][train_dataset] + train_dset = all_train_dset["dataset"][ + all_train_dset[granularity]["Train index"] + ] + kwargs["adata_train"] = train_dset + if "cell_type_group" in kwargs: + # TODO: Ugly code because specific to MixUpVI only + # More generally, improve to allow to pass other kwargs arguments + kwargs["adata_train"].obs = kwargs["adata_train"].obs.rename( + {f"cell_types_grouped_{granularity}": "cell_types_grouped"}, + axis = 1 + ) + deconv_method_initialised = deconv_method_func(**kwargs) + deconv_methods_initialised[deconv_method] = deconv_method_initialised + elif deconv_method in SIGNATURE_MATRIX_MODELS: + for signature_matrix in signature_matrices: + if SIGNATURE_TO_GRANULARITY[signature_matrix]==granularity: + kwargs["signature_matrix"] = all_data["signature_matrices"][ + signature_matrix + ] + deconv_method_initialised = deconv_method_func(**kwargs) + deconv_methods_initialised[ + f"deconv_method_{signature_matrix}" + ] = deconv_method_initialised + + logger.info("Initialisation of the deconvolution methods complete.") + return deconv_methods_initialised + + \ No newline at end of file diff --git a/run_pseudobulk_benchmark.py b/run_pseudobulk_benchmark.py index 9c01cd611b..fe686ad3d8 100644 --- a/run_pseudobulk_benchmark.py +++ b/run_pseudobulk_benchmark.py @@ -166,7 +166,7 @@ model_path = f"project/models/{BENCHMARK_DATASET}_{BENCHMARK_CELL_TYPE_GROUP}_{N_GENES}_mixupvi.pkl" mixupvi_model = fit_mixupvi(adata_train[:,filtered_genes].copy(), model_path, - cell_type_group="cell_types_grouped", + cell_type_group="cell_types_grouped", # TO CHANGE IN NEW BENCHMARK (f"cell_types_grouped_{granularity}") save_model=SAVE_MODEL, ) generative_models["MixupVI"] = mixupvi_model @@ -214,7 +214,7 @@ ) adata_bulk = create_anndata_pseudobulk( - adata=adata_train[:,filtered_genes], x=bulk_mixupvi.T.values + adata_train[:,filtered_genes].obs, filtered_genes, x=bulk_mixupvi.T.values ) latent_bulk = generative_models[model].get_latent_representation( adata_bulk, get_pseudobulk=False