-
Notifications
You must be signed in to change notification settings - Fork 3
/
Main.py
186 lines (145 loc) · 5.54 KB
/
Main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# #Main.py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import nibabel as nib
import torch
from torch.utils import data
import model
import utils
import MetricAndLoss
from Dataset import DatasetMRI
from Log import SegmentationLoss
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#%% Dataset
# Define folders
folder_training = '../Training/'
folder_validation = '../Validation/'
# create DataFrames with data about our dataset
tableTraining = utils.CreateDataTable(folder_training,True)
# now, split the tableTraining TO 80-20 split for validation set
num_training = 269
tableValidation = tableTraining.iloc[num_training:]
tableTraining = tableTraining.iloc[:num_training]
# calculate the z-score normalization for every input type
Dict_stats = utils.CalculateStats(tableTraining,True)
#%% Create dataset and data loade
# define batch size
batch_size_train = 2
batch_size_validation = 2
# define dataset and dataloader for training
train_dataset = DatasetMRI(tableTraining,Dict_stats)
train_loader = data.DataLoader(train_dataset,batch_size=batch_size_train,shuffle=True)
# define dataset and dataloader for validation
validation_dataset = DatasetMRI(tableValidation,Dict_stats)
validation_loader = data.DataLoader(validation_dataset,batch_size=batch_size_validation,shuffle=True)
#%% Define parameters
# number of epochs
num_epochs = 20
# load model
model = model.MRIModel().to(device)
utils.count_parameters(model)
# send parameters to optimizer
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# define loss function
#criterion = MetricAndLoss.DiceLoss()
# initiate logs
trainLog = SegmentationLoss()
validationLog = SegmentationLoss()
#%% Training
for epoch in range(num_epochs):
##################
### TRAIN LOOP ###
##################
# set the model to train mode
model.train()
# initiate training loss
train_loss = 0
i = 0 # index for log
for batch in train_loader:
# get batch images and labels
T1 = batch['T1'].to(device)
T1_ce = batch['T1 ce'].to(device)
T2 = batch['T2'].to(device)
FLAIR = batch['FLAIR'].to(device)
labels = batch['Label'].to(device)
# clear the old gradients from optimizer
optimizer.zero_grad()
# forward pass: feed inputs to the model to get outputs
output = model(T1,T1_ce,T2,FLAIR)
# calculate the training batch loss
#loss = criterion(output, torch.max(labels, 1)[1])
loss = MetricAndLoss.DiceLoss(output,labels)
# backward: perform gradient descent of the loss w.r. to the model params
loss.backward()
# update the model parameters by performing a single optimization step
optimizer.step()
# accumulate the training loss
train_loss += loss.item()
# update training log
print('Epoch %d, Batch %d/%d, loss: %.4f' % (epoch,i,len(train_loader),loss))
trainLog.BatchUpdate(epoch,i,loss)
i += 1 # update index
#######################
### VALIDATION LOOP ###
#######################
# set the model to eval mode
model.eval()
# initiate validation loss
valid_loss = 0
i = 0 # index for Log
# turn off gradients for validation
with torch.no_grad():
for batch in validation_loader:
# get batch images and labels
T1 = batch['T1'].to(device)
T1_ce = batch['T1 ce'].to(device)
T2 = batch['T2'].to(device)
FLAIR = batch['FLAIR'].to(device)
labels = batch['Label'].to(device)
# forward pass
output = model(T1,T1_ce,T2,FLAIR)
# validation batch loss
#loss = criterion(output, torch.max(labels, 1)[1])
loss = MetricAndLoss.DiceLoss(output,labels)
# accumulate the valid_loss
valid_loss += loss.item()
# update validation log
print('Epoch %d, Batch %d/%d, loss: %.4f' % (epoch,i,len(validation_loader),loss))
validationLog.BatchUpdate(epoch,i,loss)
i += 1 # update loss
#########################
## PRINT EPOCH RESULTS ##
#########################
train_loss /= len(train_loader)
valid_loss /= len(validation_loader)
# update training and validation loss
trainLog.EpochUpdate(epoch,train_loss)
validationLog.EpochUpdate(epoch,valid_loss)
# print results
print('Epoch: %s/%s: Training loss: %.3f. Validation Loss: %.3f.'
% (epoch+1,num_epochs,train_loss,valid_loss))
#%% Save the model
PATH = '../model_16_01_2020_3D_U_net.pt'
train_loss = trainLog.getLoss()
validation_loss = validationLog.getLoss()
torch.save({'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'validation_loss': validation_loss}, PATH)
#%% Load the model
if False:
PATH = '../model_16_01_2020_3D_U_net.pt'
checkpoint = torch.load(PATH)
import model
model2 = model.MRIModel()
model2.load_state_dict(checkpoint['model_state_dict'])
#%%
plt.figure()
plt.plot(range(num_epochs),train_loss,label='Training Loss')
plt.plot(range(num_epochs),validation_loss,label='Validation Loss')
plt.grid(); plt.xlabel('Number of epochs'); plt.ylabel('Loss')
plt.title('Loss for 3D-Unet for BraTS2020 Brain MRI Segmentation')
plt.legend()