-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvisualize_agent_graph.py
40 lines (34 loc) · 1.08 KB
/
visualize_agent_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import numpy as np
from text_localization_environment import TextLocEnv
from chainerrl.links.mlp import MLP
from chainerrl.links import Sequence
from chainerrl.experiments.train_agent import train_agent_with_evaluation
import chainer
import chainerrl
import logging
import sys
from tb_chainer import SummaryWriter
import time
import re
import chainer.computational_graph as c
from custom_model import CustomModel
from config import CONFIG, print_config
"""
Set arguments w/ config file (--config) or cli
:imagefile_path :boxfile_path
"""
def main():
print_config()
relative_paths = np.loadtxt(CONFIG['imagefile_path'], dtype=str)
images_base_path = os.path.dirname(CONFIG['imagefile_path'])
absolute_paths = [images_base_path + i.strip('.') for i in relative_paths]
bboxes = np.load(CONFIG['boxfile_path'], allow_pickle=True)
env = TextLocEnv(absolute_paths, bboxes, -1)
m = CustomModel(10)
vs = [m(env.reset())]
g = c.build_computational_graph(vs)
with open('graph.dot', 'w') as o:
o.write(g.dump())
if __name__ == '__main__':
main()