-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtools.py
56 lines (50 loc) · 2.01 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from torch import nn
from torch.optim import lr_scheduler
import os
def get_optimizer(model,
optim,
learning_rate,
momentum,
weight_decay,
amsgrad=False,):
if optim == 'sgd':
optim_module = torch.optim.SGD
optim_param = {"lr" : learning_rate,
"momentum": momentum}
if weight_decay != None:
optim_param["weight_decay"] = weight_decay
elif optim == "adam":
optim_module = torch.optim.Adam
optim_param = {"lr": learning_rate,
"weight_decay": weight_decay,
"amsgrad": amsgrad}
else:
print("Not supported")
optimizer = optim_module(
filter(lambda x : x.requires_grad, model.parameters()),
**optim_param
)
return optimizer
def get_scheduler(optimizer, decay_step, gamma=0.1):
scheduler = lr_scheduler.StepLR(
optimizer,
step_size=decay_step,
gamma=gamma # Decay ratio fixed to 0.1
)
return scheduler
def get_dir_name(out_dir, image_pattern, image_type, image_size, demosaic_algo, bayer_pattern, crop, crop_size, log_name):
path = os.path.join(out_dir, "_".join((image_pattern, str(image_size), "_".join([demosaic_algo, bayer_pattern]))))
return os.path.join(path, image_type, get_crop_name(crop, crop_size), log_name)
def get_crop_name(crop, crop_size):
crop_details = [crop]
if crop != 'none': crop_details += [str(crop_size)]
return "_".join(crop_details)
def get_log_name(args):
logging_details = ['arch', args.model_architecture,
'optim', args.optimizer,
'lr', str(args.learning_rate),
'decaystep', str(args.decay_step),
'wd', str(args.weight_decay),
'batch', str(args.batch_size)]
return "_".join(logging_details)