forked from binhdt95/Jaist-MicroNet-Challenge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
executable file
·83 lines (71 loc) · 3.05 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from model import RNNModel
import math
import counting
import torch
import time
import numpy as np
import torch
from utils import batchify, get_batch, repackage_hidden
import collections
def read_model(model):
assert isinstance(model,RNNModel)
assert model.rnn_type == 'QRNN'
ops = []
encoder, rnns, decoder = model.encoder, model.rnns, model.decoder
ops+=[
('embedding',counting.Embedding(input_size=encoder.weight.size()[0], n_channels=encoder.weight.size()[1]))
]
ops+=[('block_qrnn',[('qrnn_%d'%i,counting.QRNN(input_size=l.input_size,
hidden_size=l.hidden_size,
window=l.window, output_gate=l.output_gate))
for i,l in enumerate(rnns)])
]
ops+=[
('block_decoder',[
# separate weight and bias for convenient counting with tied weights
('decoder_weight',counting.FullyConnected(kernel_shape=(decoder.weight.size()[1], decoder.weight.size()[0]),
use_bias=False, tied_weights=model.tie_weights)),
('decoder_bias',counting.FullyConnected(kernel_shape=(0, decoder.weight.size()[0]),
use_bias=getattr(decoder,'bias', None) is not None)),
('max',counting.GlobalMax(input_size=1, n_channels=decoder.weight.size()[0]))
])
]
return ops
def evaluate(data_source, batch_size=10):
# Turn on evaluation mode which disables dropout.
model.eval()
if model.rnn_type == 'QRNN': model.reset()
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, args, evaluation=True)
output, hidden = model(data, hidden)
total_loss += len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data
hidden = repackage_hidden(hidden)
return total_loss.item() / len(data_source)
def model_load(fn):
global model, criterion, optimizer
with open(fn, 'rb') as f:
model, criterion, optimizer = torch.load(f)
args=(collections.namedtuple('Args',['bptt','cuda']))(bptt=140,cuda=True)
# load model
model_load('WT103.12hr.QRNN.pt')
# load test data: read vocab, process test text
corpus = torch.load('corpus-wikitext-103.vocab-only.data')
test_data = corpus.tokenize('data/wikitext-103/test.txt')
test_data = batchify(test_data, 1, args)
# Run on test data.
test_loss = evaluate(test_data, 1)
print('=' * 89)
print('Test ppl {:8.2f} '.format(math.exp(test_loss)))
print('=' * 89)
# read model ops
ops = read_model(model)
# print model MFLOPS and #Parameters
counter = counting.MicroNetCounter(ops, add_bits_base=32, mul_bits_base=32)
INPUT_BITS = 16
ACCUMULATOR_BITS = 32
PARAMETER_BITS = INPUT_BITS
SUMMARIZE_BLOCKS = True
counter.print_summary(0, PARAMETER_BITS, ACCUMULATOR_BITS, INPUT_BITS, summarize_blocks=SUMMARIZE_BLOCKS)