From 02a484aa7e24a06d12e64ccda9474c03094c4783 Mon Sep 17 00:00:00 2001 From: ShantanuKodgirwar Date: Thu, 24 Oct 2024 20:49:57 +0200 Subject: [PATCH] some other minor changes --- scripts/curve_3d.py | 8 ++++---- scripts/ess_curve.py | 23 +++++++++++++++++------ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/scripts/curve_3d.py b/scripts/curve_3d.py index f0b9820..8e188fc 100644 --- a/scripts/curve_3d.py +++ b/scripts/curve_3d.py @@ -13,7 +13,7 @@ import geosss as gs from geosss.distributions import CurvedVonMisesFisher, Distribution -from geosss.spherical_curve import SlerpCurve +from geosss.spherical_curve import SlerpCurve, constrained_brownian_curve plt.rc("font", size=16) @@ -415,7 +415,7 @@ def _get_view_vector(elev, azim): burnin = int(0.1 * n_samples) # burn-in # optional controls - brownian_curve = False # fix curve (target) + is_brownian_curve = False # fix curve (target) reprod_switch = True # seeds samplers for reproducibility savefig = True # save the plots rerun_if_file_exists = False # rerun even if file exists @@ -426,7 +426,7 @@ def _get_view_vector(elev, azim): setup_logging(savedir, kappa) # define curve on the sphere - if not brownian_curve: + if not is_brownian_curve: knots = np.array( [ [-0.25882694, 0.95006168, 0.17433133], @@ -442,7 +442,7 @@ def _get_view_vector(elev, azim): ] ) else: - knots = gs.sphere.constrained_brownian_curve( + knots = constrained_brownian_curve( n_points=10, dimension=3, step_size=0.3, diff --git a/scripts/ess_curve.py b/scripts/ess_curve.py index ce57749..0306061 100644 --- a/scripts/ess_curve.py +++ b/scripts/ess_curve.py @@ -1,6 +1,6 @@ # ESS computed for the curve on the sphere by varying the number of dimensions and # the concentration parameter kappa. - +# %% import os import arviz as az @@ -168,6 +168,7 @@ def ess_plot_varying_param( param_values, param_name, select_dim_idx: int = 0, + y_lim_factor: float = 18.0, ) -> plt.Figure: """ Plot ESS values against a varying parameter. @@ -182,7 +183,8 @@ def ess_plot_varying_param( Name of the parameter for labeling. select_dim_idx : int Index to select the ESS values that are computed for every dimension. - + y_lim_factor : float + Factor to multiply the y limit by. Returns ------- plt.Figure @@ -208,9 +210,14 @@ def ess_plot_varying_param( color=color_palette[i], ) ax.set_yscale("log") - ax.legend() + + # Adjust y limit + ymin, ymax = ax.get_ylim() + ax.set_ylim(ymin, ymax * y_lim_factor) + ax.legend(loc="upper right") + ax.set_xlabel(param_name) - ax.set_ylabel("Relative ESS (log)") + ax.set_ylabel("relative ESS (log)") ax.set_xticks(param_values) ax.set_xticklabels(param_values) @@ -262,8 +269,9 @@ def ess_plot_varying_param( fig = ess_plot_varying_param( ess_vals=ess_kappas, param_values=kappas, - param_name=r"Concentration parameter $\kappa$", + param_name=r"concentration parameter $\kappa$", select_dim_idx=0, + y_lim_factor=28, ) fig.savefig( f"{subdir}/ess_curve_10d_varying_kappa.pdf", transparent=True, dpi=150 @@ -306,9 +314,12 @@ def ess_plot_varying_param( fig = ess_plot_varying_param( ess_vals=ess_ndims, param_values=ndims, - param_name="Dimensions $d$", + param_name="dimension $d$", select_dim_idx=0, + y_lim_factor=13, ) fig.savefig( f"{subdir}/ess_curve_kappa_{int(kappa)}_varying_ndim.pdf", ) + +# %%