-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy patheval.py
57 lines (45 loc) · 1.76 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import argparse
import lightning.pytorch as pl
from glob import glob
from omegaconf import OmegaConf
from loguru import logger
import owdfa.algorithm as algorithm
from owdfa.datasets import create_dataloader
import warnings
warnings.filterwarnings("ignore")
def main():
# set configs
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str,
default='./configs/eval.yaml')
parser.add_argument('--exam_id', type=str, default='')
parser.add_argument('--ckpt_path', type=str, default='')
parser.add_argument('--output_log', type=str, default='eval.log')
parser.add_argument('--split', type=str, default='test')
parser.add_argument('--distributed', type=int, default=0)
parser.add_argument('--debug', action='store_true', default=False)
args = parser.parse_args()
local_config = OmegaConf.load(args.config)
for k, v in local_config.items():
setattr(args, k, v)
os.environ['TORCH_HOME'] = args.torch_home
# search the checkpoint file according EXAM ID
if args.exam_id:
ckpt_path = glob(f'wandb/*{args.exam_id}/ckpts/*.ckpt')
if len(ckpt_path) >= 1:
ckpt_path = sorted(ckpt_path)
args.ckpt_path = ckpt_path[-1]
exam_dir = os.path.dirname(os.path.dirname(args.ckpt_path))
# add log file
if len(args.output_log) > 0:
logger.add(f'{exam_dir}/{args.output_log}', level="INFO")
# load dataset
test_dataloader = create_dataloader(args, split=args.split)
method = algorithm.__dict__[args.method.name](args)
trainer = pl.Trainer(default_root_dir=exam_dir)
trainer.validate(method, test_dataloader, args.ckpt_path)
if __name__ == '__main__':
main()