Skip to content

Commit

Permalink
adding the histogram plot with true marginals of mixture of vMF
Browse files Browse the repository at this point in the history
  • Loading branch information
ShantanuKodgirwar committed Oct 27, 2024
1 parent 8ff7c8a commit 7151c90
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
6 changes: 3 additions & 3 deletions geosss/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,14 @@ def gradient(self, x):
class MarginalVonMisesFisher(VonMisesFisher):
"""Computing marginals of the von Mises-Fisher distribution"""

def __init__(self, index, mu):
def __init__(self, dim_idx, mu):
super().__init__(mu)
self.index = index
self.dim_idx = dim_idx

def prob(self, x):
d = len(self.mu)
kappa = np.linalg.norm(self.mu)
mu = self.mu[self.index] / kappa
mu = self.mu[self.dim_idx] / kappa
prob = (
np.sqrt(kappa / (2 * np.pi))
/ iv(d / 2 - 1, kappa)
Expand Down
58 changes: 57 additions & 1 deletion geosss/vMF_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from csb.io import dump, load

from geosss.distributions import MarginalVonMisesFisher, MixtureModel, VonMisesFisher
from geosss.sphere import distance
from geosss.utils import acf

Expand Down Expand Up @@ -62,12 +63,67 @@ def hist_plot(samples, ndim, path, filename, fs=16, save_res=False):
fig.tight_layout()

if save_res:
print(f"Saving sampler marginals plot..")
print(f"Saving histogram plot..")
fig.savefig(f"{path}/{filename}_hist.pdf", transparent=True)

plt.close(fig)


def hist_plot_mixture_marginals(
pdf, samples, ndim, path, filename, fs=16, save_res=False
):
"""
histogram of samples.
"""
bins = 100
plt.rc("font", size=fs)

# shows a standard histogram per dimension
if ndim == 3:
figsize = (10, 10)
else:
figsize = (10, 15)
fig, rows = plt.subplots(ndim, len(METHODS), figsize=figsize, sharex=True)

if isinstance(pdf, MixtureModel):
mus = np.array([pdf.pdfs[i].mu for i in range(len(pdf.pdfs))])
elif isinstance(pdf, VonMisesFisher):
mus = pdf.mu

# reference samples
t = np.linspace(-1.0, 1.0, 1000)

for d_idx, axes in enumerate(rows):

# mixture of the marginals of von Mises-Fisher as ground truth samples
marginalvMFs = [MarginalVonMisesFisher(d_idx, mu) for mu in mus]
mixture_marginalvMFs = MixtureModel(marginalvMFs)
log_p = mixture_marginalvMFs.log_prob(t)
prob_truth = np.exp(log_p)

# show histogram
for ax, method in zip(axes, METHODS):
marginals = samples[method][:, d_idx]
bins = ax.hist(
marginals,
bins=bins,
density=True,
alpha=0.3,
color="k",
histtype="stepfilled",
)[1]
ax.plot(t, prob_truth, ls="--", c="r", lw=1)
ax.set_xlabel(rf"$e_{d_idx}^Tx_n$", fontsize=fs)

for ax, method in zip(rows[0], METHODS):
ax.set_title(ALGOS[method], fontsize=fs)
fig.tight_layout()

if save_res:
print(f"Saving histogram plot with mixture marginals..")
fig.savefig(f"{path}/{filename}_hist_marginals.pdf", transparent=True)


def trace_plots(samples, ndim, path, filename, fs=16, save_res=False):
"""
trace plots per dimension
Expand Down

0 comments on commit 7151c90

Please sign in to comment.