From fd6269eaad849965e83af14c38d59c8d346b5442 Mon Sep 17 00:00:00 2001 From: evenmm Date: Fri, 5 May 2023 12:25:20 +0200 Subject: [PATCH] Slider and M protein plots --- plot_auc_simdata.py | 251 ++++++++++++++++++++++++++++++++++++++++++++ plot_exp_values.py | 18 ++++ plot_legend_CIs.py | 20 ++++ slider.py | 85 +++++++++++++++ 4 files changed, 374 insertions(+) create mode 100644 plot_auc_simdata.py create mode 100644 plot_exp_values.py create mode 100644 plot_legend_CIs.py create mode 100644 slider.py diff --git a/plot_auc_simdata.py b/plot_auc_simdata.py new file mode 100644 index 0000000..a7758c7 --- /dev/null +++ b/plot_auc_simdata.py @@ -0,0 +1,251 @@ +import numpy as np +import matplotlib.pyplot as plt +import scipy +import pandas as pd +import seaborn as sns + +from utilities import * +from BNN_model import * + +# Initialize random number generator +RANDOM_SEED = 42 +np.random.seed(RANDOM_SEED) +rng = np.random.default_rng(RANDOM_SEED) +print(f"Running on PyMC v{pm.__version__}") +#SAVEDIR = "/data/evenmm/plots/" +SAVEDIR = "./plots/Bayesian_estimates_simdata_BNN/" + +script_index = 1 + +# Settings +if int(script_index % 3) == 0: + true_sigma_obs = 0 +elif int(script_index % 3) == 1: + true_sigma_obs = 1 +elif int(script_index % 3) == 2: + true_sigma_obs = 2.5 + +if script_index >= 3: + RANDOM_EFFECTS = True +else: + RANDOM_EFFECTS = False + +RANDOM_EFFECTS_TEST = False + +model_name = "BNN" +N_patients = 150 +psi_prior="lognormal" +WEIGHT_PRIOR = "Student_out" #"Horseshoe" # "Student_out" #"symmetry_fix" #"iso_normal" "Student_out" +N_samples = 10_000 +N_tuning = 10_000 +ADADELTA = True +target_accept = 0.99 +CI_with_obs_noise = True +PLOT_RESISTANT = True +FUNNEL_REPARAMETRIZATION = False +MODEL_RANDOM_EFFECTS = True +N_HIDDEN = 2 +P = 5 # Number of covariates +P0 = int(P / 2) # A guess of the true number of nonzero parameters is needed for defining the global shrinkage parameter +true_omega = np.array([0.10,0.05,0.20]) + +M_number_of_measurements =7 +y_resolution = 80 # Number of timepoints to evaluate the posterior of y in +true_omega_for_psi = 0.1 + +max_time = 1200 #3000 #1500 +days_between_measurements = int(max_time/M_number_of_measurements) +measurement_times = days_between_measurements * np.linspace(0, M_number_of_measurements, M_number_of_measurements) + +evaluation_time = measurement_times[:4][-1] + 1 +print(evaluation_time) + +treatment_history = np.array([Treatment(start=0, end=measurement_times[-1], id=1)]) +name = "simdata_"+model_name+"_"+str(script_index)+"_M_"+str(M_number_of_measurements)+"_P_"+str(P)+"_N_pax_"+str(N_patients)+"_N_sampl_"+str(N_samples)+"_N_tune_"+str(N_tuning)+"_FUNNEL_"+str(FUNNEL_REPARAMETRIZATION)+"_RNDM_EFFECTS_"+str(RANDOM_EFFECTS)+"_WT_PRIOR_"+str(WEIGHT_PRIOR+"_N_HIDDN_"+str(N_HIDDEN)) + +N_patients_test = 50 + +recur_or_not_BNN = [ + [1.,0.,0.,1.,0.,0.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0. +,0.,1.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.,0.,0.,0.,1.,0.,0.,0. +,1.,0.], +[1.,0.,0.,1.,0.,0.,1.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0. +,0.,1.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.,0.,0.,0.,1.,0.,0.,0. +,1.,0.], + [1.,0.,0.,1.,0.,0.,1.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0. +,0.,1.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.,0.,0.,1.,1.,0.,0.,0. +,1.,0.], + [1.,0.,0.,1.,0.,0.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0. +,0.,1.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.,0.,0.,0.,1.,0.,0.,0. +,1.,0.], +[1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, + 1, 0], +[1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, + 1, 0,] +] + +p_recs_BNN = np.array([[1.00000e+00,0.00000e+00,5.19250e-02,1.00000e+00,5.00000e-05,0.00000e+00 + ,4.05575e-01,2.00000e-04,7.45575e-01,0.00000e+00,2.47500e-03,0.00000e+00 + ,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00 + ,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,9.59275e-01 + ,0.00000e+00,9.74175e-01,0.00000e+00,0.00000e+00,9.99825e-01,0.00000e+00 + ,6.75000e-04,0.00000e+00,2.50000e-05,0.00000e+00,0.00000e+00,6.13250e-02 + ,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,0.00000e+00,0.00000e+00 + ,0.00000e+00,1.50000e-04,9.42375e-01,0.00000e+00,0.00000e+00,0.00000e+00 + ,1.00000e+00,0.00000e+00], + [1.00000e+00,0.00000e+00,1.77750e-02,1.00000e+00,0.00000e+00,0.00000e+00 + ,2.96350e-01,2.50000e-05,6.81375e-01,0.00000e+00,2.25000e-04,0.00000e+00 + ,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00 + ,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,9.39000e-01 + ,0.00000e+00,9.65450e-01,0.00000e+00,0.00000e+00,9.99925e-01,0.00000e+00 + ,5.00000e-05,0.00000e+00,2.50000e-05,0.00000e+00,0.00000e+00,2.48000e-02 + ,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,0.00000e+00,0.00000e+00 + ,0.00000e+00,0.00000e+00,9.55525e-01,0.00000e+00,0.00000e+00,0.00000e+00 + ,1.00000e+00,0.00000e+00], + [1.00000e+00,0.00000e+00,3.31500e-02,1.00000e+00,2.75000e-04,0.00000e+00 + ,2.98825e-01,8.25000e-04,5.79400e-01,0.00000e+00,2.82500e-03,0.00000e+00 + ,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00 + ,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,8.51800e-01 + ,0.00000e+00,8.87375e-01,0.00000e+00,0.00000e+00,9.98425e-01,0.00000e+00 + ,1.20000e-03,0.00000e+00,6.75000e-04,0.00000e+00,0.00000e+00,5.97250e-02 + ,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,2.50000e-05,0.00000e+00 + ,0.00000e+00,1.32500e-03,8.75850e-01,0.00000e+00,0.00000e+00,0.00000e+00 + ,9.99875e-01,0.00000e+00], + [9.99875e-01,0.00000e+00,2.30250e-01,9.99225e-01,1.36500e-02,0.00000e+00 + ,4.30125e-01,2.60750e-02,6.46450e-01,0.00000e+00,7.52250e-02,0.00000e+00 + ,0.00000e+00,0.00000e+00,4.75000e-03,0.00000e+00,0.00000e+00,0.00000e+00 + ,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,8.53925e-01 + ,0.00000e+00,8.64100e-01,0.00000e+00,0.00000e+00,9.83975e-01,0.00000e+00 + ,4.88250e-02,0.00000e+00,4.55000e-03,0.00000e+00,0.00000e+00,2.28150e-01 + ,2.00000e-04,0.00000e+00,0.00000e+00,9.99925e-01,1.00000e-04,0.00000e+00 + ,0.00000e+00,2.06250e-02,7.88850e-01,0.00000e+00,0.00000e+00,0.00000e+00 + ,9.83175e-01,0.00000e+00], + [9.99650e-01, 0.00000e+00, 1.89525e-01, 9.98775e-01, 1.33500e-02, 0.00000e+00, + 3.88825e-01, 2.35000e-02, 5.92750e-01, 0.00000e+00, 5.76000e-02, 0.00000e+00, + 0.00000e+00, 0.00000e+00, 3.47500e-03, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 0.00000e+00, 0.00000e+00, 7.50000e-05, 0.00000e+00, 1.00000e+00, 8.00275e-01, + 0.00000e+00, 8.17250e-01, 0.00000e+00, 0.00000e+00, 9.73450e-01, 0.00000e+00, + 3.60750e-02, 0.00000e+00, 5.22500e-03, 0.00000e+00, 0.00000e+00, 1.95325e-01, + 1.50000e-04, 0.00000e+00, 0.00000e+00, 9.99650e-01, 9.25000e-04, 0.00000e+00, + 0.00000e+00, 1.84750e-02, 7.66100e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 9.77350e-01, 0.00000e+00], + [9.99600e-01, 0.00000e+00, 1.96975e-01, 9.97275e-01, 2.17500e-02, 0.00000e+00, + 3.97275e-01, 2.98500e-02, 6.12750e-01, 0.00000e+00, 3.67500e-02, 0.00000e+00, + 0.00000e+00, 0.00000e+00, 5.95000e-03, 0.00000e+00, 0.00000e+00, 5.00000e-05, + 0.00000e+00, 0.00000e+00, 2.00000e-04, 0.00000e+00, 1.00000e+00, 7.83650e-01, + 0.00000e+00, 7.90425e-01, 0.00000e+00, 0.00000e+00, 9.58925e-01, 0.00000e+00, + 2.50000e-02, 0.00000e+00, 1.67000e-02, 0.00000e+00, 0.00000e+00, 1.40300e-01, + 2.00000e-04, 0.00000e+00, 2.50000e-05, 9.99200e-01, 6.75000e-04, 0.00000e+00, + 0.00000e+00, 3.85750e-02, 7.53050e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 9.70200e-01, 0.00000e+00,] + ]) + + +recur_or_not_LIN = [ +[1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., + 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., + 1., 0.], + [1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., + 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., + 1., 0.], +[1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., + 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., + 1., 0.], +[1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., + 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., + 1., 0.], +[1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., + 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., + 1., 0.], +[1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., + 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., + 1., 0.] +] + +p_recs_LIN = np.array([ +[9.99175e-01, 0.00000e+00, 5.81800e-01, 9.39900e-01, 9.50000e-03, 0.00000e+00, + 1.87775e-01, 1.79375e-01, 6.75025e-01, 0.00000e+00, 8.99500e-02, 1.25000e-04, + 0.00000e+00, 0.00000e+00, 5.87750e-02, 0.00000e+00, 0.00000e+00, 1.00000e-04, + 0.00000e+00, 0.00000e+00, 1.12500e-03, 4.75000e-04, 1.00000e+00, 9.34475e-01, + 2.50000e-05, 8.79500e-01, 0.00000e+00, 0.00000e+00, 9.82225e-01, 0.00000e+00, + 8.20000e-02, 0.00000e+00, 1.67000e-02, 0.00000e+00, 2.50000e-05, 2.66025e-01, + 3.00000e-04, 0.00000e+00, 1.25000e-04, 9.94900e-01, 6.97500e-03, 0.00000e+00, + 0.00000e+00, 7.74250e-02, 5.17300e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 9.17650e-01, 0.00000e+00,], +[9.99700e-01, 0.00000e+00, 6.19600e-01, 9.59500e-01, 5.37500e-03, 0.00000e+00, + 1.84250e-01, 1.59575e-01, 7.24525e-01, 0.00000e+00, 8.30750e-02, 5.00000e-05, + 0.00000e+00, 0.00000e+00, 5.20250e-02, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 0.00000e+00, 0.00000e+00, 4.00000e-04, 2.50000e-05, 1.00000e+00, 9.58650e-01, + 0.00000e+00, 9.15225e-01, 0.00000e+00, 0.00000e+00, 9.92675e-01, 0.00000e+00, + 7.89250e-02, 0.00000e+00, 8.57500e-03, 0.00000e+00, 0.00000e+00, 2.69350e-01, + 1.00000e-04, 0.00000e+00, 5.00000e-05, 9.97825e-01, 2.25000e-03, 0.00000e+00, + 0.00000e+00, 6.46500e-02, 5.55475e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 9.37575e-01, 0.00000e+00,], +[9.99750e-01, 0.00000e+00, 6.52225e-01, 9.66275e-01, 7.00000e-03, 0.00000e+00, + 1.91450e-01, 1.79550e-01, 7.28450e-01, 0.00000e+00, 7.64750e-02, 2.50000e-05, + 0.00000e+00, 0.00000e+00, 4.55750e-02, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 0.00000e+00, 0.00000e+00, 4.00000e-04, 2.50000e-05, 1.00000e+00, 9.63475e-01, + 0.00000e+00, 9.21650e-01, 0.00000e+00, 0.00000e+00, 9.95000e-01, 0.00000e+00, + 7.29000e-02, 0.00000e+00, 8.00000e-03, 0.00000e+00, 0.00000e+00, 2.56025e-01, + 1.25000e-04, 0.00000e+00, 1.00000e-04, 9.98675e-01, 3.55000e-03, 0.00000e+00, + 0.00000e+00, 7.01000e-02, 6.09500e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 9.43975e-01, 0.00000e+00,], +[1.00000e+00, 0.00000e+00, 5.81725e-01, 9.89125e-01, 1.35000e-03, 0.00000e+00, + 1.12675e-01, 1.16075e-01, 7.28225e-01, 0.00000e+00, 2.96000e-02, 2.50000e-05, + 0.00000e+00, 0.00000e+00, 2.08000e-02, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 0.00000e+00, 0.00000e+00, 7.50000e-05, 2.50000e-05, 1.00000e+00, 9.74750e-01, + 0.00000e+00, 9.46425e-01, 0.00000e+00, 0.00000e+00, 9.97725e-01, 0.00000e+00, + 2.30500e-02, 0.00000e+00, 3.47500e-03, 0.00000e+00, 0.00000e+00, 1.92150e-01, + 0.00000e+00, 0.00000e+00, 0.00000e+00, 9.99825e-01, 1.42500e-03, 0.00000e+00, + 0.00000e+00, 3.19000e-02, 5.33925e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 9.78950e-01, 0.00000e+00,], +[1.00000e+00, 0.00000e+00, 6.28575e-01, 9.96000e-01, 4.50000e-04, 0.00000e+00, + 1.01475e-01, 9.34500e-02, 7.89375e-01, 0.00000e+00, 2.32000e-02, 0.00000e+00, + 0.00000e+00, 0.00000e+00, 1.48000e-02, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00, 9.88100e-01, + 0.00000e+00, 9.74425e-01, 0.00000e+00, 0.00000e+00, 9.99425e-01, 0.00000e+00, + 1.77250e-02, 0.00000e+00, 1.60000e-03, 0.00000e+00, 0.00000e+00, 2.04900e-01, + 0.00000e+00, 0.00000e+00, 0.00000e+00, 9.99950e-01, 3.00000e-04, 0.00000e+00, + 0.00000e+00, 2.22750e-02, 5.74200e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 9.89875e-01, 0.00000e+00,], +[1.00000e+00, 0.00000e+00, 7.00900e-01, 9.92800e-01, 5.00000e-04, 0.00000e+00, + 8.87750e-02, 1.08375e-01, 8.12125e-01, 0.00000e+00, 3.29500e-02, 0.00000e+00, + 0.00000e+00, 0.00000e+00, 2.14500e-02, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00, 9.92275e-01, + 0.00000e+00, 9.81775e-01, 0.00000e+00, 0.00000e+00, 9.99600e-01, 0.00000e+00, + 2.54500e-02, 0.00000e+00, 1.20000e-03, 0.00000e+00, 0.00000e+00, 2.55350e-01, + 0.00000e+00, 0.00000e+00, 0.00000e+00, 9.99925e-01, 2.50000e-04, 0.00000e+00, + 0.00000e+00, 2.75000e-02, 5.86725e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00, + 9.89675e-01, 0.00000e+00,] +]) + +import sklearn.metrics as metrics +for ii in range(len(recur_or_not_BNN)): + fpr_BNN, tpr_BNN, threshold = metrics.roc_curve(recur_or_not_BNN[ii], p_recs_BNN[ii]) #(y_test, preds) + roc_auc_BNN = metrics.auc(fpr_BNN, tpr_BNN) + + fpr_LIN, tpr_LIN, threshold = metrics.roc_curve(recur_or_not_LIN[ii], p_recs_LIN[ii]) #(y_test, preds) + roc_auc_LIN = metrics.auc(fpr_LIN, tpr_LIN) + + print("threshold:\n", threshold) + print("fpr_BNN:\n", fpr_BNN) + print("tpr_BNN:\n", tpr_BNN) + print("roc_auc_BNN:\n", roc_auc_BNN) + + print("\nfpr_BNN:\n", fpr_LIN) + print("tpr_BNN:\n", tpr_LIN) + print("roc_auc_BNN:\n", roc_auc_LIN) + + #plt.title('Receiver Operating Characteristic') + plt.plot(fpr_LIN, tpr_LIN, color=plt.cm.viridis(0.3), label = 'Linear reg. (AUC = %0.2f)' % roc_auc_LIN) + plt.plot(fpr_BNN, tpr_BNN, color=plt.cm.viridis(0.7), label = 'BNN (AUC = %0.2f)' % roc_auc_BNN) + plt.legend(loc = 'lower right') + plt.plot([0,1], [0,1], color='grey', linestyle='--') + plt.xlim([0,1]) + plt.ylim([0,1]) + plt.ylabel('True Positive Rate') + plt.xlabel('False Positive Rate') + plt.savefig(SAVEDIR+"AUC_"+str(ii)+"_"+str(N_patients_test)+"_test_patients_"+name+".pdf") + plt.show() diff --git a/plot_exp_values.py b/plot_exp_values.py new file mode 100644 index 0000000..e6d0d97 --- /dev/null +++ b/plot_exp_values.py @@ -0,0 +1,18 @@ +import numpy as np +import matplotlib.pyplot as plt +import scipy +import pandas as pd +import seaborn as sns + +from utilities import * +# plotting endtime M protein value +times = np.linspace(0,18) +mean_theta = np.log(0.002) # barely an increase: 10 to 14 over the course of 6 months +theta = np.log(0.04) # already unrealistically big: 10 to 20 over 18 days +gr = np.exp(theta) +Mprot = 10*np.exp(gr*times) +fig, ax = plt.subplots() +ax.plot(times, Mprot) +#ax.plot(times, np.log(0.04)*sigmoid(times)) +#ax.plot(times, times/(np.sqrt(1+times**2))) +plt.show() diff --git a/plot_legend_CIs.py b/plot_legend_CIs.py new file mode 100644 index 0000000..5bc6945 --- /dev/null +++ b/plot_legend_CIs.py @@ -0,0 +1,20 @@ +import numpy as np +import matplotlib.pyplot as plt +import scipy +import pandas as pd +import seaborn as sns + +SAVEDIR = "./plots/Bayesian_estimates_simdata_BNN/" + +fig, ax1 = plt.subplots() +shade_array = [0.7, 0.5, 0.35] +for index, critical_value in enumerate([0.05, 0.25, 0.45]): # Corresponding to confidence levels 90, 50, and 10 + ax1.fill_between([-1,0,1], [1,-1,-1], [1,-1,-1], color=plt.cm.copper(1-critical_value), label='%3.0f %% CI, resistant M protein' % (100*(1-2*critical_value)), zorder=0+index*0.1) +for index, critical_value in enumerate([0.05, 0.25, 0.45]): # Corresponding to confidence levels 90, 50, and 10 + ax1.fill_between([-1,0,1], [1,-1,-1], [1,-1,-1], color=plt.cm.bone(shade_array[index]), label='%3.0f %% CI, total M protein' % (100*(1-2*critical_value)), zorder=1+index*0.1) +ax1.plot([-1,0,1], [1,-1,-1], linestyle='--', marker='', zorder=3, color='cyan', label="True M protein (total)") +ax1.plot([-1,0,1], [1,-1,-1], linestyle='--', marker='', zorder=2.9, color=plt.cm.hot(0.2), label="True M protein (resistant)") +ax1.plot([-1,0,1], [1,-1,-1], linestyle='', marker='x', zorder=4, color='k', label="Observed M protein") #[ax1.axvline(time, color="k", linewidth=0.5, linestyle="-") for time in measurement_times] +plt.legend(loc="upper right") +plt.savefig(SAVEDIR+"xxxxAUC_.pdf", dpi=300) +plt.show() \ No newline at end of file diff --git a/slider.py b/slider.py new file mode 100644 index 0000000..1e2054d --- /dev/null +++ b/slider.py @@ -0,0 +1,85 @@ +from utilities import * +from matplotlib.widgets import Slider, Button, RadioButtons +RANDOM_SEED = 499 +np.random.seed(RANDOM_SEED) +rng = np.random.default_rng(RANDOM_SEED) +PLOT_RESISTANT = True +FUNNEL_REPARAMETRIZATION = False +MODEL_RANDOM_EFFECTS = True +N_HIDDEN = 2 +RANDOM_EFFECTS = True +P = 5 # Number of covariates +P0 = int(P / 2) # A guess of the true number of nonzero parameters is needed for defining the global shrinkage parameter +true_omega = np.array([0.10, 0.05, 0.20]) +true_omega_for_psi = 0.1 +true_sigma_obs = 0 +M_number_of_measurements = 12 +N_patients = 150 +y_resolution = 80 # Number of timepoints to evaluate the posterior of y in +max_time = 180 +days_between_measurements = int(max_time/M_number_of_measurements) +measurement_times = days_between_measurements * np.linspace(0, M_number_of_measurements, M_number_of_measurements) +treatment_history = np.array([Treatment(start=0, end=measurement_times[-1], id=1)]) + +X, patient_dictionary, parameter_dictionary, expected_theta_1, true_theta_rho_s, true_rho_s = generate_simulated_patients(deepcopy(measurement_times), treatment_history, true_sigma_obs, N_patients, P, get_expected_theta_from_X_2, true_omega, true_omega_for_psi, seed=42, RANDOM_EFFECTS=RANDOM_EFFECTS) + +# Our patient +ii = 19 +patient = patient_dictionary[ii] +parameters = parameter_dictionary[ii] +measurement_times = patient.measurement_times +treatment_history = patient.get_treatment_history() +parameters = Parameters(Y_0=50, pi_r=0.1*parameters.pi_r, g_r=30*parameters.g_r, g_s=30*parameters.g_s, k_1=0, sigma=true_sigma_obs) +patient = Patient(parameters, measurement_times, treatment_history, name=str(ii)) +plot_mprotein(patient, "", "./obs.pdf", PLOT_PARAMETERS=True, parameters=parameters) +#Mprotein_values = patient.get_Mprotein_values() +time_zero = min(treatment_history[0].start, measurement_times[0]) +time_max = find_max_time(measurement_times) +plotting_times = np.linspace(time_zero, time_max, y_resolution) + + +################# sliders ############### +fig, ax = plt.subplots(figsize=(10, 8)) +plt.subplots_adjust(left=0.25, bottom=0.25) +mprot = measure_Mprotein_with_noise(parameters, plotting_times, treatment_history) +#parameters = Parameters(50, 0.072, -0.056, 0.01742, parameters.k_1, parameters.sigma) +resistant_parameters = Parameters((parameters.Y_0*parameters.pi_r), 1, parameters.g_r, parameters.g_s, parameters.k_1, parameters.sigma) +mres = measure_Mprotein_with_noise(resistant_parameters, plotting_times, treatment_history) +msens = mprot - mres +l, = plt.plot(plotting_times, mprot, lw=2, color='k', label="From all cells", zorder=3) +r, = plt.plot(plotting_times, mres, lw=2, color='r', linestyle="--", label="From resistant", zorder=2) +s, = plt.plot(plotting_times, msens, lw=2, color='b', linestyle="--", label="From sensitive", zorder=1) +plt.ylabel("Serum Mprotein (g/L)") +plt.axis([plotting_times[0], plotting_times[-1], 0, 80]) +plt.xlabel("Days") +plt.legend() +axcolor = 'lightgoldenrodyellow' +ax_pi = plt.axes([0.25, 0.12, 0.65, 0.03], facecolor=axcolor) +ax_rho = plt.axes([0.25, 0.07, 0.65, 0.03], facecolor=axcolor) +ax_alpha = plt.axes([0.25, 0.02, 0.65, 0.03], facecolor=axcolor) +#ax_Y_0 = plt.axes([0.25, 0.2, 0.65, 0.03], facecolor=axcolor) + +slide_pi = Slider(ax_pi, r'$\pi $', 0, 1, valinit=parameters.pi_r[0]) +slide_rho = Slider(ax_rho, r'$\rho $', 0, 10*parameters.g_r[0], valinit=parameters.g_r[0]) +slide_alpha = Slider(ax_alpha, r'$\alpha$', - 0.1, 0, valinit=parameters.g_s[0]) +#slide_Y_0 = Slider(ax_Y_0, 'Y_0', - 0.1, 0, valinit=parameters.Y_0[0]) + +def update(val): + pi = slide_pi.val + rho = slide_rho.val + alpha = slide_alpha.val + #Y_0 = slide_Y_0.val + params = Parameters(parameters.Y_0, pi, rho, alpha, parameters.k_1, parameters.sigma) + r_params = Parameters((params.Y_0*params.pi_r), 1, params.g_r, params.g_s, params.k_1, params.sigma) + mprot = measure_Mprotein_with_noise(params, plotting_times, treatment_history) + mres = measure_Mprotein_with_noise(r_params, plotting_times, treatment_history) + msens = mprot - mres + l.set_ydata(mprot) + r.set_ydata(mres) + s.set_ydata(msens) + fig.canvas.draw_idle() +slide_pi.on_changed(update) +slide_rho.on_changed(update) +slide_alpha.on_changed(update) + +plt.show()