-
Notifications
You must be signed in to change notification settings - Fork 17
/
Main.py
61 lines (32 loc) · 1.33 KB
/
Main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 18 16:08:57 2021
@author: Narmin Ghaffari Laleh
"""
###############################################################################
from Classic_Training import Classic_Training
from CLAM_MIL_Training import CLAM_MIL_Training
from AttMIL_Training import AttMIL_Training
import utils.utils as utils
from pathlib import Path
import warnings
import argparse
import torch
# %%
parser = argparse.ArgumentParser(description = 'Main Script to Run Training')
pathToExperimentFile = Path(".\Experiments\COHORT_RESNET18_CROSSVAL.txt")
parser.add_argument('--adressExp', type = str, default = pathToExperimentFile, help = 'Adress to the experiment File')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore")
print('\nTORCH Detected: {}\n'.format(device))
#%%
if __name__ == '__main__':
args = utils.ReadExperimentFile(args)
if args.useClassicModel:
Classic_Training(args)
torch.cuda.set_device(args.gpuNo)
elif args.model_name == 'attmil':
AttMIL_Training(args)
else:
CLAM_MIL_Training(args)