-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_infer.py
72 lines (55 loc) · 1.84 KB
/
main_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import os
import numpy as np
from infer_tfkeras import infer
import ran_func as func
# name = "unpaired_ct_lung"
name = "unpaired_ct_abdomen"
parser = argparse.ArgumentParser()
#=======================================================================================================================
parser.add_argument(
"--model_name",
"-mn",
help="network for training."
"-mn for ",
type=str,
default="RAN4",
required=False,
)
parser.add_argument(
"--model_id",
"-mi",
help="network for training."
"-mn for ",
type=int,
default=3,
required=False,
)
parser.add_argument(
"--data_name",
"-dn",
help="data name for training."
"-dn for ",
type=str,
default=name,
required=False,
)
args = parser.parse_args()
#=======================================================================================================================
data_name=args.data_name
model_name=args.model_name
print(model_name)
rescale_factor=1
rescale_factor_label=1
int_range=[-100,300]
crop_sz = np.array([0, 0, 0])
print(crop_sz)
net_core = func.networks.get_net(model_name)
# use_lab=True
use_lab=False
model_path=os.path.join('.','models',data_name,data_name+'-'+model_name,'model_'+str(args.model_id))
data_path=os.path.join('.','data',data_name,'dataset')
print(os.path.abspath(data_path))
test_paths = os.path.join(data_path,'test_proc')
#=======================================================================================================================
infer(net_core=net_core,model_path=model_path,crop_sz=crop_sz,pair_type="unpaired",rescale_factor=rescale_factor,rescale_factor_label=rescale_factor_label,use_lab=use_lab,test_path=test_paths,model_name=model_name,int_range=int_range)