Skip to content

Commit

Permalink
Spring cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
evenmm committed Jun 22, 2023
1 parent 3856bb0 commit 32bfed3
Show file tree
Hide file tree
Showing 61 changed files with 15,388 additions and 8,520 deletions.
10 changes: 7 additions & 3 deletions BNN_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# X is an (N_patients, P) shaped pandas dataframe
# patient dictionary contains N_patients patients in the same order as X

def BNN_model(X, patient_dictionary, name, psi_prior="lognormal", MODEL_RANDOM_EFFECTS=True, FUNNEL_REPARAMETRIZATION=False, FUNNEL_WEIGHTS = False, WEIGHT_PRIOR = "symmetry_fix", SAVING=False, n_hidden = 3, net_list=["pi", "rho_r", "rho_s"]):
def BNN_model(X, patient_dictionary, name, psi_prior="lognormal", MODEL_RANDOM_EFFECTS=True, FUNNEL_REPARAMETRIZATION=False, FUNNEL_WEIGHTS = False, WEIGHT_PRIOR = "symmetry_fix", SAVING=False, n_hidden = 3, net_list=["pi", "rho_r", "rho_s"], empirical_mean_alpha=[]):
df = pd.DataFrame(columns=["patient_id", "mprotein_value", "time"])
for ii in range(len(patient_dictionary)):
patient = patient_dictionary[ii]
Expand Down Expand Up @@ -42,9 +42,13 @@ def BNN_model(X, patient_dictionary, name, psi_prior="lognormal", MODEL_RANDOM_E
sigma_obs = pm.HalfNormal("sigma_obs", sigma=1)

# alpha
alpha = pm.Normal("alpha", mu=np.array([np.math.log(0.002), np.math.log(0.002), np.math.log(0.5/(1-0.5))]), sigma=1, shape=3)
if len(empirical_mean_alpha) < 3:
alpha = pm.Normal("alpha", mu=np.array([np.log(0.002), np.log(0.002), np.log(0.5/(1-0.5))]), sigma=1, shape=3)
else:
print("Using empirical_mean_alpha as prior for alpha")
alpha = pm.Normal("alpha", mu=np.array([np.log(-empirical_mean_alpha[0]), np.log(empirical_mean_alpha[1]), np.log(empirical_mean_alpha[2]/(1-empirical_mean_alpha[2]))]), sigma=1, shape=3)

log_sigma_weights_in = pm.Normal("log_sigma_weights_in", mu=2*np.math.log(0.01), sigma=2.5**2, shape=(X.shape[0], 1))
log_sigma_weights_in = pm.Normal("log_sigma_weights_in", mu=2*np.log(0.01), sigma=2.5**2, shape=(X.shape[0], 1))
sigma_weights_in = pm.Deterministic("sigma_weights_in", pm.math.exp(log_sigma_weights_in))
sigma_weights_out = pm.HalfNormal("sigma_weights_out", sigma=0.1)
sigma_bias_in = pm.HalfNormal("sigma_bias_in", sigma=1, shape=(1,n_hidden))
Expand Down
5 changes: 2 additions & 3 deletions COMMPASS_BNN_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@
"var_dimensions = sample_shape[2]\n",
"# Gather all the parameter estimates and get y values for each parameter set\n",
"posterior_parameters = np.empty(shape=sample_shape, dtype=object)\n",
"y_resolution = 1000\n",
"y_resolution = 80\n",
"predicted_y_values = np.empty((sample_shape+(y_resolution,)))\n",
"predicted_y_resistant_values = np.empty_like(predicted_y_values)\n",
"for ii in range(var_dimensions): # per patient\n",
Expand All @@ -900,7 +900,6 @@
" these_parameters = posterior_parameters[ch,sa,ii]\n",
" resistant_parameters = Parameters((these_parameters.Y_0*these_parameters.pi_r), 1, these_parameters.g_r, these_parameters.g_s, these_parameters.k_1, these_parameters.sigma)\n",
" # Predicted total M protein\n",
" meas = measure_Mprotein_noiseless(these_parameters, plotting_times, treatment_history) \n",
" predicted_y_values[ch,sa,ii] = measure_Mprotein_noiseless(these_parameters, plotting_times, treatment_history) \n",
" # Predicted resistant part\n",
" predicted_y_resistant_values[ch,sa,ii] = measure_Mprotein_noiseless(resistant_parameters, plotting_times, treatment_history)\n"
Expand Down Expand Up @@ -1374,7 +1373,7 @@
"# Plot posterior confidence intervals \n",
"for training_instance_id, patient in patient_dictionary.items():\n",
" savename = \"./plots/Bayesian_estimates_COMMPASS_BNN/CI_training_id_\"+str(training_instance_id)+\"_treat_id_\"+str(treat_id)+\"_M_\"+str(M_number_of_measurements)+\"_P_\"+str(P)+\"_N_cases_\"+str(N_cases)+\"_psi_prior_\"+psi_prior+\"_N_samples_\"+str(N_samples)+\".png\"\n",
" plot_posterior_confidence_intervals(training_instance_id, patient, sorted_pred_y_values, parameter_estimates=[], PLOT_POINT_ESTIMATES=False, PLOT_TREATMENTS=False, plot_title=\"Posterior CI for patient \"+str(training_instance_id), savename=savename, y_resolution=y_resolution)"
" plot_posterior_confidence_intervals(training_instance_id, patient, sorted_pred_y_values, parameter_estimates=[], PLOT_POINT_ESTIMATES=False, PLOT_TREATMENTS=False, plot_title=\"Posterior CI for patient \"+str(training_instance_id), savename=savename, y_resolution=y_resolution, n_chains=n_chains, n_samples=n_samples)"
]
},
{
Expand Down
7 changes: 3 additions & 4 deletions COMMPASS_linearmodel_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
"target_accept = 0.99\n",
"max_treedepth = 10\n",
"FUNNEL_REPARAMETRIZATION = True\n",
"name = \"COMMPASS_treat_id_\"+str(treat_id)+\"_M_\"+str(M_number_of_measurements)+\"_P_\"+str(P)+\"_N_cases_\"+str(N_cases)+\"_psi_prior_\"+psi_prior+\"_N_samples_\"+str(N_samples)+\"_N_tuning_\"+str(N_tuning)+\"_target_accept_\"+str(target_accept)+\"_max_treedepth_\"+str(max_treedepth)+\"_FUNNEL_REPARAMETRIZATION_\"+str(FUNNEL_REPARAMETRIZATION)\n",
"name = \"COMMPASS_linear_treat_id_\"+str(treat_id)+\"_M_\"+str(M_number_of_measurements)+\"_P_\"+str(P)+\"_N_cases_\"+str(N_cases)+\"_psi_prior_\"+psi_prior+\"_N_samples_\"+str(N_samples)+\"_N_tuning_\"+str(N_tuning)+\"_target_accept_\"+str(target_accept)+\"_max_treedepth_\"+str(max_treedepth)+\"_FUNNEL_REPARAMETRIZATION_\"+str(FUNNEL_REPARAMETRIZATION)\n",
"print(\"Running\"+name)\n",
"idata = sample_from_full_model(X, patient_dictionary, name, N_samples=N_samples, N_tuning=N_tuning, target_accept=target_accept, max_treedepth=max_treedepth, psi_prior=psi_prior, FUNNEL_REPARAMETRIZATION=FUNNEL_REPARAMETRIZATION)\n",
"# This is an xArray: https://docs.xarray.dev/en/v2022.11.0/user-guide/data-structures.html\n",
Expand Down Expand Up @@ -329,7 +329,7 @@
"var_dimensions = sample_shape[2]\n",
"# Gather all the parameter estimates and get y values for each parameter set\n",
"posterior_parameters = np.empty(shape=sample_shape, dtype=object)\n",
"y_resolution = 1000\n",
"y_resolution = 80\n",
"predicted_y_values = np.empty((sample_shape+(y_resolution,)))\n",
"predicted_y_resistant_values = np.empty_like(predicted_y_values)\n",
"for ii in range(var_dimensions): # per patient\n",
Expand All @@ -349,8 +349,7 @@
" these_parameters = posterior_parameters[ch,sa,ii]\n",
" resistant_parameters = Parameters((these_parameters.Y_0*these_parameters.pi_r), 1, these_parameters.g_r, these_parameters.g_s, these_parameters.k_1, these_parameters.sigma)\n",
" # Predicted total M protein\n",
" meas = measure_Mprotein_noiseless(these_parameters, plotting_times, treatment_history) \n",
" predicted_y_values[ch,sa,ii] = measure_Mprotein_noiseless(these_parameters, plotting_times, treatment_history) \n",
" predicted_y_values[ch,sa,ii] = measure_Mprotein_noiseless(these_parameters, plotting_times, treatment_history)\n",
" # Predicted resistant part\n",
" predicted_y_resistant_values[ch,sa,ii] = measure_Mprotein_noiseless(resistant_parameters, plotting_times, treatment_history)\n"
]
Expand Down
Loading

0 comments on commit 32bfed3

Please sign in to comment.