-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
29 lines (25 loc) · 909 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Imports dependencies
import tensorflow as tf
import tensorflow_addons as tfa
import pandas as pd
import numpy as np
import random
from dataloader import load_data
from model import build_model
# Loads data
train_in, train_out, val_in, val_out, test_in, test_out = load_data()
# Builds model and loads weights if present
model = build_model()
# Callbacks
callback = tf.keras.callbacks.EarlyStopping(monitor='val_binary_accuracy',
mode='max', patience=10,
restore_best_weights=True)
# Fits model and saves weights at the end
history = model.fit(train_in,
train_out,
validation_data=(val_in, val_out),
epochs=10,
batch_size=8,
callbacks=[callback],
verbose=1)
model.save_weights("./my_model/ckpt")