Skip to content

Commit

Permalink
Slider and M protein plots
Browse files Browse the repository at this point in the history
  • Loading branch information
evenmm committed May 5, 2023
1 parent b652ca8 commit fd6269e
Show file tree
Hide file tree
Showing 4 changed files with 374 additions and 0 deletions.
251 changes: 251 additions & 0 deletions plot_auc_simdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import seaborn as sns

from utilities import *
from BNN_model import *

# Initialize random number generator
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
rng = np.random.default_rng(RANDOM_SEED)
print(f"Running on PyMC v{pm.__version__}")
#SAVEDIR = "/data/evenmm/plots/"
SAVEDIR = "./plots/Bayesian_estimates_simdata_BNN/"

script_index = 1

# Settings
if int(script_index % 3) == 0:
true_sigma_obs = 0
elif int(script_index % 3) == 1:
true_sigma_obs = 1
elif int(script_index % 3) == 2:
true_sigma_obs = 2.5

if script_index >= 3:
RANDOM_EFFECTS = True
else:
RANDOM_EFFECTS = False

RANDOM_EFFECTS_TEST = False

model_name = "BNN"
N_patients = 150
psi_prior="lognormal"
WEIGHT_PRIOR = "Student_out" #"Horseshoe" # "Student_out" #"symmetry_fix" #"iso_normal" "Student_out"
N_samples = 10_000
N_tuning = 10_000
ADADELTA = True
target_accept = 0.99
CI_with_obs_noise = True
PLOT_RESISTANT = True
FUNNEL_REPARAMETRIZATION = False
MODEL_RANDOM_EFFECTS = True
N_HIDDEN = 2
P = 5 # Number of covariates
P0 = int(P / 2) # A guess of the true number of nonzero parameters is needed for defining the global shrinkage parameter
true_omega = np.array([0.10,0.05,0.20])

M_number_of_measurements =7
y_resolution = 80 # Number of timepoints to evaluate the posterior of y in
true_omega_for_psi = 0.1

max_time = 1200 #3000 #1500
days_between_measurements = int(max_time/M_number_of_measurements)
measurement_times = days_between_measurements * np.linspace(0, M_number_of_measurements, M_number_of_measurements)

evaluation_time = measurement_times[:4][-1] + 1
print(evaluation_time)

treatment_history = np.array([Treatment(start=0, end=measurement_times[-1], id=1)])
name = "simdata_"+model_name+"_"+str(script_index)+"_M_"+str(M_number_of_measurements)+"_P_"+str(P)+"_N_pax_"+str(N_patients)+"_N_sampl_"+str(N_samples)+"_N_tune_"+str(N_tuning)+"_FUNNEL_"+str(FUNNEL_REPARAMETRIZATION)+"_RNDM_EFFECTS_"+str(RANDOM_EFFECTS)+"_WT_PRIOR_"+str(WEIGHT_PRIOR+"_N_HIDDN_"+str(N_HIDDEN))

N_patients_test = 50

recur_or_not_BNN = [
[1.,0.,0.,1.,0.,0.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.
,0.,1.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.,0.,0.,0.,1.,0.,0.,0.
,1.,0.],
[1.,0.,0.,1.,0.,0.,1.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.
,0.,1.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.,0.,0.,0.,1.,0.,0.,0.
,1.,0.],
[1.,0.,0.,1.,0.,0.,1.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.
,0.,1.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.,0.,0.,1.,1.,0.,0.,0.
,1.,0.],
[1.,0.,0.,1.,0.,0.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.
,0.,1.,0.,0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,1.,0.,0.,0.,0.,1.,0.,0.,0.
,1.,0.],
[1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,
1, 0],
[1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0,
1, 0,]
]

p_recs_BNN = np.array([[1.00000e+00,0.00000e+00,5.19250e-02,1.00000e+00,5.00000e-05,0.00000e+00
,4.05575e-01,2.00000e-04,7.45575e-01,0.00000e+00,2.47500e-03,0.00000e+00
,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00
,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,9.59275e-01
,0.00000e+00,9.74175e-01,0.00000e+00,0.00000e+00,9.99825e-01,0.00000e+00
,6.75000e-04,0.00000e+00,2.50000e-05,0.00000e+00,0.00000e+00,6.13250e-02
,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,0.00000e+00,0.00000e+00
,0.00000e+00,1.50000e-04,9.42375e-01,0.00000e+00,0.00000e+00,0.00000e+00
,1.00000e+00,0.00000e+00],
[1.00000e+00,0.00000e+00,1.77750e-02,1.00000e+00,0.00000e+00,0.00000e+00
,2.96350e-01,2.50000e-05,6.81375e-01,0.00000e+00,2.25000e-04,0.00000e+00
,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00
,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,9.39000e-01
,0.00000e+00,9.65450e-01,0.00000e+00,0.00000e+00,9.99925e-01,0.00000e+00
,5.00000e-05,0.00000e+00,2.50000e-05,0.00000e+00,0.00000e+00,2.48000e-02
,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,0.00000e+00,0.00000e+00
,0.00000e+00,0.00000e+00,9.55525e-01,0.00000e+00,0.00000e+00,0.00000e+00
,1.00000e+00,0.00000e+00],
[1.00000e+00,0.00000e+00,3.31500e-02,1.00000e+00,2.75000e-04,0.00000e+00
,2.98825e-01,8.25000e-04,5.79400e-01,0.00000e+00,2.82500e-03,0.00000e+00
,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00
,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,8.51800e-01
,0.00000e+00,8.87375e-01,0.00000e+00,0.00000e+00,9.98425e-01,0.00000e+00
,1.20000e-03,0.00000e+00,6.75000e-04,0.00000e+00,0.00000e+00,5.97250e-02
,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,2.50000e-05,0.00000e+00
,0.00000e+00,1.32500e-03,8.75850e-01,0.00000e+00,0.00000e+00,0.00000e+00
,9.99875e-01,0.00000e+00],
[9.99875e-01,0.00000e+00,2.30250e-01,9.99225e-01,1.36500e-02,0.00000e+00
,4.30125e-01,2.60750e-02,6.46450e-01,0.00000e+00,7.52250e-02,0.00000e+00
,0.00000e+00,0.00000e+00,4.75000e-03,0.00000e+00,0.00000e+00,0.00000e+00
,0.00000e+00,0.00000e+00,0.00000e+00,0.00000e+00,1.00000e+00,8.53925e-01
,0.00000e+00,8.64100e-01,0.00000e+00,0.00000e+00,9.83975e-01,0.00000e+00
,4.88250e-02,0.00000e+00,4.55000e-03,0.00000e+00,0.00000e+00,2.28150e-01
,2.00000e-04,0.00000e+00,0.00000e+00,9.99925e-01,1.00000e-04,0.00000e+00
,0.00000e+00,2.06250e-02,7.88850e-01,0.00000e+00,0.00000e+00,0.00000e+00
,9.83175e-01,0.00000e+00],
[9.99650e-01, 0.00000e+00, 1.89525e-01, 9.98775e-01, 1.33500e-02, 0.00000e+00,
3.88825e-01, 2.35000e-02, 5.92750e-01, 0.00000e+00, 5.76000e-02, 0.00000e+00,
0.00000e+00, 0.00000e+00, 3.47500e-03, 0.00000e+00, 0.00000e+00, 0.00000e+00,
0.00000e+00, 0.00000e+00, 7.50000e-05, 0.00000e+00, 1.00000e+00, 8.00275e-01,
0.00000e+00, 8.17250e-01, 0.00000e+00, 0.00000e+00, 9.73450e-01, 0.00000e+00,
3.60750e-02, 0.00000e+00, 5.22500e-03, 0.00000e+00, 0.00000e+00, 1.95325e-01,
1.50000e-04, 0.00000e+00, 0.00000e+00, 9.99650e-01, 9.25000e-04, 0.00000e+00,
0.00000e+00, 1.84750e-02, 7.66100e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00,
9.77350e-01, 0.00000e+00],
[9.99600e-01, 0.00000e+00, 1.96975e-01, 9.97275e-01, 2.17500e-02, 0.00000e+00,
3.97275e-01, 2.98500e-02, 6.12750e-01, 0.00000e+00, 3.67500e-02, 0.00000e+00,
0.00000e+00, 0.00000e+00, 5.95000e-03, 0.00000e+00, 0.00000e+00, 5.00000e-05,
0.00000e+00, 0.00000e+00, 2.00000e-04, 0.00000e+00, 1.00000e+00, 7.83650e-01,
0.00000e+00, 7.90425e-01, 0.00000e+00, 0.00000e+00, 9.58925e-01, 0.00000e+00,
2.50000e-02, 0.00000e+00, 1.67000e-02, 0.00000e+00, 0.00000e+00, 1.40300e-01,
2.00000e-04, 0.00000e+00, 2.50000e-05, 9.99200e-01, 6.75000e-04, 0.00000e+00,
0.00000e+00, 3.85750e-02, 7.53050e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00,
9.70200e-01, 0.00000e+00,]
])


recur_or_not_LIN = [
[1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
1., 0.],
[1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
1., 0.],
[1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
1., 0.],
[1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
1., 0.],
[1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
1., 0.],
[1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
1., 0.]
]

p_recs_LIN = np.array([
[9.99175e-01, 0.00000e+00, 5.81800e-01, 9.39900e-01, 9.50000e-03, 0.00000e+00,
1.87775e-01, 1.79375e-01, 6.75025e-01, 0.00000e+00, 8.99500e-02, 1.25000e-04,
0.00000e+00, 0.00000e+00, 5.87750e-02, 0.00000e+00, 0.00000e+00, 1.00000e-04,
0.00000e+00, 0.00000e+00, 1.12500e-03, 4.75000e-04, 1.00000e+00, 9.34475e-01,
2.50000e-05, 8.79500e-01, 0.00000e+00, 0.00000e+00, 9.82225e-01, 0.00000e+00,
8.20000e-02, 0.00000e+00, 1.67000e-02, 0.00000e+00, 2.50000e-05, 2.66025e-01,
3.00000e-04, 0.00000e+00, 1.25000e-04, 9.94900e-01, 6.97500e-03, 0.00000e+00,
0.00000e+00, 7.74250e-02, 5.17300e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00,
9.17650e-01, 0.00000e+00,],
[9.99700e-01, 0.00000e+00, 6.19600e-01, 9.59500e-01, 5.37500e-03, 0.00000e+00,
1.84250e-01, 1.59575e-01, 7.24525e-01, 0.00000e+00, 8.30750e-02, 5.00000e-05,
0.00000e+00, 0.00000e+00, 5.20250e-02, 0.00000e+00, 0.00000e+00, 0.00000e+00,
0.00000e+00, 0.00000e+00, 4.00000e-04, 2.50000e-05, 1.00000e+00, 9.58650e-01,
0.00000e+00, 9.15225e-01, 0.00000e+00, 0.00000e+00, 9.92675e-01, 0.00000e+00,
7.89250e-02, 0.00000e+00, 8.57500e-03, 0.00000e+00, 0.00000e+00, 2.69350e-01,
1.00000e-04, 0.00000e+00, 5.00000e-05, 9.97825e-01, 2.25000e-03, 0.00000e+00,
0.00000e+00, 6.46500e-02, 5.55475e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00,
9.37575e-01, 0.00000e+00,],
[9.99750e-01, 0.00000e+00, 6.52225e-01, 9.66275e-01, 7.00000e-03, 0.00000e+00,
1.91450e-01, 1.79550e-01, 7.28450e-01, 0.00000e+00, 7.64750e-02, 2.50000e-05,
0.00000e+00, 0.00000e+00, 4.55750e-02, 0.00000e+00, 0.00000e+00, 0.00000e+00,
0.00000e+00, 0.00000e+00, 4.00000e-04, 2.50000e-05, 1.00000e+00, 9.63475e-01,
0.00000e+00, 9.21650e-01, 0.00000e+00, 0.00000e+00, 9.95000e-01, 0.00000e+00,
7.29000e-02, 0.00000e+00, 8.00000e-03, 0.00000e+00, 0.00000e+00, 2.56025e-01,
1.25000e-04, 0.00000e+00, 1.00000e-04, 9.98675e-01, 3.55000e-03, 0.00000e+00,
0.00000e+00, 7.01000e-02, 6.09500e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00,
9.43975e-01, 0.00000e+00,],
[1.00000e+00, 0.00000e+00, 5.81725e-01, 9.89125e-01, 1.35000e-03, 0.00000e+00,
1.12675e-01, 1.16075e-01, 7.28225e-01, 0.00000e+00, 2.96000e-02, 2.50000e-05,
0.00000e+00, 0.00000e+00, 2.08000e-02, 0.00000e+00, 0.00000e+00, 0.00000e+00,
0.00000e+00, 0.00000e+00, 7.50000e-05, 2.50000e-05, 1.00000e+00, 9.74750e-01,
0.00000e+00, 9.46425e-01, 0.00000e+00, 0.00000e+00, 9.97725e-01, 0.00000e+00,
2.30500e-02, 0.00000e+00, 3.47500e-03, 0.00000e+00, 0.00000e+00, 1.92150e-01,
0.00000e+00, 0.00000e+00, 0.00000e+00, 9.99825e-01, 1.42500e-03, 0.00000e+00,
0.00000e+00, 3.19000e-02, 5.33925e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00,
9.78950e-01, 0.00000e+00,],
[1.00000e+00, 0.00000e+00, 6.28575e-01, 9.96000e-01, 4.50000e-04, 0.00000e+00,
1.01475e-01, 9.34500e-02, 7.89375e-01, 0.00000e+00, 2.32000e-02, 0.00000e+00,
0.00000e+00, 0.00000e+00, 1.48000e-02, 0.00000e+00, 0.00000e+00, 0.00000e+00,
0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00, 9.88100e-01,
0.00000e+00, 9.74425e-01, 0.00000e+00, 0.00000e+00, 9.99425e-01, 0.00000e+00,
1.77250e-02, 0.00000e+00, 1.60000e-03, 0.00000e+00, 0.00000e+00, 2.04900e-01,
0.00000e+00, 0.00000e+00, 0.00000e+00, 9.99950e-01, 3.00000e-04, 0.00000e+00,
0.00000e+00, 2.22750e-02, 5.74200e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00,
9.89875e-01, 0.00000e+00,],
[1.00000e+00, 0.00000e+00, 7.00900e-01, 9.92800e-01, 5.00000e-04, 0.00000e+00,
8.87750e-02, 1.08375e-01, 8.12125e-01, 0.00000e+00, 3.29500e-02, 0.00000e+00,
0.00000e+00, 0.00000e+00, 2.14500e-02, 0.00000e+00, 0.00000e+00, 0.00000e+00,
0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00, 9.92275e-01,
0.00000e+00, 9.81775e-01, 0.00000e+00, 0.00000e+00, 9.99600e-01, 0.00000e+00,
2.54500e-02, 0.00000e+00, 1.20000e-03, 0.00000e+00, 0.00000e+00, 2.55350e-01,
0.00000e+00, 0.00000e+00, 0.00000e+00, 9.99925e-01, 2.50000e-04, 0.00000e+00,
0.00000e+00, 2.75000e-02, 5.86725e-01, 0.00000e+00, 0.00000e+00, 0.00000e+00,
9.89675e-01, 0.00000e+00,]
])

import sklearn.metrics as metrics
for ii in range(len(recur_or_not_BNN)):
fpr_BNN, tpr_BNN, threshold = metrics.roc_curve(recur_or_not_BNN[ii], p_recs_BNN[ii]) #(y_test, preds)
roc_auc_BNN = metrics.auc(fpr_BNN, tpr_BNN)

fpr_LIN, tpr_LIN, threshold = metrics.roc_curve(recur_or_not_LIN[ii], p_recs_LIN[ii]) #(y_test, preds)
roc_auc_LIN = metrics.auc(fpr_LIN, tpr_LIN)

print("threshold:\n", threshold)
print("fpr_BNN:\n", fpr_BNN)
print("tpr_BNN:\n", tpr_BNN)
print("roc_auc_BNN:\n", roc_auc_BNN)

print("\nfpr_BNN:\n", fpr_LIN)
print("tpr_BNN:\n", tpr_LIN)
print("roc_auc_BNN:\n", roc_auc_LIN)

#plt.title('Receiver Operating Characteristic')
plt.plot(fpr_LIN, tpr_LIN, color=plt.cm.viridis(0.3), label = 'Linear reg. (AUC = %0.2f)' % roc_auc_LIN)
plt.plot(fpr_BNN, tpr_BNN, color=plt.cm.viridis(0.7), label = 'BNN (AUC = %0.2f)' % roc_auc_BNN)
plt.legend(loc = 'lower right')
plt.plot([0,1], [0,1], color='grey', linestyle='--')
plt.xlim([0,1])
plt.ylim([0,1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.savefig(SAVEDIR+"AUC_"+str(ii)+"_"+str(N_patients_test)+"_test_patients_"+name+".pdf")
plt.show()
18 changes: 18 additions & 0 deletions plot_exp_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import seaborn as sns

from utilities import *
# plotting endtime M protein value
times = np.linspace(0,18)
mean_theta = np.log(0.002) # barely an increase: 10 to 14 over the course of 6 months
theta = np.log(0.04) # already unrealistically big: 10 to 20 over 18 days
gr = np.exp(theta)
Mprot = 10*np.exp(gr*times)
fig, ax = plt.subplots()
ax.plot(times, Mprot)
#ax.plot(times, np.log(0.04)*sigmoid(times))
#ax.plot(times, times/(np.sqrt(1+times**2)))
plt.show()
20 changes: 20 additions & 0 deletions plot_legend_CIs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import seaborn as sns

SAVEDIR = "./plots/Bayesian_estimates_simdata_BNN/"

fig, ax1 = plt.subplots()
shade_array = [0.7, 0.5, 0.35]
for index, critical_value in enumerate([0.05, 0.25, 0.45]): # Corresponding to confidence levels 90, 50, and 10
ax1.fill_between([-1,0,1], [1,-1,-1], [1,-1,-1], color=plt.cm.copper(1-critical_value), label='%3.0f %% CI, resistant M protein' % (100*(1-2*critical_value)), zorder=0+index*0.1)
for index, critical_value in enumerate([0.05, 0.25, 0.45]): # Corresponding to confidence levels 90, 50, and 10
ax1.fill_between([-1,0,1], [1,-1,-1], [1,-1,-1], color=plt.cm.bone(shade_array[index]), label='%3.0f %% CI, total M protein' % (100*(1-2*critical_value)), zorder=1+index*0.1)
ax1.plot([-1,0,1], [1,-1,-1], linestyle='--', marker='', zorder=3, color='cyan', label="True M protein (total)")
ax1.plot([-1,0,1], [1,-1,-1], linestyle='--', marker='', zorder=2.9, color=plt.cm.hot(0.2), label="True M protein (resistant)")
ax1.plot([-1,0,1], [1,-1,-1], linestyle='', marker='x', zorder=4, color='k', label="Observed M protein") #[ax1.axvline(time, color="k", linewidth=0.5, linestyle="-") for time in measurement_times]
plt.legend(loc="upper right")
plt.savefig(SAVEDIR+"xxxxAUC_.pdf", dpi=300)
plt.show()
85 changes: 85 additions & 0 deletions slider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from utilities import *
from matplotlib.widgets import Slider, Button, RadioButtons
RANDOM_SEED = 499
np.random.seed(RANDOM_SEED)
rng = np.random.default_rng(RANDOM_SEED)
PLOT_RESISTANT = True
FUNNEL_REPARAMETRIZATION = False
MODEL_RANDOM_EFFECTS = True
N_HIDDEN = 2
RANDOM_EFFECTS = True
P = 5 # Number of covariates
P0 = int(P / 2) # A guess of the true number of nonzero parameters is needed for defining the global shrinkage parameter
true_omega = np.array([0.10, 0.05, 0.20])
true_omega_for_psi = 0.1
true_sigma_obs = 0
M_number_of_measurements = 12
N_patients = 150
y_resolution = 80 # Number of timepoints to evaluate the posterior of y in
max_time = 180
days_between_measurements = int(max_time/M_number_of_measurements)
measurement_times = days_between_measurements * np.linspace(0, M_number_of_measurements, M_number_of_measurements)
treatment_history = np.array([Treatment(start=0, end=measurement_times[-1], id=1)])

X, patient_dictionary, parameter_dictionary, expected_theta_1, true_theta_rho_s, true_rho_s = generate_simulated_patients(deepcopy(measurement_times), treatment_history, true_sigma_obs, N_patients, P, get_expected_theta_from_X_2, true_omega, true_omega_for_psi, seed=42, RANDOM_EFFECTS=RANDOM_EFFECTS)

# Our patient
ii = 19
patient = patient_dictionary[ii]
parameters = parameter_dictionary[ii]
measurement_times = patient.measurement_times
treatment_history = patient.get_treatment_history()
parameters = Parameters(Y_0=50, pi_r=0.1*parameters.pi_r, g_r=30*parameters.g_r, g_s=30*parameters.g_s, k_1=0, sigma=true_sigma_obs)
patient = Patient(parameters, measurement_times, treatment_history, name=str(ii))
plot_mprotein(patient, "", "./obs.pdf", PLOT_PARAMETERS=True, parameters=parameters)
#Mprotein_values = patient.get_Mprotein_values()
time_zero = min(treatment_history[0].start, measurement_times[0])
time_max = find_max_time(measurement_times)
plotting_times = np.linspace(time_zero, time_max, y_resolution)


################# sliders ###############
fig, ax = plt.subplots(figsize=(10, 8))
plt.subplots_adjust(left=0.25, bottom=0.25)
mprot = measure_Mprotein_with_noise(parameters, plotting_times, treatment_history)
#parameters = Parameters(50, 0.072, -0.056, 0.01742, parameters.k_1, parameters.sigma)
resistant_parameters = Parameters((parameters.Y_0*parameters.pi_r), 1, parameters.g_r, parameters.g_s, parameters.k_1, parameters.sigma)
mres = measure_Mprotein_with_noise(resistant_parameters, plotting_times, treatment_history)
msens = mprot - mres
l, = plt.plot(plotting_times, mprot, lw=2, color='k', label="From all cells", zorder=3)
r, = plt.plot(plotting_times, mres, lw=2, color='r', linestyle="--", label="From resistant", zorder=2)
s, = plt.plot(plotting_times, msens, lw=2, color='b', linestyle="--", label="From sensitive", zorder=1)
plt.ylabel("Serum Mprotein (g/L)")
plt.axis([plotting_times[0], plotting_times[-1], 0, 80])
plt.xlabel("Days")
plt.legend()
axcolor = 'lightgoldenrodyellow'
ax_pi = plt.axes([0.25, 0.12, 0.65, 0.03], facecolor=axcolor)
ax_rho = plt.axes([0.25, 0.07, 0.65, 0.03], facecolor=axcolor)
ax_alpha = plt.axes([0.25, 0.02, 0.65, 0.03], facecolor=axcolor)
#ax_Y_0 = plt.axes([0.25, 0.2, 0.65, 0.03], facecolor=axcolor)

slide_pi = Slider(ax_pi, r'$\pi $', 0, 1, valinit=parameters.pi_r[0])
slide_rho = Slider(ax_rho, r'$\rho $', 0, 10*parameters.g_r[0], valinit=parameters.g_r[0])
slide_alpha = Slider(ax_alpha, r'$\alpha$', - 0.1, 0, valinit=parameters.g_s[0])
#slide_Y_0 = Slider(ax_Y_0, 'Y_0', - 0.1, 0, valinit=parameters.Y_0[0])

def update(val):
pi = slide_pi.val
rho = slide_rho.val
alpha = slide_alpha.val
#Y_0 = slide_Y_0.val
params = Parameters(parameters.Y_0, pi, rho, alpha, parameters.k_1, parameters.sigma)
r_params = Parameters((params.Y_0*params.pi_r), 1, params.g_r, params.g_s, params.k_1, params.sigma)
mprot = measure_Mprotein_with_noise(params, plotting_times, treatment_history)
mres = measure_Mprotein_with_noise(r_params, plotting_times, treatment_history)
msens = mprot - mres
l.set_ydata(mprot)
r.set_ydata(mres)
s.set_ydata(msens)
fig.canvas.draw_idle()
slide_pi.on_changed(update)
slide_rho.on_changed(update)
slide_alpha.on_changed(update)

plt.show()

0 comments on commit fd6269e

Please sign in to comment.