-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_profile.py
103 lines (88 loc) · 3.75 KB
/
train_profile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from datasets.dataset import HDTDataset, DatasetProperties
import argparse
from utils.utils_config import get_config
from datasets.dataset import HDTDataset, DatasetProperties
import argparse
from utils.utils_config import get_config
from facenet_pytorch import InceptionResnetV1, MTCNN
from PIL import Image
from torchvision import transforms
from models.PDT import PDT
from torchsummary import summary
import torch
from torch.utils.data import DataLoader
from losses.loss import DeepContrastiveLoss
from torchvision import transforms
def main(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"DEVICE IS {device}")
cfg = get_config(args.config)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#Transforms
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
train_transform = transforms.Compose([
test_transform,
transforms.RandomHorizontalFlip(0.5)
])
#Dataset
d_properties = DatasetProperties(path=cfg.path_to_dataset,
same_label_pairs=cfg.load_same_label_pairs,
save_same_label_pairs=cfg.save_same_label_pairs,
subindex_for_label=cfg.load_subindex_for_label,
save_subindex_for_label=cfg.save_subindex_for_label)
train_dataset = HDTDataset(d_properties, custom_transform=train_transform)
train_dataloader = DataLoader(train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.dl_num_workers)
#TODO: train val test split
#Model
translator=PDT(pool_features=6,use_se=False, use_bias=False, use_cbam=True)
translator.to(device)
#TODO: Make it configurable
#Optimizer and Loss
optimizer = torch.optim.Adam(translator.parameters(), lr=cfg.learning_rate)
criterion = DeepContrastiveLoss(device = device,
pretrained_on = cfg.pretrained_on,
margin=cfg.loss_margin)
criterion.to(device)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/profiles', worker_name='worker0'),
record_shapes=True,
profile_memory=True, # This will take 1 to 2 minutes. Setting it to False could greatly speedup.
with_stack=True
) as p:
for step, (img1, img2, same_label, only_nir) in enumerate(train_dataloader):
m = only_nir.size()[0]
img1, img2, same_label, only_nir = map(lambda x: x.to(device), [img1, img2, same_label, only_nir])
for i in range(m):
output_1 = translator(img1[i].unsqueeze(0))
if only_nir[i]:
output_2 = translator(img2[i].unsqueeze(0))
else:
output_2 = img2[i].unsqueeze(0)
loss = criterion(output_1, output_2, same_label[i].item())
loss.backward()
optimizer.step()
if i + 1 >= 4:
break
p.step()
break
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="...")
parser.add_argument("--config", type=str, default = 'configs/base', help="py config file")
main(parser.parse_args())