From 7658fe92a85683ce7f727764f6634abcee0561eb Mon Sep 17 00:00:00 2001 From: brunofavs Date: Tue, 14 May 2024 17:13:47 +0100 Subject: [PATCH] Added xlim ylim to plotting #951 and fixed a small bug when legend wasnt provided --- atom_batch_execution/scripts/batch_execution | 4 ++++ atom_batch_execution/scripts/plot_graphs | 18 ++++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/atom_batch_execution/scripts/batch_execution b/atom_batch_execution/scripts/batch_execution index 581c30e7..2b28cf12 100755 --- a/atom_batch_execution/scripts/batch_execution +++ b/atom_batch_execution/scripts/batch_execution @@ -138,6 +138,10 @@ def main(): # Dataset is no longer needed del dataset + # Add dataset dirname to data + dataset_dirname = os.path.dirname(data["dataset_path"]) + data["dataset_dirname"] = dataset_dirname + # Add folds to data data['folds'] = fold_list diff --git a/atom_batch_execution/scripts/plot_graphs b/atom_batch_execution/scripts/plot_graphs index a917a476..3cb943b5 100755 --- a/atom_batch_execution/scripts/plot_graphs +++ b/atom_batch_execution/scripts/plot_graphs @@ -144,7 +144,11 @@ def main(): # nightmare to work with # Get y data for name in x_data_files_names: - file_path_to_get_data_fom = glob.glob(f'{args["results_folder"]}/{name}/{plot_line["ydata"]["file"]}')[0] + try: + file_path_to_get_data_fom = glob.glob(f'{args["results_folder"]}/{name}/{plot_line["ydata"]["file"]}')[0] + except: + atomError(f"Experiment {x_data_files_names} doesn't have data file {plot_line['ydata']['file']}") + exit() # Method with pandas, not working # df = df.T # df.drop(0) @@ -206,6 +210,8 @@ def main(): markers = plot_options['markers'] xscale = plot_options['xscale'] yscale = plot_options['yscale'] + xlim = plot_options.get('xlim') + ylim = plot_options.get('ylim') # Normalizing options to allow them to be defined as lists or individual str/int @@ -237,7 +243,7 @@ def main(): x_values = plot_line['x_values'] y_values = plot_line['y_values'] - legend = plot_line['legend'] if plot_line['legend'] else None + legend = plot_line.get('legend') # Plotting plt.plot(x_values, y_values, color=colors[idx], linestyle=linestyles[idx], linewidth=linewidths[idx], @@ -250,6 +256,12 @@ def main(): plt.yscale(yscale) # Set y-scale if legend is not None: plt.legend(prop={'size': 6}) + + if xlim: + plt.xlim(xlim) + + if ylim: + plt.ylim(ylim) # Save figure in output folder @@ -269,6 +281,8 @@ def main(): if args["show_plots"]: plt.show() + plt.close() + # Getting back to cwd, to prevent confusion if this script is further modified os.chdir(cwd)