From 251d85f2e26f885356ac2d92de4b8dbce132cd6b Mon Sep 17 00:00:00 2001 From: Anna Foix Date: Mon, 29 Jul 2024 21:10:59 +0100 Subject: [PATCH] Updated graphs titles --- scripts/shapeembed/evaluation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/shapeembed/evaluation.py b/scripts/shapeembed/evaluation.py index 1a872eeb..d530e9f6 100644 --- a/scripts/shapeembed/evaluation.py +++ b/scripts/shapeembed/evaluation.py @@ -303,7 +303,8 @@ def save_barplot( scores_df ax.set(title=f'f1 score against batch size ({m}, compression factor {cf})') plt.savefig(f"{outputdir}/barplot_{m}_x_bs_cf{cf}.pdf") plt.close() - ax = seaborn.catplot( data=melted_df.loc[ (melted_df['model'] == m) & (melted_df['batch_size'] == cf) + for bs in melted_df['batch_size'].unique(): + ax = seaborn.catplot( data=melted_df.loc[ (melted_df['model'] == m) & (melted_df['batch_size'] == bs) , ['compression_factor', 'beta', 'Metric', 'Score'] ] , kind="bar" , x='compression_factor' @@ -314,7 +315,7 @@ def save_barplot( scores_df , aspect=width * 2**0.5 / height ) ax.tick_params(axis='x', rotation=90) ax.fig.subplots_adjust(top=0.9) - ax.set(title=f'f1 score against batch size ({m}, compression factor {cf})') + ax.set(title=f'f1 score against compression factor ({m}, compression batch size {bs})') plt.savefig(f"{outputdir}/barplot_{m}_x_cf_bs{bs}.pdf") plt.close() # log info