-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathcsv_utils.py
68 lines (57 loc) · 2.43 KB
/
csv_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import csv
import numpy as np
class CSVLogger(object):
def __init__(self, log_dir, filename="progress.csv"):
self.csvfile = open(os.path.join(log_dir, filename), "w")
self.writer = None
def init_writer(self, keys):
if self.writer is None:
self.writer = csv.DictWriter(self.csvfile, fieldnames=list(keys))
self.writer.writeheader()
def log_epoch(self, data):
if "stats" in data:
for key, values in data["stats"].items():
data["mean_" + key] = np.mean(values)
data["median_" + key] = np.median(values)
data["min_" + key] = np.min(values)
data["max_" + key] = np.max(values)
del data["stats"]
if "test_stats" in data:
for key, values in data["test_stats"].items():
data["test_mean_" + key] = np.mean(values)
data["test_median_" + key] = np.median(values)
data["test_min_" + key] = np.min(values)
data["test_max_" + key] = np.max(values)
del data["test_stats"]
self.init_writer(data.keys())
self.writer.writerow(data)
self.csvfile.flush()
def __del__(self):
self.csvfile.close()
class ConsoleCSVLogger(CSVLogger):
def __init__(self, console_log_interval=1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.console_log_interval = console_log_interval
def log_epoch(self, data):
super().log_epoch(data)
if data["iter"] % self.console_log_interval == 0:
print(
"Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, test_mean/median reward {:.1f}/{:.1f}, test_min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".format(
data["iter"],
data["total_num_steps"],
data["fps"],
data["mean_rew"],
data["median_rew"],
data["min_rew"],
data["max_rew"],
data["test_mean_rew"],
data["test_median_rew"],
data["test_min_rew"],
data["test_max_rew"],
data["entropy"],
data["value_loss"],
data["action_loss"],
),
flush=True,
)