-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
74 lines (56 loc) · 2.72 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch.autograd
from tqdm import tqdm
import sys
import src.utils as utils
import src.inference as infSrc
class Inferer(infSrc.Inferer):
def run_n_minibatches(self, params, test_loader, model, numMB) :
#{{{
model.eval()
for batch_idx, (inputs, targets) in tqdm(enumerate(test_loader), total=len(test_loader)-1, desc='inference', leave=False) :
# move inputs and targets to GPU
with torch.no_grad():
device = 'cuda:' + str(params.gpuList[0])
if params.use_cuda :
inputs, targets = inputs.cuda(device, non_blocking=True), targets.cuda(device, non_blocking=True)
# perform inference
outputs = model(inputs)
if (batch_idx + 1) == numMB:
return
#}}}
def run_single_minibatch(self, params, test_loader, model) :
#{{{
model.eval()
inputs, targets = next(iter(test_loader))
with torch.no_grad():
if params.use_cuda:
device = 'cuda:' + str(params.gpu_id)
inputs, targets = inputs.cuda(device, non_blocking=True), targets.cuda( device, non_blocking=True)
outputs = model(inputs)
return
#}}}
def test_network(self, params, test_loader, model, criterion, optimiser, verbose=True) :
#{{{
model.eval()
losses = utils.AverageMeter()
top1 = utils.AverageMeter()
top5 = utils.AverageMeter()
for batch_idx, (inputs, targets) in tqdm(enumerate(test_loader), total=len(test_loader)-1, desc='inference', leave=False) :
# move inputs and targets to GPU
with torch.no_grad():
device = 'cuda:' + str(params.gpuList[0])
if params.use_cuda :
inputs, targets = inputs.cuda(device, non_blocking=True), targets.cuda(device, non_blocking=True)
# perform inference
outputs = model(inputs)
loss = criterion(outputs, targets)
prec1, prec5 = utils.accuracy(outputs.data, targets.data)
losses.update(loss.item())
top1.update(prec1.item())
top5.update(prec5.item())
# if params.evaluate == True or (params.finetune == False and (params.entropy or params.pruneFilters)):
# if params.evaluate or params.entropy or (params.pruneFilters and not params.finetune):
if verbose:
tqdm.write('Loss: {}, Top1: {}, Top5: {}'.format(losses.avg, top1.avg, top5.avg))
return (losses.avg, top1.avg, top5.avg)
#}}}