Skip to content

Commit

Permalink
Refactoring to improve module load times
Browse files Browse the repository at this point in the history
  • Loading branch information
shz9 committed Jun 11, 2024
1 parent 24e9eb5 commit 433e0ee
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 89 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ object wasn't refreshed.
- Updated `VIPRSBMA` & `VIPRSGridSearch` to only consider models that
successfully converged.
- Fixed bug in `psuedo_metrics` when extracting summary statistics data.
- Streamlined evaluation code.
- Refactored code to slightly reduce import/load time.

### Added

Expand Down
2 changes: 1 addition & 1 deletion bin/viprs_evaluate
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import viprs as vp
from magenpy.utils.system_utils import makedir
from magenpy import SampleTable
from viprs.eval import eval_metric_names, eval_incremental_metrics
from viprs.eval.continuous_metrics import r2_stats
from viprs.eval.eval_utils import r2_stats


print(fr"""
Expand Down
11 changes: 6 additions & 5 deletions bin/viprs_fit
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,8 @@ def prepare_model(args, verbose=True):
print("- Model selection criterion:", args.grid_metric)

from functools import partial
from viprs.model.gridsearch.HyperparameterGrid import HyperparameterGrid

from viprs.model.VIPRS import VIPRS
from viprs.model.VIPRSMix import VIPRSMix
from viprs.model.gridsearch.VIPRSGridSearch import VIPRSGridSearch
from viprs.model.gridsearch.VIPRSBMA import VIPRSBMA
from viprs.model.gridsearch.HyperparameterSearch import BayesOpt

if args.hyp_search == 'EM':
if args.model == 'VIPRS':
Expand All @@ -296,6 +291,10 @@ def prepare_model(args, verbose=True):

elif args.hyp_search in ('BMA', 'GS'):

from viprs.model.gridsearch.HyperparameterGrid import HyperparameterGrid
from viprs.model.gridsearch.VIPRSGridSearch import VIPRSGridSearch
from viprs.model.gridsearch.VIPRSBMA import VIPRSBMA

grid = HyperparameterGrid(sigma_epsilon_grid=args.sigma_epsilon_grid,
sigma_epsilon_steps=args.sigma_epsilon_steps,
pi_grid=args.pi_grid,
Expand All @@ -316,6 +315,8 @@ def prepare_model(args, verbose=True):

else:

from viprs.model.gridsearch.HyperparameterSearch import BayesOpt

base_model = partial(VIPRS,
float_precision=args.float_precision,
low_memory=not args.use_symmetric_ld,
Expand Down
2 changes: 1 addition & 1 deletion tests/conda_manual_testing.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ python_versions=("3.8" "3.9" "3.10" "3.11" "3.12")
for version in "${python_versions[@]}"
do
# Create a new conda environment for the Python version
conda create --name "viprs_py$version" python="$version" -y
conda create --name "viprs_py$version" python="$version" -y || return 1

# Activate the conda environment
conda activate "viprs_py$version"
Expand Down
3 changes: 2 additions & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import viprs as vp
from viprs.model.vi.e_step_cpp import check_blas_support, check_omp_support
import numpy as np
from viprs.model.VIPRSMix import VIPRSMix
from viprs.model.gridsearch.HyperparameterGrid import HyperparameterGrid
from viprs.model.gridsearch.VIPRSGrid import VIPRSGrid
import shutil
Expand Down Expand Up @@ -49,7 +50,7 @@ def viprsmix_model(gdl_object):
Initialize a VIPRS model using data pre-packaged with magenpy.
Make this data loader available to all tests.
"""
return vp.VIPRSMix(gdl_object, K=10, verbose=False)
return VIPRSMix(gdl_object, K=10, verbose=False)


@pytest.fixture(scope='module')
Expand Down
5 changes: 1 addition & 4 deletions viprs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@

from .model.VIPRS import VIPRS
from .model.VIPRSMix import VIPRSMix
from .model.gridsearch.VIPRSGridSearch import VIPRSGridSearch
from .model.gridsearch.HyperparameterGrid import HyperparameterGrid
from .utils.data_utils import *

__version__ = '0.1.2'
__release_date__ = 'May 2024'
__release_date__ = 'June 2024'
63 changes: 38 additions & 25 deletions viprs/eval/binary_metrics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
import numpy as np
from sklearn.metrics import (
auc,
roc_auc_score,
average_precision_score,
precision_recall_curve,
f1_score
)
from .continuous_metrics import incremental_r2
from scipy.stats import norm
import statsmodels.api as sm
from .eval_utils import fit_linear_model
import pandas as pd


Expand All @@ -20,6 +12,7 @@ def roc_auc(true_val, pred_val):
:param true_val: The response value or phenotype (a numpy binary vector with 0s and 1s)
:param pred_val: The predicted value or PRS (a numpy vector)
"""
from sklearn.metrics import roc_auc_score
return roc_auc_score(true_val, pred_val)


Expand All @@ -31,6 +24,7 @@ def pr_auc(true_val, pred_val):
:param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s)
:param pred_val: The predicted value or PRS (a numpy vector)
"""
from sklearn.metrics import precision_recall_curve, auc
precision, recall, thresholds = precision_recall_curve(true_val, pred_val)
return auc(recall, precision)

Expand All @@ -42,6 +36,7 @@ def avg_precision(true_val, pred_val):
:param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s)
:param pred_val: The predicted value or PRS (a numpy vector)
"""
from sklearn.metrics import average_precision_score
return average_precision_score(true_val, pred_val)


Expand All @@ -52,6 +47,7 @@ def f1(true_val, pred_val):
:param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s)
:param pred_val: The predicted value or PRS (a numpy vector)
"""
from sklearn.metrics import f1_score
return f1_score(true_val, pred_val)


Expand All @@ -69,12 +65,15 @@ def mcfadden_r2(true_val, pred_val, covariates=None):
"""

if covariates is None:
add_intercept = False
covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const'])
else:
covariates = sm.add_constant(covariates)
add_intercept = True

null_result = sm.Logit(true_val, covariates).fit(disp=0)
full_result = sm.Logit(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0)
null_result = fit_linear_model(true_val, covariates,
family='binomial', add_intercept=add_intercept)
full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val),
family='binomial', add_intercept=add_intercept)

return 1. - (full_result.llf / null_result.llf)

Expand All @@ -92,12 +91,15 @@ def cox_snell_r2(true_val, pred_val, covariates=None):
"""

if covariates is None:
add_intercept = False
covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const'])
else:
covariates = sm.add_constant(covariates)
add_intercept = True

null_result = sm.Logit(true_val, covariates).fit(disp=0)
full_result = sm.Logit(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0)
null_result = fit_linear_model(true_val, covariates,
family='binomial', add_intercept=add_intercept)
full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val),
family='binomial', add_intercept=add_intercept)
n = true_val.shape[0]

return 1. - np.exp(-2 * (full_result.llf - null_result.llf) / n)
Expand All @@ -116,12 +118,15 @@ def nagelkerke_r2(true_val, pred_val, covariates=None):
"""

if covariates is None:
add_intercept = False
covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const'])
else:
covariates = sm.add_constant(covariates)
add_intercept = True

null_result = sm.Logit(true_val, covariates).fit(disp=0)
full_result = sm.Logit(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0)
null_result = fit_linear_model(true_val, covariates,
family='binomial', add_intercept=add_intercept)
full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val),
family='binomial', add_intercept=add_intercept)
n = true_val.shape[0]

# First compute the Cox & Snell R2:
Expand Down Expand Up @@ -158,6 +163,8 @@ def liability_r2(true_val, pred_val, covariates=None, return_all_r2=False):

# Second, compute the prevalence and the standard normal quantile of the prevalence:

from scipy.stats import norm

k = np.mean(true_val)
z2 = norm.pdf(norm.ppf(1.-k))**2
mult_factor = k*(1. - k) / z2
Expand Down Expand Up @@ -194,12 +201,15 @@ def liability_probit_r2(true_val, pred_val, covariates=None, return_all_r2=False
"""

if covariates is None:
add_intercept = False
covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const'])
else:
covariates = sm.add_constant(covariates)
add_intercept = True

null_result = sm.Probit(true_val, covariates).fit(disp=0)
full_result = sm.Probit(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0)
null_result = fit_linear_model(true_val, covariates,
family='binomial', link='probit', add_intercept=add_intercept)
full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val),
family='binomial', link='probit', add_intercept=add_intercept)

null_var = np.var(null_result.predict())
null_r2 = null_var / (null_var + 1.)
Expand All @@ -224,7 +234,7 @@ def liability_logit_r2(true_val, pred_val, covariates=None, return_all_r2=False)
https://pubmed.ncbi.nlm.nih.gov/22714935/
The R^2 is defined as:
R2_{probit} = Var(pred) / (Var(pred) + pi^2 / 3)
R2_{logit} = Var(pred) / (Var(pred) + pi^2 / 3)
Where Var(pred) is the variance of the predicted liability.
Expand All @@ -239,12 +249,15 @@ def liability_logit_r2(true_val, pred_val, covariates=None, return_all_r2=False)
"""

if covariates is None:
add_intercept = False
covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const'])
else:
covariates = sm.add_constant(covariates)
add_intercept = True

null_result = sm.Probit(true_val, covariates).fit(disp=0)
full_result = sm.Probit(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0)
null_result = fit_linear_model(true_val, covariates,
family='binomial', add_intercept=add_intercept)
full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val),
family='binomial', add_intercept=add_intercept)

null_var = np.var(null_result.predict())
null_r2 = null_var / (null_var + (np.pi**2 / 3))
Expand Down
59 changes: 10 additions & 49 deletions viprs/eval/continuous_metrics.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,6 @@
import numpy as np
import pandas as pd
from scipy import stats
import statsmodels.api as sm


def r2_stats(r2_val, n):
"""
Compute the confidence interval and p-value for a given R-squared (proportion of variance
explained) value.
This function and the formulas therein are based on the following paper
by Momin et al. 2023: https://doi.org/10.1016/j.ajhg.2023.01.004 as well as
the implementation in the R package `PRSmix`:
https://github.com/buutrg/PRSmix/blob/main/R/get_PRS_acc.R#L63
:param r2_val: The R^2 value to compute the confidence interval/p-value for.
:param n: The sample size used to compute the R^2 value
:return: A dictionary with the R^2 value, the lower and upper values of the confidence interval,
the p-value, and the standard error of the R^2 metric.
"""

assert 0. < r2_val < 1., "R^2 value must be between 0 and 1."

# Compute the variance of the R^2 value:
r2_var = (4. * r2_val * (1. - r2_val) ** 2 * (n - 2) ** 2) / ((n ** 2 - 1) * (n + 3))

# Compute the standard errors for the R^2 value
# as well as the lower and upper values for
# the confidence interval:
r2_se = np.sqrt(r2_var)
lower_r2 = r2_val - 1.97 * r2_se
upper_r2 = r2_val + 1.97 * r2_se

# Compute the p-value assuming a Chi-squared distribution with 1 degree of freedom:
pval = stats.chi2.sf((r2_val / r2_se) ** 2, df=1)

return {
'R2': r2,
'Lower_R2': lower_r2,
'Upper_R2': upper_r2,
'P_Value': pval,
'SE': r2_se,
}
from .eval_utils import fit_linear_model


def r2(true_val, pred_val):
Expand All @@ -54,6 +11,8 @@ def r2(true_val, pred_val):
:param true_val: The response value or phenotype (a numpy vector)
:param pred_val: The predicted value or PRS (a numpy vector)
"""
from scipy import stats

_, _, r_val, _, _ = stats.linregress(pred_val, true_val)
return r_val ** 2

Expand Down Expand Up @@ -95,12 +54,14 @@ def incremental_r2(true_val, pred_val, covariates=None, return_all_r2=False):
"""

if covariates is None:
add_intercept = False
covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const'])
else:
covariates = sm.add_constant(covariates)
add_intercept = True

null_result = sm.OLS(true_val, covariates).fit(disp=0)
full_result = sm.OLS(true_val, covariates.assign(pred_val=pred_val)).fit(disp=0)
null_result = fit_linear_model(true_val, covariates,add_intercept=add_intercept)
full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val),
add_intercept=add_intercept)

if return_all_r2:
return {
Expand All @@ -125,7 +86,7 @@ def partial_correlation(true_val, pred_val, covariates):
the same way as the predictions and response.
"""

true_response = sm.OLS(true_val, sm.add_constant(covariates)).fit(disp=0)
pred_response = sm.OLS(pred_val, sm.add_constant(covariates)).fit(disp=0)
true_response = fit_linear_model(true_val, covariates, add_intercept=True)
pred_response = fit_linear_model(pred_val, covariates, add_intercept=True)

return np.corrcoef(true_response.resid, pred_response.resid)[0, 1]
Loading

0 comments on commit 433e0ee

Please sign in to comment.