Skip to content

Commit

Permalink
added alternative matplotlib titles for pdfs
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Apr 12, 2024
1 parent b69ad23 commit 1941a97
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions src/scripts/plots/ring/pdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,34 @@
parser.add_argument('--prog-ellipses', default=False, action='store_true',
help="Whether to plot ellipses progressively")
parser.add_argument('--title', default=False, action='store_true', help="Whether to show a title")
parser.add_argument('--alt-title', default=False, action='store_true',
help="Whether to show alternative titles")
parser.add_argument('--vertical-title', default=False, action='store_true',
help="Whether to show the title vertically")
parser.add_argument('--dpi', type=int, default=192, help="The DPI for PNG rasterization")
parser.add_argument('--prune', default=False, action='store_true',
help="Whether to prune components having weight close to zero")


def format_model_name(m: str, num_components: int) -> str:
if m == 'MonotonicPC':
return f"GMM ($K \!\! = \!\! {num_components}$)"
elif m == 'BornPC':
return f"SGMM ($K \!\! = \!\! {num_components}$)"
def format_model_name(m: str, num_components: int, alt: bool = False) -> str:
if alt:
if m == 'MonotonicPC':
if num_components == 1:
return r"$\mathcal{N}_1$"
elif num_components == 2:
return r"$w_1\mathcal{N}_1 + w_2\mathcal{N}_2$"
return r"$w_1\mathcal{N}_1 + \cdots w_K\mathcal{N}_K$"
elif m == 'BornPC':
if num_components == 1:
return r"$\mathcal{N}_1$"
elif num_components == 2:
return r"$\mathcal{N}_1 - w_2\mathcal{N}_2$"
assert False
else:
if m == 'MonotonicPC':
return fr"GMM ($K \!\! = \!\! {num_components}$)"
elif m == 'BornPC':
return fr"SGMM ($K \!\! = \!\! {num_components}$)"
return m


Expand Down Expand Up @@ -185,15 +201,15 @@ def plot_pdf(

data_pdfs = [(None, truth_pdf, 'Ground Truth', -1)] + list(zip(mixtures, pdfs, models, num_components))
for idx, (p, pdf, m, nc) in enumerate(data_pdfs):
if args.prog_ellipses:
plot_settings = [{'max_num_components': i + 1} for i in range(nc)]
if args.prog_ellipses and p is not None:
plot_settings = [{'max_num_components': i} for i in range(nc + 1)]
else:
plot_settings = [{'max_num_components': None}]
for j, ps in enumerate(plot_settings):
setup_tueplots(1, 1, rel_width=0.2, hw_ratio=1.0)
fig, ax = plt.subplots(1, 1)
if args.title:
title = f"{format_model_name(m, nc)}" if p is not None else m
title = f"{format_model_name(m, nc, alt=args.alt_title)}" if p is not None else m
else:
title = None

Expand All @@ -213,16 +229,19 @@ def plot_pdf(
ax.set_yticks([])
ax.set_aspect(1.0)
if args.title:
fontsize = 7 if args.alt_title else 10
if args.vertical_title:
ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center')
ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center', fontsize=fontsize)
else:
ax.set_title(title, y=-0.275)
y = -0.225 if args.alt_title else -0.275
ax.set_title(title, y=y, fontsize=fontsize)

if args.prog_ellipses:
if args.prog_ellipses and p is not None:
filename = f'pdfs-ellipses-{idx}-{j}.png' if args.show_ellipses else f'pdfs-{idx}-{j}.png'
subdir = 'progressive'
else:
filename = f'pdfs-ellipses-{idx}.png' if args.show_ellipses else f'pdfs-{idx}.png'
subdir = 'plain'
os.makedirs(os.path.join('figures', 'gaussian-ring', subdir), exist_ok=True)
plt.savefig(os.path.join('figures', 'gaussian-ring', subdir, filename), dpi=args.dpi)
plt.close()

0 comments on commit 1941a97

Please sign in to comment.