Skip to content

Commit

Permalink
Fixed keras model load and save
Browse files Browse the repository at this point in the history
  • Loading branch information
jernsting committed Oct 23, 2024
1 parent 91f5ad5 commit 92e933e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
27 changes: 18 additions & 9 deletions photonai/modelwrapper/keras_base_estimator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import os
import tensorflow.keras as keras
from sklearn.base import BaseEstimator

Expand Down Expand Up @@ -72,20 +73,28 @@ def encode_targets(self, y):

def save(self, filename):
# serialize model to JSON
warnings.warn("Using json export for compatibility, will be deprecated in future.")
model_json = self.model.to_json()
with open(filename + ".json", "w") as json_file:
json_file.write(model_json)
json_file.write(model_json)
# serialize weights to HDF5
self.model.save_weights(filename + ".weights.h5")
self.model.save(filename + ".keras")

def load(self, filename):
# load json and create model
json_file = open(filename + '.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = keras.models.model_from_json(loaded_model_json)
if not os.path.exists(filename+'.keras'):
warnings.warn("Using json import for compatiblity, will be deprecated in future. "
"Please save your model to get a *.keras file")
json_file = open(filename + '.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = keras.models.model_from_json(loaded_model_json)

loaded_model.load_weights(filename + ".weights.h5")
self.model = loaded_model
self.init_weights = self.model.get_weights()
else:
# load weights into new model
self.model = keras.models.load_model(filename + '.keras')

# load weights into new model
loaded_model.load_weights(filename + ".weights.h5")
self.model = loaded_model
self.init_weights = self.model.get_weights()
23 changes: 16 additions & 7 deletions test/modelwrapper_tests/test_keras_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sklearn.datasets import load_breast_cancer, load_diabetes
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.layers import Dense, Dropout, Input, Activation
import numpy as np
import warnings
import os
Expand All @@ -16,16 +16,17 @@ def setUp(self):
self.X, self.y = load_breast_cancer(return_X_y=True)

self.model = Sequential()
self.model.add(Dense(3, input_dim=self.X.shape[1], activation='relu'))
self.model.add(Input(shape=[self.X.shape[1]]))
self.model.add(Dense(3, activation="relu"))
self.model.add(Dropout(0.1))
self.model.add(Dense(2, activation='softmax'))
self.model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

self.estimator_type = KerasBaseClassifier

inputs = tf.keras.Input(shape=(self.X.shape[1],))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(2, activation=tf.nn.softmax)(x)
x = tf.keras.layers.Dense(4, activation=tf.keras.activations.relu)(inputs)
outputs = tf.keras.layers.Dense(2, activation=tf.keras.activations.softmax)(x)
self.tf_model = tf.keras.Model(inputs=inputs, outputs=outputs)
self.tf_model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

Expand Down Expand Up @@ -57,10 +58,18 @@ def test_tf_model(self):

estimator.save("keras_example_saved_model")

reload_estinator = self.estimator_type()
reload_estinator.load("keras_example_saved_model")
reload_estimator = self.estimator_type()
reload_estimator.load("keras_example_saved_model")

np.testing.assert_array_almost_equal(estimator.predict(self.X), reload_estimator.predict(self.X), decimal=3)

# remove novel keras file and test legacy import
os.remove("keras_example_saved_model.keras")

reload_estimator_legacy = self.estimator_type()
reload_estimator_legacy.load("keras_example_saved_model")

np.testing.assert_array_almost_equal(estimator.predict(self.X), reload_estinator.predict(self.X), decimal=3)
np.testing.assert_array_almost_equal(estimator.predict(self.X), reload_estimator.predict(self.X), decimal=3)

# remove saved keras files
for fname in os.listdir("."):
Expand Down

0 comments on commit 92e933e

Please sign in to comment.