Skip to content

Commit

Permalink
update plotting HF simulator LF prior
Browse files Browse the repository at this point in the history
  • Loading branch information
lenarttreven committed Jan 31, 2024
1 parent 142d6cd commit 271a1d3
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 31 deletions.
122 changes: 94 additions & 28 deletions experiments/regression_exp/plots_num_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,60 @@
import pandas as pd
import numpy as np
import argparse

import math
from typing import Tuple

import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import matplotlib as mpl

from experiments.util import collect_exp_results, ucb, lcb, median, count
import math
from plotting_hyperdata import plotting_constants

plt.locator_params(nbins=4)

TITLE_FONT_SIZE = 18
LEGEND_FONT_SIZE = 14
LABEL_FONT_SIZE = 18
YLABEL_FONT_SIZE = 18
XLABEL_FONT_SIZE = 18
TICKS_SIZE = 14
LINE_WIDTH = 3

plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=
r'\usepackage{amsmath}'
r'\usepackage{bm}'
r'\def\vx{{\bm{x}}}'
r'\def\vf{{\bm{f}}}')

mpl.rcParams['xtick.labelsize'] = TICKS_SIZE
mpl.rcParams['ytick.labelsize'] = TICKS_SIZE


class FirstNonZeroFormatter(ticker.Formatter):
def __call__(self, x, pos=None):
# Handling very small or zero values to avoid log errors
if x <= 0:
return '0'

# Calculate the number of decimal places needed
# This will be negative for numbers larger than 1
# and positive for numbers smaller than 1.
num_decimals = np.ceil(-np.log10(x)).astype(int)

# Format string for precision
format_str = '{{:.{}f}}'.format(max(num_decimals, 0))

# Format the tick label
formatted_label = format_str.format(x)

# If the resulting formatted label has a period at the end, remove it
# This can happen when x is an exact integer
if formatted_label.endswith('.'):
formatted_label = formatted_label[:-1]

return formatted_label


def different_method_plot(df_agg: pd.DataFrame, metric: str = 'nll', display: bool = True,
Expand All @@ -29,8 +78,6 @@ def different_method_plot(df_agg: pd.DataFrame, metric: str = 'nll', display: bo
linestyle='', capsize=4., color='black')
ax.set_xticks(x_pos)
ax.set_xticklabels(best_rows_df['model'], rotation=-20)
ax.set_ylabel(metric)
fig.tight_layout()

if display:
plt.show()
Expand All @@ -40,11 +87,9 @@ def different_method_plot(df_agg: pd.DataFrame, metric: str = 'nll', display: bo
return best_rows_df, fig


def main(args, drop_nan=False, fig=None, show_legend=True):
def main(args, drop_nan=False, fig=None, show_legend=True, log_scale=False):
df_full, param_names = collect_exp_results(exp_name=args.exp_name)
df_full = df_full[df_full['data_source'] == args.data_source]
#df_full = df_full[df_full['model'] == args.data_source]

df_full = df_full[df_full['num_samples_train'] >= 10]

for col in ['bandwidth_kde', 'bandwidth_ssge', 'bandwidth_score_estim']:
Expand All @@ -66,14 +111,13 @@ def main(args, drop_nan=False, fig=None, show_legend=True):
series[is_nan] = max_value + noise
df_full[col] = series


# group over everything except seeds and aggregate over the seeds
groupby_names = list(set(param_names) - {'model_seed', 'data_seed'})

# rmove the likelihood_std column since it's a constant list which is not hashable
groupby_names.remove('likelihood_std')
# groupby_names.remove('added_gp_outputscale')
#df_full['added_gp_outputscale'] = df_full['added_gp_outputscale'].apply(lambda x: x[0])
# df_full['added_gp_outputscale'] = df_full['added_gp_outputscale'].apply(lambda x: x[0])

# replace all the nans in hyperparameter columns with 'N/A'
for column in groupby_names:
Expand All @@ -92,60 +136,82 @@ def main(args, drop_nan=False, fig=None, show_legend=True):
df_agg = df_agg[df_agg['rmse']['count'] >= 3]

available_models = sorted(list(set(df_agg['model'])))
print(available_models)
if fig is None:
fig, axs = plt.subplots(1, 2)
fig, axs = plt.subplots(2, 1)
else:
axs = fig.subplots(1, 2)
axs = fig.subplots(2, 1, sharex=True)
for idx, metric in enumerate(['nll', 'rmse']):
for model in available_models:
df_model = df_agg[df_agg['model'] == model].sort_values(by=[('num_samples_train', '')], ascending=True)
nice_name = plotting_constants.plot_num_data_name_transfer[model]
if args.quantile_cis:
axs[idx].plot(df_model[('num_samples_train', '')], df_model[(metric, 'median')], label=model)
axs[idx].plot(df_model[('num_samples_train', '')], df_model[(metric, 'median')],
label=nice_name,
color=plotting_constants.COLORS[nice_name],
linestyle=plotting_constants.LINE_STYLES[nice_name],
linewidth=LINE_WIDTH)
lower_ci = df_model[(metric, 'lcb')]
upper_ci = df_model[(metric, 'ucb')]
else:
axs[idx].plot(df_model[('num_samples_train', '')], df_model[(metric, 'mean')], label=model)
axs[idx].plot(df_model[('num_samples_train', '')], df_model[(metric, 'mean')],
label=nice_name,
color=plotting_constants.COLORS[nice_name],
linestyle=plotting_constants.LINE_STYLES[nice_name],
linewidth=LINE_WIDTH)
CI_width = 2 / np.sqrt(df_model[(metric, 'count')])
lower_ci = df_model[(metric, 'mean')] - CI_width * df_model[(metric, 'std')]
upper_ci = df_model[(metric, 'mean')] + CI_width * df_model[(metric, 'std')]
axs[idx].fill_between(df_model[('num_samples_train', '')], lower_ci, upper_ci, alpha=0.3)
axs[idx].set_title(f'{args.data_source} - {metric}')
# ax.set_xscale('log')
#ax.set_ylim((-3.8, -2.))

axs[idx].fill_between(df_model[('num_samples_train', '')], lower_ci, upper_ci, alpha=0.3,
color=plotting_constants.COLORS[nice_name])
if idx == 0:
axs[idx].set_title(plotting_constants.plot_num_data_data_source_transfer[args.data_source],
fontsize=TITLE_FONT_SIZE)
if args.data_source == 'racecar':
axs[idx].set_ylabel(plotting_constants.plot_num_data_metrics_transfer[metric],
fontsize=YLABEL_FONT_SIZE)
if idx == 1 and log_scale:
axs[idx].set_yscale('log')
axs[idx].yaxis.set_major_formatter(FirstNonZeroFormatter())
axs[idx].yaxis.set_minor_formatter(FirstNonZeroFormatter())
axs[idx].yaxis.set_minor_locator(plt.MaxNLocator(6))
axs[idx].set_xlabel('Number of iterations', fontsize=XLABEL_FONT_SIZE)
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
if show_legend:
fig.legend(by_label.values(), by_label.keys(), ncols=4, loc='lower center',
bbox_to_anchor=(0.5, 0), fontsize=10)
# fig.tight_layout(rect=[0, 0.1, 1, 1])

# plt.show()
print('Models:', set(df_agg['model']))



if __name__ == '__main__':
figure = plt.figure(figsize=(10, 7))
subfigs = figure.subfigures(2, 1, hspace=-0.1)

subfigs = figure.subfigures(1, 2, wspace=-0.05)

parser = argparse.ArgumentParser(description='Inspect results of a regression experiment.')
parser.add_argument('--exp_name', type=str, default='jan10_num_data')
parser.add_argument('--quantile_cis', type=int, default=1)
parser.add_argument('--data_source', type=str, default='racecar')
args = parser.parse_args()


main(args, fig=subfigs[0], show_legend=False)
main(args, fig=subfigs[0], show_legend=False, log_scale=True)

parser = argparse.ArgumentParser(description='Inspect results of a regression experiment.')
parser.add_argument('--exp_name', type=str, default='jan10_num_data')
parser.add_argument('--quantile_cis', type=int, default=1)
parser.add_argument('--data_source', type=str, default='pendulum')
args = parser.parse_args()

main(args, fig=subfigs[1])
plt.tight_layout(rect=[0, 0.2, 1, 0.9])
main(args, fig=subfigs[1], show_legend=False, log_scale=True)
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
figure.legend(by_label.values(), by_label.keys(),
ncols=4,
loc='upper center',
fontsize=LEGEND_FONT_SIZE,
frameon=False)
figure.tight_layout(rect=[0.15, -0.1, 1, 0.95])
plt.savefig('regression_exp.pdf')
plt.show()
50 changes: 47 additions & 3 deletions plotting_hyperdata/plotting_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
'long dash with offset': (5, (10, 3)),
'loosely dashed': (0, (5, 10)),
'dashed': (0, (5, 5)),
'dashdot': 'dashdot',
'densely dashed': (0, (5, 1)),
'loosely dashdotted': (0, (3, 10, 1, 10)),
'dashdotted': (0, (3, 5, 1, 5)),
Expand All @@ -16,15 +17,59 @@

METHODS = ['SVGD',
'FSVGD',
'GreyBox',
'GreyBox + FSVGD',
'Sim Model',
'FSVGD[SimPrior=GP]',
'FSVGD[SimPrior=KDE]',
'FSVGD[SimPrior=Nu-Method]',
'FSVGD[SimPrior=SSGE]',
'Sampled Prior Functions'
'Sampled Prior Functions',
'NP',
'PACOH'
]

COLORS = {'SVGD': '#228B22',
'FSVGD': '#9BCD4C',
'GreyBox + FSVGD': '#CB3E3E',
'Sim Model': '#AF5227',
'FSVGD[SimPrior=GP]': '#DC6A73',
'FSVGD[SimPrior=KDE]': '#daa520',
'FSVGD[SimPrior=Nu-Method]': '#8486E0',
'NP': '#654321',
'PACOH': '#808080'
}

LINE_STYLES = {'SVGD': linestyle_tuple['densely dashdotdotted'],
'FSVGD': linestyle_tuple['dashdot'],
'GreyBox + FSVGD': linestyle_tuple['dashdotdotted'],
'Sim Model': linestyle_tuple['densely dashdotted'],
'FSVGD[SimPrior=GP]': linestyle_tuple['dashdotted'],
'FSVGD[SimPrior=KDE]': linestyle_tuple['dotted'],
'FSVGD[SimPrior=Nu-Method]': linestyle_tuple['densely dashed'],
'NP': linestyle_tuple['dashed'],
'PACOH': linestyle_tuple['loosely dashed']
}

plot_num_data_data_source_transfer = {
'pendulum': 'Pendulum[HF simulator][LF prior]',
'racecar': 'Racecar[HF simulator][LF prior]'
}

plot_num_data_metrics_transfer = {
'nll': 'NLL',
'rmse': 'RMSE'
}

plot_num_data_name_transfer = {
'BNN_FSVGD': METHODS[1],
'BNN_FSVGD_SimPrior_gp': METHODS[4],
'BNN_FSVGD_SimPrior_kde': METHODS[5],
'BNN_FSVGD_SimPrior_nu-method': METHODS[6],
'BNN_SVGD': METHODS[0],
'NP': METHODS[9],
'PACOH': METHODS[10]
}

TRUE_FUNCTION_COLOR = 'black'
TRUE_FUNCTION_LINE_STYLE = linestyle_tuple['densely dashed']
TRUE_FUNCTION_LINE_WIDTH = 4
Expand All @@ -34,7 +79,6 @@
CONFIDENCE_ALPHA = 0.2
MEAN_FUNCTION_LINE_WIDTH = 4


SAMPLES_COLOR = 'green'
SAMPLES_LINE_STYLE = linestyle_tuple['densely dashdotdotted']
SAMPLES_ALPHA = 0.6
Expand Down

0 comments on commit 271a1d3

Please sign in to comment.