-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathshow_history.py
44 lines (34 loc) · 1.11 KB
/
show_history.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
""" Visualize training history.
"""
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import argparse
def main():
parser = argparse.ArgumentParser(description="Visualize training history.")
parser.add_argument('history', type=str, help="Path to the history file.")
args = parser.parse_args()
df = pd.read_csv(args.history)
# header: datetime, epoch, learning rate, train loss, dev loss, error rate
plt.figure(figsize=(15,3))
plt.subplots_adjust(.05, 0.15, .95, .9, None, None)
plt.subplot(1,3,1)
plt.title("Loss")
plt.plot(df['epoch'], df['train loss'], label='train')
plt.plot(df['epoch'], df['dev loss'], label='dev')
plt.xlabel('epochs')
plt.legend()
plt.subplot(1,3,2)
plt.title("Dev. error rate")
plt.grid()
plt.plot(df['epoch'], df['error rate'])
plt.ylim(0,1)
plt.xlabel('epochs')
plt.subplot(1,3,3)
plt.title("Learning rate")
plt.plot(df['epoch'], df['learning rate'])
plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
plt.xlabel('epochs')
plt.show()
if __name__ == '__main__':
main()