-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainx.py
113 lines (87 loc) · 3.44 KB
/
trainx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from utils import *
from model import SiameseNetwork
from loss import ContrastiveLoss
from dataset import ATATContrast
from torchsummary import summary
from matplotlib import pyplot as plt
path = '/home/harsh/Downloads/ATAT/'
## Initialize parameters
bs = 32
lr = 1e-3
threshold = 0.3
margin = 1.5
epochs = 40
## Initialize network
model = SiameseNetwork()
model = model.cuda()
model.apply(initialize_weights)
## Initialize optimizer
optim = torch.optim.Adam(model.parameters(),lr=lr)
## Initialize scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optim,8)
## Initialize loss
criterion = ContrastiveLoss(margin)
# criterion = torch.nn.BCEWithLogitsLoss()
## Initialize datasets and dataloaders
train_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
])
valid_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
# std=(0.229, 0.224, 0.225))
])
test_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
# std=(0.229, 0.224, 0.225))
])
train_ds = ATATContrast(ImageFolder(root = path + 'train',transform=train_transforms))
valid_ds = ATATContrast(ImageFolder(root = path + 'valid',transform=valid_transforms))
test_ds = ATATContrast(ImageFolder(root = path + 'test',transform=test_transforms))
train_dl = DataLoader(train_ds,batch_size=bs)
valid_dl = DataLoader(valid_ds,batch_size=bs)
test_dl = DataLoader(test_ds,batch_size=bs)
train_loss = []
valid_loss = []
for epoch in range(epochs):
train_epoch_loss = 0
model.train()
for i,(input1,input2,target) in enumerate(train_dl):
optim.zero_grad()
output1,output2 = model(input1.cuda(),input2.cuda())
out = model(input1.cuda(),input2.cuda())
loss = criterion(output1,output2,target.cuda())
train_epoch_loss += loss.item()
loss.backward()
optim.step()
train_epoch_loss /= len(train_ds)
train_loss.append(train_epoch_loss)
print("Epoch [{}/{}] ----> Training loss :{} \n".format(epoch+1,epochs,train_epoch_loss))
valid_epoch_loss = 0
val_pos_accuracy = 0
val_neg_accuracy = 0
num_pos = 0
num_neg = 0
model.eval()
for i,(input1,input2,target) in enumerate(valid_dl):
output1,output2 = model(input1.cuda(),input2.cuda())
loss = criterion(output1,output2,target.cuda())
valid_epoch_loss += loss.item()
pos_acc,pos_sum,neg_acc,neg_sum = evaluate_pair(output1,output2,target.cuda(),threshold)
val_pos_accuracy+=pos_acc
val_neg_accuracy+=neg_acc
num_pos+=pos_sum
num_neg+=neg_sum
valid_epoch_loss /= len(valid_ds)
val_pos_accuracy /= num_pos
val_neg_accuracy /= num_neg
valid_loss.append(valid_epoch_loss)
print("Validation loss :{} \t\t\t P Acc : {}, N Acc: {}\n".format(valid_epoch_loss,val_pos_accuracy,val_neg_accuracy))