-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainModel.py
46 lines (35 loc) · 1.74 KB
/
TrainModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from Preprocessing import Preprocessing
from Model import build_model
import pickle
import numpy as np
from keras.callbacks import ModelCheckpoint
import params
preprocessing = Preprocessing(size=params.VOCAB_SIZE)
# Loading the train data
with open('train_data.pkl', 'rb') as f:
train_data = pickle.load(f)
texts = []
for item in train_data:
for i in item[0:3]:
texts.append(i)
preprocessing.create_vocabulary(texts)
print('Vocabulary size: {}'.format(preprocessing.vocab_size))
preprocessed_train_data = list()
for sample in train_data:
if sample[4] < params.CONTEXT_LEN or sample[3] == -1: # Only considering contexts smaller than 325 words
preprocessed_train_data.append(preprocessing.text_to_seq_sample(sample))
with open('preprocessing.pkl', 'wb') as f:
pickle.dump(preprocessing, f)
print('Done with processing..')
model = build_model(preprocessing.tokenizer)
filepath = "Models/weights-improvement-{epoch:02d}-{val_loss:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]
contexts = np.array([c[0] for c in preprocessed_train_data]).reshape(-1, params.CONTEXT_LEN)
questions = np.array([q[1] for q in preprocessed_train_data]).reshape(-1, params.QUESTION_LEN)
target_start = np.array([x[3] for x in preprocessed_train_data]).reshape(-1, params.CONTEXT_LEN)
target_end = np.array([x[4] for x in preprocessed_train_data]).reshape(-1, params.CONTEXT_LEN)
is_no_ans = np.array([x[5] for x in preprocessed_train_data]).reshape(-1, 2)
model.fit([questions, contexts], [target_start, target_end, is_no_ans], epochs=10,
validation_split=0.05, callbacks=callbacks_list, batch_size=64, verbose=1)
model.save('Models/model_large.hdf5')