-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcommons.py
96 lines (82 loc) · 3.47 KB
/
commons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
This file contains some functions and classes which can be useful in very diverse projects.
"""
import os
import sys
import torch
import random
import logging
import traceback
import numpy as np
from os.path import join
def make_deterministic(seed=0):
"""Make results deterministic. If seed == -1, do not make deterministic.
Running the script in a deterministic way might slow it down.
"""
if seed == -1:
return
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_output_dim(model, pooling_type="gem"):
"""Dinamically compute the output size of a model.
"""
output_dim = model(torch.ones([2, 3, 224, 224])).shape[1]
if pooling_type == "netvlad":
output_dim *= 64 # NetVLAD layer has 64x bigger output dimensions
return output_dim
class InfiniteDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_iterator = super().__iter__()
def __iter__(self):
return self
def __next__(self):
try:
batch = next(self.dataset_iterator)
except StopIteration:
self.dataset_iterator = super().__iter__()
batch = next(self.dataset_iterator)
return batch
def setup_logging(save_dir, console="debug",
info_filename="info.log", debug_filename="debug.log"):
"""Set up logging files and console output.
Creates one file for INFO logs and one for DEBUG logs.
Args:
save_dir (str): creates the folder where to save the files.
debug (str):
if == "debug" prints on console debug messages and higher
if == "info" prints on console info messages and higher
if == None does not use console (useful when a logger has already been set)
info_filename (str): the name of the info file. if None, don't create info file
debug_filename (str): the name of the debug file. if None, don't create debug file
"""
if os.path.exists(save_dir):
raise FileExistsError(f"{save_dir} already exists!")
os.makedirs(save_dir, exist_ok=True)
# logging.Logger.manager.loggerDict.keys() to check which loggers are in use
base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S")
logger = logging.getLogger('')
logger.setLevel(logging.DEBUG)
if info_filename != None:
info_file_handler = logging.FileHandler(join(save_dir, info_filename))
info_file_handler.setLevel(logging.INFO)
info_file_handler.setFormatter(base_formatter)
logger.addHandler(info_file_handler)
if debug_filename != None:
debug_file_handler = logging.FileHandler(join(save_dir, debug_filename))
debug_file_handler.setLevel(logging.DEBUG)
debug_file_handler.setFormatter(base_formatter)
logger.addHandler(debug_file_handler)
if console != None:
console_handler = logging.StreamHandler()
if console == "debug": console_handler.setLevel(logging.DEBUG)
if console == "info": console_handler.setLevel(logging.INFO)
console_handler.setFormatter(base_formatter)
logger.addHandler(console_handler)
def exception_handler(type_, value, tb):
logger.info("\n" + "".join(traceback.format_exception(type, value, tb)))
sys.excepthook = exception_handler