Skip to content

Commit

Permalink
Added command line plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
larsson4 committed May 23, 2024
1 parent 1d73828 commit d116d87
Showing 1 changed file with 39 additions and 5 deletions.
44 changes: 39 additions & 5 deletions utils/python/linelast_cwtrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def create_scaling_plot(samples, res, scale_prefix, plt_name = "plot.png"):
axs[0].set_yscale('log')
axs[0].legend()


axs[1].plot(samples, res[:,3], label='Speedup factor', )
axs[1].set_xlabel(scale_prefix)
axs[1].legend()
Expand All @@ -169,13 +168,43 @@ def create_scaling_plot(samples, res, scale_prefix, plt_name = "plot.png"):
plt.tight_layout()
plt.savefig(plt_name)

def basis_scaling_plot(prefix, plot_path):
scale_prefix = '$n_{basis}$'
#samples = (2, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128) //regular
samples = (1, 2, 3, 4, 6, 7, 9, 12, 15, 20, 25, 32, 41, 53, 68, 87, 111, 142, 183, 234, 300, 384) #//complicated
def get_nr(txt, split_txt = 'comparison'):
return int(txt.split('.')[0].split(split_txt)[1])
def get_nrs(txts, split_txt = 'comparison'):
return [get_nr(txt, split_txt) for txt in txts]
def get_sorted_nrs(txts, split_txt = 'comparison'):
return sorted(zip(get_nrs(txts, split_txt),txts))

def basis_scaling_plot(prefix, plot_path, result_path = "basis_scaling"):
cwd = os.getcwd()
abs_scaling_folder = os.path.join(cwd,result_path)
os.chdir(abs_scaling_folder)
txts = os.listdir()
samples = [i for i,_ in get_sorted_nrs(txts, split_txt = 'comparison')]
os.chdir(cwd)
res = get_results(samples, prefix)
scale_prefix = '$n_{basis}$'
create_scaling_plot(samples, res, scale_prefix, plot_path)

def get_svs(filename):
with open(filename, 'r') as file:
return [float(line.strip()) for line in file]

def sv_plot(jointname, h_name, v_name, plt_name = "sv_plot.png"):
svs_j = get_svs(jointname)
svs_b = get_svs(h_name)
svs_c = get_svs(v_name)

plt.plot(range(len(svs_j)), svs_j, label='Joint')
plt.plot(range(len(svs_b)), svs_b, label='Beam')
plt.plot(range(len(svs_c)), svs_c, label='Column')
plt.xscale('log')
plt.yscale('log')
plt.legend()
plt.title('Advanced component SV spectrum')
plt.tight_layout()
plt.savefig(os.path.join(os.getcwd(), plt_name))

if __name__ == "__main__":
import sys
if len(sys.argv) == 1:
Expand All @@ -194,6 +223,11 @@ def basis_scaling_plot(prefix, plot_path):
prefix = sys.argv[2]
plot_path = sys.argv[3]
basis_scaling_plot(prefix, plot_path)
elif name == "svplot":
jointname = sys.argv[2]
h_name = sys.argv[3]
v_name = sys.argv[4]
sv_plot(jointname, h_name, v_name)
elif name == "cwtrain_mesh":
prefix = sys.argv[2]
create_training_meshes(prefix)
Expand Down

0 comments on commit d116d87

Please sign in to comment.