Skip to content

Commit

Permalink
loading ess values from saved file if it exists
Browse files Browse the repository at this point in the history
  • Loading branch information
ShantanuKodgirwar committed Oct 24, 2024
1 parent 10549ab commit 5d97f96
Showing 1 changed file with 56 additions and 50 deletions.
106 changes: 56 additions & 50 deletions scripts/ess_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"rwmh": "RWMH",
"hmc": "HMC",
}
plt.rc("font", size=16)
plt.rc("font", size=20)


def load_samples(
Expand Down Expand Up @@ -134,7 +134,6 @@ def calc_ess_varying_param(
param_values,
datasets,
ess_filepath,
recompute_ess=False,
verbose=False,
):
"""
Expand All @@ -148,22 +147,18 @@ def calc_ess_varying_param(
Loaded datasets, indexed by parameter value.
ess_filepath : str
File path to save or load ESS values.
recompute_ess : bool
verbose : bool
Returns
-------
dict
ESS values indexed by parameter value and method.
"""
if recompute_ess or not os.path.exists(ess_filepath):
ess_vals = {}
for param_value in param_values:
print(f"Calculating ESS for {param_value=}")
ess_vals[param_value] = calc_ess(datasets[param_value], verbose=verbose)
dump(ess_vals, ess_filepath)
else:
ess_vals = load(ess_filepath)
ess_vals = {}
for param_value in param_values:
print(f"Calculating ESS for {param_value=}")
ess_vals[param_value] = calc_ess(datasets[param_value], verbose=verbose)
dump(ess_vals, ess_filepath)

return ess_vals

Expand Down Expand Up @@ -226,7 +221,7 @@ def ess_plot_varying_param(

if __name__ == "__main__":

plotting_varying_kappa = False
plotting_varying_kappa = True
plotting_varying_ndim = True

if plotting_varying_kappa:
Expand All @@ -236,26 +231,32 @@ def ess_plot_varying_param(
n_runs = 10
subdir = f"results/curve_{n_dim}d_vary_kappa_nruns_{n_runs}"
ess_filename = "ess_curve_10d_varying_kappa.pkl"

# load samples for varying kappa
datasets_varying_kappa = load_samples(
base_path=subdir,
varying_param_values=kappas,
varying_param_name="kappa",
fixed_params={"n_dim": n_dim},
n_runs=n_runs,
verbose=False,
)

# calculate ESS
ess_filepath = os.path.join(subdir, ess_filename)
ess_kappas = calc_ess_varying_param(
param_values=kappas,
datasets=datasets_varying_kappa,
ess_filepath=ess_filepath,
recompute_ess=False,
verbose=False,
)
recompute_ess = False

# load or calculate ESS
if not recompute_ess and os.path.exists(ess_filepath):
print("Loading ESS values from the file...")
ess_kappas = load(ess_filepath)
else:
# load samples for varying kappa
print(f"Loading samples for varying kappa from {subdir}...")
datasets_varying_kappa = load_samples(
base_path=subdir,
varying_param_values=kappas,
varying_param_name="kappa",
fixed_params={"n_dim": n_dim},
n_runs=n_runs,
verbose=True,
)

# calculate ESS
ess_kappas = calc_ess_varying_param(
param_values=kappas,
datasets=datasets_varying_kappa,
ess_filepath=ess_filepath,
verbose=True,
)

# plotting
fig = ess_plot_varying_param(
Expand All @@ -275,26 +276,31 @@ def ess_plot_varying_param(
n_runs = 10
subdir = f"results/curve_kappa_{float(kappa)}_vary_ndim_nruns_{n_runs}"
ess_filename = f"ess_curve_kappa_{int(kappa)}_varying_ndim.pkl"

# load samples for varying n_dim
datasets_varying_ndim = load_samples(
base_path=subdir,
varying_param_values=ndims,
varying_param_name="n_dim",
fixed_params={"kappa": kappa},
n_runs=n_runs,
verbose=False,
)

# calculate ESS
ess_filepath = os.path.join(subdir, ess_filename)
ess_ndims = calc_ess_varying_param(
param_values=ndims,
datasets=datasets_varying_ndim,
ess_filepath=ess_filepath,
recompute_ess=False,
verbose=False,
)
recompute_ess = False

if not recompute_ess and os.path.exists(ess_filepath):
print("Loading ESS values from the file...")
ess_ndims = load(ess_filepath)
else:
# load samples for varying n_dim
print(f"Loading samples for varying n_dim from {subdir}...")
datasets_varying_ndim = load_samples(
base_path=subdir,
varying_param_values=ndims,
varying_param_name="n_dim",
fixed_params={"kappa": kappa},
n_runs=n_runs,
verbose=True,
)

# calculate ESS
ess_ndims = calc_ess_varying_param(
param_values=ndims,
datasets=datasets_varying_ndim,
ess_filepath=ess_filepath,
verbose=True,
)

# plotting
fig = ess_plot_varying_param(
Expand Down

0 comments on commit 5d97f96

Please sign in to comment.