Skip to content

Commit

Permalink
some other minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ShantanuKodgirwar committed Oct 24, 2024
1 parent 5b038dd commit 02a484a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
8 changes: 4 additions & 4 deletions scripts/curve_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import geosss as gs
from geosss.distributions import CurvedVonMisesFisher, Distribution
from geosss.spherical_curve import SlerpCurve
from geosss.spherical_curve import SlerpCurve, constrained_brownian_curve

plt.rc("font", size=16)

Expand Down Expand Up @@ -415,7 +415,7 @@ def _get_view_vector(elev, azim):
burnin = int(0.1 * n_samples) # burn-in

# optional controls
brownian_curve = False # fix curve (target)
is_brownian_curve = False # fix curve (target)
reprod_switch = True # seeds samplers for reproducibility
savefig = True # save the plots
rerun_if_file_exists = False # rerun even if file exists
Expand All @@ -426,7 +426,7 @@ def _get_view_vector(elev, azim):
setup_logging(savedir, kappa)

# define curve on the sphere
if not brownian_curve:
if not is_brownian_curve:
knots = np.array(
[
[-0.25882694, 0.95006168, 0.17433133],
Expand All @@ -442,7 +442,7 @@ def _get_view_vector(elev, azim):
]
)
else:
knots = gs.sphere.constrained_brownian_curve(
knots = constrained_brownian_curve(
n_points=10,
dimension=3,
step_size=0.3,
Expand Down
23 changes: 17 additions & 6 deletions scripts/ess_curve.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ESS computed for the curve on the sphere by varying the number of dimensions and
# the concentration parameter kappa.

# %%
import os

import arviz as az
Expand Down Expand Up @@ -168,6 +168,7 @@ def ess_plot_varying_param(
param_values,
param_name,
select_dim_idx: int = 0,
y_lim_factor: float = 18.0,
) -> plt.Figure:
"""
Plot ESS values against a varying parameter.
Expand All @@ -182,7 +183,8 @@ def ess_plot_varying_param(
Name of the parameter for labeling.
select_dim_idx : int
Index to select the ESS values that are computed for every dimension.
y_lim_factor : float
Factor to multiply the y limit by.
Returns
-------
plt.Figure
Expand All @@ -208,9 +210,14 @@ def ess_plot_varying_param(
color=color_palette[i],
)
ax.set_yscale("log")
ax.legend()

# Adjust y limit
ymin, ymax = ax.get_ylim()
ax.set_ylim(ymin, ymax * y_lim_factor)
ax.legend(loc="upper right")

ax.set_xlabel(param_name)
ax.set_ylabel("Relative ESS (log)")
ax.set_ylabel("relative ESS (log)")

ax.set_xticks(param_values)
ax.set_xticklabels(param_values)
Expand Down Expand Up @@ -262,8 +269,9 @@ def ess_plot_varying_param(
fig = ess_plot_varying_param(
ess_vals=ess_kappas,
param_values=kappas,
param_name=r"Concentration parameter $\kappa$",
param_name=r"concentration parameter $\kappa$",
select_dim_idx=0,
y_lim_factor=28,
)
fig.savefig(
f"{subdir}/ess_curve_10d_varying_kappa.pdf", transparent=True, dpi=150
Expand Down Expand Up @@ -306,9 +314,12 @@ def ess_plot_varying_param(
fig = ess_plot_varying_param(
ess_vals=ess_ndims,
param_values=ndims,
param_name="Dimensions $d$",
param_name="dimension $d$",
select_dim_idx=0,
y_lim_factor=13,
)
fig.savefig(
f"{subdir}/ess_curve_kappa_{int(kappa)}_varying_ndim.pdf",
)

# %%

0 comments on commit 02a484a

Please sign in to comment.