forked from hsakas/siamese_similarity_model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrain.py
77 lines (57 loc) · 1.93 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import DownloadData
# pos + pos -> 1
# pos + neg -> 0
from keras.applications import VGG16
import os
import DataHandler as dh
import BaseModel as bm
from keras.models import Model
DATA_DIR = os.path.abspath("data")
IMAGE_TEMP_DIR = os.path.join(DATA_DIR, "tmp")
MODEL_DIR = os.path.join(DATA_DIR, "Models")
IMAGE_DIR = os.path.join(DATA_DIR, "images", "jpg")
IM_SIZE = 224
EPOCHS = 20
BATCH_SIZE = 32
print('Creating triples...')
triples = dh.create_image_triples(IMAGE_DIR)
print('Loading images...')
lhs, rhs, y = dh.load_image_triplets(image_dir=IMAGE_DIR,
image_triples=triples,
image_size=IM_SIZE, shuffle=True)
print('y', y.shape)
print('lhs', lhs.shape)
print('rhs', rhs.shape)
vgg_1 = VGG16(weights='imagenet', include_top=True)
vgg_2 = VGG16(weights='imagenet', include_top=True)
for layer in vgg_1.layers:
layer.trainable = False
layer.name = layer.name + "_1"
for layer in vgg_2.layers:
layer.trainable = False
layer.name = layer.name + "_2"
print('_'*12, 'VGG16', '-'*12)
vgg_1.summary()
v1 = vgg_1.get_layer("flatten_1").output
v2 = vgg_2.get_layer("flatten_2").output
pred = bm.sim_model(v1, v2)
model = Model(inputs=[vgg_1.input, vgg_2.input], outputs=pred)
print('_'*12, 'SIAMESE', '-'*12)
model.summary()
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit([lhs, rhs], y, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_split=0.2)
if not os.path.isdir(MODEL_DIR):
os.makedirs(MODEL_DIR)
print('Model directory created!')
# serialize model to json
model_json = model.to_json()
model_id = "trained"
# Set paths
model_name = model_id + ".json"
weights_name = model_id + ".h5"
model_path = os.path.join(MODEL_DIR, model_name)
weights_path = os.path.join(MODEL_DIR, weights_name)
with open(model_path, "w") as json_file:
json_file.write(model_json)
model.save_weights(weights_path)
print('Model Saved')