From 8868dc62b45d5d7bd151aef44391d901119ea73b Mon Sep 17 00:00:00 2001 From: Axel Larsson Date: Wed, 22 May 2024 13:44:35 -0700 Subject: [PATCH] Added command line plotting --- utils/python/linelast_cwtrain.py | 44 ++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/utils/python/linelast_cwtrain.py b/utils/python/linelast_cwtrain.py index c1926408..c8217822 100644 --- a/utils/python/linelast_cwtrain.py +++ b/utils/python/linelast_cwtrain.py @@ -160,7 +160,6 @@ def create_scaling_plot(samples, res, scale_prefix, plt_name = "plot.png"): axs[0].set_yscale('log') axs[0].legend() - axs[1].plot(samples, res[:,3], label='Speedup factor', ) axs[1].set_xlabel(scale_prefix) axs[1].legend() @@ -169,13 +168,43 @@ def create_scaling_plot(samples, res, scale_prefix, plt_name = "plot.png"): plt.tight_layout() plt.savefig(plt_name) -def basis_scaling_plot(prefix, plot_path): - scale_prefix = '$n_{basis}$' - #samples = (2, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128) //regular - samples = (1, 2, 3, 4, 6, 7, 9, 12, 15, 20, 25, 32, 41, 53, 68, 87, 111, 142, 183, 234, 300, 384) #//complicated +def get_nr(txt, split_txt = 'comparison'): + return int(txt.split('.')[0].split(split_txt)[1]) +def get_nrs(txts, split_txt = 'comparison'): + return [get_nr(txt, split_txt) for txt in txts] +def get_sorted_nrs(txts, split_txt = 'comparison'): + return sorted(zip(get_nrs(txts, split_txt),txts)) + +def basis_scaling_plot(prefix, plot_path, result_path = "basis_scaling"): + cwd = os.getcwd() + abs_scaling_folder = os.path.join(cwd,result_path) + os.chdir(abs_scaling_folder) + txts = os.listdir() + samples = [i for i,_ in get_sorted_nrs(txts, split_txt = 'comparison')] + os.chdir(cwd) res = get_results(samples, prefix) + scale_prefix = '$n_{basis}$' create_scaling_plot(samples, res, scale_prefix, plot_path) +def get_svs(filename): + with open(filename, 'r') as file: + return [float(line.strip()) for line in file] + +def sv_plot(jointname, h_name, v_name, plt_name = "sv_plot.png"): + svs_j = get_svs(jointname) + svs_b = get_svs(h_name) + svs_c = get_svs(v_name) + + plt.plot(range(len(svs_j)), svs_j, label='Joint') + plt.plot(range(len(svs_b)), svs_b, label='Beam') + plt.plot(range(len(svs_c)), svs_c, label='Column') + plt.xscale('log') + plt.yscale('log') + plt.legend() + plt.title('Advanced component SV spectrum') + plt.tight_layout() + plt.savefig(os.path.join(os.getcwd(), plt_name)) + if __name__ == "__main__": import sys if len(sys.argv) == 1: @@ -194,6 +223,11 @@ def basis_scaling_plot(prefix, plot_path): prefix = sys.argv[2] plot_path = sys.argv[3] basis_scaling_plot(prefix, plot_path) + elif name == "svplot": + jointname = sys.argv[2] + h_name = sys.argv[3] + v_name = sys.argv[4] + sv_plot(jointname, h_name, v_name) elif name == "cwtrain_mesh": prefix = sys.argv[2] create_training_meshes(prefix)