This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathprint_results.py
66 lines (57 loc) · 2.04 KB
/
print_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import pandas as pd
import omegaconf
import argparse
import csv
import os
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--root', nargs='+')
parser.add_argument('--group', nargs='+', default=['model'])
parser.add_argument('--agg', nargs='+', default=['step', 'val_f1'])
parser.add_argument('--early_stop', default='val_f1')
args = parser.parse_args()
def load_best(dlog):
best = 0
best_row = None
for root, dirs, files in os.walk(dlog):
if 'metrics.csv' in files:
with open(os.path.join(root, 'metrics.csv')) as f:
reader = csv.reader(f)
header = next(reader)
for r in reader:
d = dict(zip(header, r))
for k, v in d.items():
try:
d[k] = float(v)
except Exception:
pass
d['path'] = root
v = d[args.early_stop]
if v and not pd.isna(v) and v > best:
best_row = d
best = v
return best_row
all_logs = []
for p in args.root:
for root, dirs, files in os.walk(p):
dlog = os.path.join(root, 'logs')
fconfig = os.path.join(root, 'config.yaml')
if os.path.isfile(fconfig) and os.path.isdir(dlog):
best = load_best(dlog)
if best is not None:
cfg = omegaconf.OmegaConf.load(fconfig)
best.update(cfg)
all_logs.append(best)
df = pd.DataFrame(all_logs)
agg = {}
for a in args.agg:
agg['mean_{}'.format(a)] = (a, 'mean')
agg['std_{}'.format(a)] = (a, 'std')
agg['count'] = (args.early_stop, 'count')
df = df[args.group + args.agg]
summary = df.groupby(args.group).agg(**agg)
print(summary)