From 433e0ee3ef8157aa77b358544722e8b15c302c3b Mon Sep 17 00:00:00 2001 From: Shadi Date: Mon, 10 Jun 2024 22:36:23 -0700 Subject: [PATCH] Refactoring to improve module load times --- CHANGELOG.md | 2 + bin/viprs_evaluate | 2 +- bin/viprs_fit | 11 ++--- tests/conda_manual_testing.sh | 2 +- tests/test_basic.py | 3 +- viprs/__init__.py | 5 +-- viprs/eval/binary_metrics.py | 63 +++++++++++++++----------- viprs/eval/continuous_metrics.py | 59 +++++-------------------- viprs/eval/eval_utils.py | 76 ++++++++++++++++++++++++++++++++ viprs/eval/pseudo_metrics.py | 3 +- viprs/model/BayesPRSModel.py | 3 +- viprs/model/VIPRS.py | 2 +- 12 files changed, 142 insertions(+), 89 deletions(-) create mode 100644 viprs/eval/eval_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a40ba3b..1dfbe30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ object wasn't refreshed. - Updated `VIPRSBMA` & `VIPRSGridSearch` to only consider models that successfully converged. - Fixed bug in `psuedo_metrics` when extracting summary statistics data. +- Streamlined evaluation code. +- Refactored code to slightly reduce import/load time. ### Added diff --git a/bin/viprs_evaluate b/bin/viprs_evaluate index 8245199..b78962d 100644 --- a/bin/viprs_evaluate +++ b/bin/viprs_evaluate @@ -30,7 +30,7 @@ import viprs as vp from magenpy.utils.system_utils import makedir from magenpy import SampleTable from viprs.eval import eval_metric_names, eval_incremental_metrics -from viprs.eval.continuous_metrics import r2_stats +from viprs.eval.eval_utils import r2_stats print(fr""" diff --git a/bin/viprs_fit b/bin/viprs_fit index e84c52c..9fb9c9e 100644 --- a/bin/viprs_fit +++ b/bin/viprs_fit @@ -271,13 +271,8 @@ def prepare_model(args, verbose=True): print("- Model selection criterion:", args.grid_metric) from functools import partial - from viprs.model.gridsearch.HyperparameterGrid import HyperparameterGrid - from viprs.model.VIPRS import VIPRS from viprs.model.VIPRSMix import VIPRSMix - from viprs.model.gridsearch.VIPRSGridSearch import VIPRSGridSearch - from viprs.model.gridsearch.VIPRSBMA import VIPRSBMA - from viprs.model.gridsearch.HyperparameterSearch import BayesOpt if args.hyp_search == 'EM': if args.model == 'VIPRS': @@ -296,6 +291,10 @@ def prepare_model(args, verbose=True): elif args.hyp_search in ('BMA', 'GS'): + from viprs.model.gridsearch.HyperparameterGrid import HyperparameterGrid + from viprs.model.gridsearch.VIPRSGridSearch import VIPRSGridSearch + from viprs.model.gridsearch.VIPRSBMA import VIPRSBMA + grid = HyperparameterGrid(sigma_epsilon_grid=args.sigma_epsilon_grid, sigma_epsilon_steps=args.sigma_epsilon_steps, pi_grid=args.pi_grid, @@ -316,6 +315,8 @@ def prepare_model(args, verbose=True): else: + from viprs.model.gridsearch.HyperparameterSearch import BayesOpt + base_model = partial(VIPRS, float_precision=args.float_precision, low_memory=not args.use_symmetric_ld, diff --git a/tests/conda_manual_testing.sh b/tests/conda_manual_testing.sh index 0832a6b..2383b5e 100644 --- a/tests/conda_manual_testing.sh +++ b/tests/conda_manual_testing.sh @@ -16,7 +16,7 @@ python_versions=("3.8" "3.9" "3.10" "3.11" "3.12") for version in "${python_versions[@]}" do # Create a new conda environment for the Python version - conda create --name "viprs_py$version" python="$version" -y + conda create --name "viprs_py$version" python="$version" -y || return 1 # Activate the conda environment conda activate "viprs_py$version" diff --git a/tests/test_basic.py b/tests/test_basic.py index 7aceec8..a4bb0da 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -3,6 +3,7 @@ import viprs as vp from viprs.model.vi.e_step_cpp import check_blas_support, check_omp_support import numpy as np +from viprs.model.VIPRSMix import VIPRSMix from viprs.model.gridsearch.HyperparameterGrid import HyperparameterGrid from viprs.model.gridsearch.VIPRSGrid import VIPRSGrid import shutil @@ -49,7 +50,7 @@ def viprsmix_model(gdl_object): Initialize a VIPRS model using data pre-packaged with magenpy. Make this data loader available to all tests. """ - return vp.VIPRSMix(gdl_object, K=10, verbose=False) + return VIPRSMix(gdl_object, K=10, verbose=False) @pytest.fixture(scope='module') diff --git a/viprs/__init__.py b/viprs/__init__.py index f3b3064..7791cd3 100644 --- a/viprs/__init__.py +++ b/viprs/__init__.py @@ -1,9 +1,6 @@ from .model.VIPRS import VIPRS -from .model.VIPRSMix import VIPRSMix -from .model.gridsearch.VIPRSGridSearch import VIPRSGridSearch -from .model.gridsearch.HyperparameterGrid import HyperparameterGrid from .utils.data_utils import * __version__ = '0.1.2' -__release_date__ = 'May 2024' +__release_date__ = 'June 2024' diff --git a/viprs/eval/binary_metrics.py b/viprs/eval/binary_metrics.py index 91ab010..9cd8f4d 100644 --- a/viprs/eval/binary_metrics.py +++ b/viprs/eval/binary_metrics.py @@ -1,14 +1,6 @@ import numpy as np -from sklearn.metrics import ( - auc, - roc_auc_score, - average_precision_score, - precision_recall_curve, - f1_score -) from .continuous_metrics import incremental_r2 -from scipy.stats import norm -import statsmodels.api as sm +from .eval_utils import fit_linear_model import pandas as pd @@ -20,6 +12,7 @@ def roc_auc(true_val, pred_val): :param true_val: The response value or phenotype (a numpy binary vector with 0s and 1s) :param pred_val: The predicted value or PRS (a numpy vector) """ + from sklearn.metrics import roc_auc_score return roc_auc_score(true_val, pred_val) @@ -31,6 +24,7 @@ def pr_auc(true_val, pred_val): :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) :param pred_val: The predicted value or PRS (a numpy vector) """ + from sklearn.metrics import precision_recall_curve, auc precision, recall, thresholds = precision_recall_curve(true_val, pred_val) return auc(recall, precision) @@ -42,6 +36,7 @@ def avg_precision(true_val, pred_val): :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) :param pred_val: The predicted value or PRS (a numpy vector) """ + from sklearn.metrics import average_precision_score return average_precision_score(true_val, pred_val) @@ -52,6 +47,7 @@ def f1(true_val, pred_val): :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) :param pred_val: The predicted value or PRS (a numpy vector) """ + from sklearn.metrics import f1_score return f1_score(true_val, pred_val) @@ -69,12 +65,15 @@ def mcfadden_r2(true_val, pred_val, covariates=None): """ if covariates is None: + add_intercept = False covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) else: - covariates = sm.add_constant(covariates) + add_intercept = True - null_result = sm.Logit(true_val, covariates).fit(disp=0) - full_result = sm.Logit(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0) + null_result = fit_linear_model(true_val, covariates, + family='binomial', add_intercept=add_intercept) + full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), + family='binomial', add_intercept=add_intercept) return 1. - (full_result.llf / null_result.llf) @@ -92,12 +91,15 @@ def cox_snell_r2(true_val, pred_val, covariates=None): """ if covariates is None: + add_intercept = False covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) else: - covariates = sm.add_constant(covariates) + add_intercept = True - null_result = sm.Logit(true_val, covariates).fit(disp=0) - full_result = sm.Logit(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0) + null_result = fit_linear_model(true_val, covariates, + family='binomial', add_intercept=add_intercept) + full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), + family='binomial', add_intercept=add_intercept) n = true_val.shape[0] return 1. - np.exp(-2 * (full_result.llf - null_result.llf) / n) @@ -116,12 +118,15 @@ def nagelkerke_r2(true_val, pred_val, covariates=None): """ if covariates is None: + add_intercept = False covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) else: - covariates = sm.add_constant(covariates) + add_intercept = True - null_result = sm.Logit(true_val, covariates).fit(disp=0) - full_result = sm.Logit(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0) + null_result = fit_linear_model(true_val, covariates, + family='binomial', add_intercept=add_intercept) + full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), + family='binomial', add_intercept=add_intercept) n = true_val.shape[0] # First compute the Cox & Snell R2: @@ -158,6 +163,8 @@ def liability_r2(true_val, pred_val, covariates=None, return_all_r2=False): # Second, compute the prevalence and the standard normal quantile of the prevalence: + from scipy.stats import norm + k = np.mean(true_val) z2 = norm.pdf(norm.ppf(1.-k))**2 mult_factor = k*(1. - k) / z2 @@ -194,12 +201,15 @@ def liability_probit_r2(true_val, pred_val, covariates=None, return_all_r2=False """ if covariates is None: + add_intercept = False covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) else: - covariates = sm.add_constant(covariates) + add_intercept = True - null_result = sm.Probit(true_val, covariates).fit(disp=0) - full_result = sm.Probit(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0) + null_result = fit_linear_model(true_val, covariates, + family='binomial', link='probit', add_intercept=add_intercept) + full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), + family='binomial', link='probit', add_intercept=add_intercept) null_var = np.var(null_result.predict()) null_r2 = null_var / (null_var + 1.) @@ -224,7 +234,7 @@ def liability_logit_r2(true_val, pred_val, covariates=None, return_all_r2=False) https://pubmed.ncbi.nlm.nih.gov/22714935/ The R^2 is defined as: - R2_{probit} = Var(pred) / (Var(pred) + pi^2 / 3) + R2_{logit} = Var(pred) / (Var(pred) + pi^2 / 3) Where Var(pred) is the variance of the predicted liability. @@ -239,12 +249,15 @@ def liability_logit_r2(true_val, pred_val, covariates=None, return_all_r2=False) """ if covariates is None: + add_intercept = False covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) else: - covariates = sm.add_constant(covariates) + add_intercept = True - null_result = sm.Probit(true_val, covariates).fit(disp=0) - full_result = sm.Probit(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0) + null_result = fit_linear_model(true_val, covariates, + family='binomial', add_intercept=add_intercept) + full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), + family='binomial', add_intercept=add_intercept) null_var = np.var(null_result.predict()) null_r2 = null_var / (null_var + (np.pi**2 / 3)) diff --git a/viprs/eval/continuous_metrics.py b/viprs/eval/continuous_metrics.py index 72ba59c..abaf6aa 100644 --- a/viprs/eval/continuous_metrics.py +++ b/viprs/eval/continuous_metrics.py @@ -1,49 +1,6 @@ import numpy as np import pandas as pd -from scipy import stats -import statsmodels.api as sm - - -def r2_stats(r2_val, n): - """ - Compute the confidence interval and p-value for a given R-squared (proportion of variance - explained) value. - - This function and the formulas therein are based on the following paper - by Momin et al. 2023: https://doi.org/10.1016/j.ajhg.2023.01.004 as well as - the implementation in the R package `PRSmix`: - https://github.com/buutrg/PRSmix/blob/main/R/get_PRS_acc.R#L63 - - :param r2_val: The R^2 value to compute the confidence interval/p-value for. - :param n: The sample size used to compute the R^2 value - - :return: A dictionary with the R^2 value, the lower and upper values of the confidence interval, - the p-value, and the standard error of the R^2 metric. - - """ - - assert 0. < r2_val < 1., "R^2 value must be between 0 and 1." - - # Compute the variance of the R^2 value: - r2_var = (4. * r2_val * (1. - r2_val) ** 2 * (n - 2) ** 2) / ((n ** 2 - 1) * (n + 3)) - - # Compute the standard errors for the R^2 value - # as well as the lower and upper values for - # the confidence interval: - r2_se = np.sqrt(r2_var) - lower_r2 = r2_val - 1.97 * r2_se - upper_r2 = r2_val + 1.97 * r2_se - - # Compute the p-value assuming a Chi-squared distribution with 1 degree of freedom: - pval = stats.chi2.sf((r2_val / r2_se) ** 2, df=1) - - return { - 'R2': r2, - 'Lower_R2': lower_r2, - 'Upper_R2': upper_r2, - 'P_Value': pval, - 'SE': r2_se, - } +from .eval_utils import fit_linear_model def r2(true_val, pred_val): @@ -54,6 +11,8 @@ def r2(true_val, pred_val): :param true_val: The response value or phenotype (a numpy vector) :param pred_val: The predicted value or PRS (a numpy vector) """ + from scipy import stats + _, _, r_val, _, _ = stats.linregress(pred_val, true_val) return r_val ** 2 @@ -95,12 +54,14 @@ def incremental_r2(true_val, pred_val, covariates=None, return_all_r2=False): """ if covariates is None: + add_intercept = False covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) else: - covariates = sm.add_constant(covariates) + add_intercept = True - null_result = sm.OLS(true_val, covariates).fit(disp=0) - full_result = sm.OLS(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0) + null_result = fit_linear_model(true_val, covariates,add_intercept=add_intercept) + full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), + add_intercept=add_intercept) if return_all_r2: return { @@ -125,7 +86,7 @@ def partial_correlation(true_val, pred_val, covariates): the same way as the predictions and response. """ - true_response = sm.OLS(true_val, sm.add_constant(covariates)).fit(disp=0) - pred_response = sm.OLS(pred_val, sm.add_constant(covariates)).fit(disp=0) + true_response = fit_linear_model(true_val, covariates, add_intercept=True) + pred_response = fit_linear_model(pred_val, covariates, add_intercept=True) return np.corrcoef(true_response.resid, pred_response.resid)[0, 1] diff --git a/viprs/eval/eval_utils.py b/viprs/eval/eval_utils.py new file mode 100644 index 0000000..1deda35 --- /dev/null +++ b/viprs/eval/eval_utils.py @@ -0,0 +1,76 @@ +import numpy as np + + +def r2_stats(r2_val, n): + """ + Compute the confidence interval and p-value for a given R-squared (proportion of variance + explained) value. + + This function and the formulas therein are based on the following paper + by Momin et al. 2023: https://doi.org/10.1016/j.ajhg.2023.01.004 as well as + the implementation in the R package `PRSmix`: + https://github.com/buutrg/PRSmix/blob/main/R/get_PRS_acc.R#L63 + + :param r2_val: The R^2 value to compute the confidence interval/p-value for. + :param n: The sample size used to compute the R^2 value + + :return: A dictionary with the R^2 value, the lower and upper values of the confidence interval, + the p-value, and the standard error of the R^2 metric. + + """ + + assert 0. < r2_val < 1., "R^2 value must be between 0 and 1." + + # Compute the variance of the R^2 value: + r2_var = (4. * r2_val * (1. - r2_val) ** 2 * (n - 2) ** 2) / ((n ** 2 - 1) * (n + 3)) + + # Compute the standard errors for the R^2 value + # as well as the lower and upper values for + # the confidence interval: + r2_se = np.sqrt(r2_var) + lower_r2 = r2_val - 1.97 * r2_se + upper_r2 = r2_val + 1.97 * r2_se + + from scipy import stats + + # Compute the p-value assuming a Chi-squared distribution with 1 degree of freedom: + pval = stats.chi2.sf((r2_val / r2_se) ** 2, df=1) + + return { + 'R2': r2_val, + 'Lower_R2': lower_r2, + 'Upper_R2': upper_r2, + 'P_Value': pval, + 'SE': r2_se, + } + + +def fit_linear_model(y, x, family='gaussian', link=None, add_intercept=False): + """ + Fit a linear model to the data `x` and `y` and return the model object. + + :param y: The independent variable (a numpy vector) + :param x: The design matrix (a pandas DataFrame) + :param family: The family of the model. Must be either 'gaussian' or 'binomial'. + :param link: The link function to use for the model. If None, the default link function. + :param add_intercept: If True, add an intercept term to the model. + """ + + assert y.shape[0] == x.shape[0], ("The number of rows in the design matrix " + "and the independent variable must match.") + assert family in ('gaussian', 'binomial'), "The family must be either 'gaussian' or 'binomial'." + if family == 'binomial': + assert link in ('logit', 'probit', None), "The link function must be either 'logit', 'probit' or None." + + import statsmodels.api as sm + + if add_intercept: + x = sm.add_constant(x) + + if family == 'gaussian': + return sm.OLS(y, x).fit() + elif family == 'binomial': + if link == 'logit' or link is None: + return sm.Logit(y, x).fit(disp=0) + elif link == 'probit': + return sm.Probit(y, x).fit(disp=0) diff --git a/viprs/eval/pseudo_metrics.py b/viprs/eval/pseudo_metrics.py index 758984e..3183960 100644 --- a/viprs/eval/pseudo_metrics.py +++ b/viprs/eval/pseudo_metrics.py @@ -1,4 +1,3 @@ -from magenpy import GWADataLoader import numpy as np @@ -20,6 +19,8 @@ def _match_variant_stats(test_gdl, prs_beta_table): (2) The inferred PRS effect sizes, (3) The LD-weighted PRS effect sizes (q). """ + from magenpy import GWADataLoader + # Sanity checks: assert isinstance(test_gdl, GWADataLoader), "The test/validation set must be an instance of GWADataLoader." assert test_gdl.ld is not None, "The test/validation set must have LD matrices initialized." diff --git a/viprs/model/BayesPRSModel.py b/viprs/model/BayesPRSModel.py index eb193a9..401836f 100644 --- a/viprs/model/BayesPRSModel.py +++ b/viprs/model/BayesPRSModel.py @@ -3,7 +3,6 @@ import os.path as osp from ..utils.compute_utils import expand_column_names, dict_mean -from magenpy.utils.model_utils import merge_snp_tables class BayesPRSModel: @@ -186,6 +185,8 @@ def harmonize_data(self, gdl=None, parameter_table=None): common_chroms = sorted(list(set(snp_tables.keys()).intersection(set(parameter_table.keys())))) + from magenpy.utils.model_utils import merge_snp_tables + for c in common_chroms: try: diff --git a/viprs/model/VIPRS.py b/viprs/model/VIPRS.py index 6527be9..597f0f2 100644 --- a/viprs/model/VIPRS.py +++ b/viprs/model/VIPRS.py @@ -3,7 +3,6 @@ from tqdm.auto import tqdm from .BayesPRSModel import BayesPRSModel -from magenpy.stats.h2.ldsc import simple_ldsc from ..utils.exceptions import OptimizationDivergence from .vi.e_step import e_step from .vi.e_step_cpp import cpp_e_step @@ -269,6 +268,7 @@ def initialize_theta(self, theta_0=None): # then initialize using the SNP heritability estimate try: + from magenpy.stats.h2.ldsc import simple_ldsc naive_h2g = np.clip(simple_ldsc(self.gdl), a_min=1e-3, a_max=1. - 1e-3) except Exception as e: naive_h2g = np.random.uniform(low=.001, high=.999)