From 92e933ed9a9fee998f255812ca227f9b822c5b7b Mon Sep 17 00:00:00 2001 From: Jan Ernsting Date: Wed, 23 Oct 2024 16:01:45 +0200 Subject: [PATCH] Fixed keras model load and save --- photonai/modelwrapper/keras_base_estimator.py | 27 ++++++++++++------- test/modelwrapper_tests/test_keras_basic.py | 23 +++++++++++----- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/photonai/modelwrapper/keras_base_estimator.py b/photonai/modelwrapper/keras_base_estimator.py index 08c37934..5c211217 100644 --- a/photonai/modelwrapper/keras_base_estimator.py +++ b/photonai/modelwrapper/keras_base_estimator.py @@ -1,4 +1,5 @@ import warnings +import os import tensorflow.keras as keras from sklearn.base import BaseEstimator @@ -72,20 +73,28 @@ def encode_targets(self, y): def save(self, filename): # serialize model to JSON + warnings.warn("Using json export for compatibility, will be deprecated in future.") model_json = self.model.to_json() with open(filename + ".json", "w") as json_file: - json_file.write(model_json) + json_file.write(model_json) # serialize weights to HDF5 self.model.save_weights(filename + ".weights.h5") + self.model.save(filename + ".keras") def load(self, filename): # load json and create model - json_file = open(filename + '.json', 'r') - loaded_model_json = json_file.read() - json_file.close() - loaded_model = keras.models.model_from_json(loaded_model_json) + if not os.path.exists(filename+'.keras'): + warnings.warn("Using json import for compatiblity, will be deprecated in future. " + "Please save your model to get a *.keras file") + json_file = open(filename + '.json', 'r') + loaded_model_json = json_file.read() + json_file.close() + loaded_model = keras.models.model_from_json(loaded_model_json) + + loaded_model.load_weights(filename + ".weights.h5") + self.model = loaded_model + self.init_weights = self.model.get_weights() + else: + # load weights into new model + self.model = keras.models.load_model(filename + '.keras') - # load weights into new model - loaded_model.load_weights(filename + ".weights.h5") - self.model = loaded_model - self.init_weights = self.model.get_weights() diff --git a/test/modelwrapper_tests/test_keras_basic.py b/test/modelwrapper_tests/test_keras_basic.py index b6d92f77..b84673e7 100644 --- a/test/modelwrapper_tests/test_keras_basic.py +++ b/test/modelwrapper_tests/test_keras_basic.py @@ -1,7 +1,7 @@ from sklearn.datasets import load_breast_cancer, load_diabetes import tensorflow as tf from tensorflow.keras.models import Sequential -from tensorflow.keras.layers import Dense, Dropout +from tensorflow.keras.layers import Dense, Dropout, Input, Activation import numpy as np import warnings import os @@ -16,7 +16,8 @@ def setUp(self): self.X, self.y = load_breast_cancer(return_X_y=True) self.model = Sequential() - self.model.add(Dense(3, input_dim=self.X.shape[1], activation='relu')) + self.model.add(Input(shape=[self.X.shape[1]])) + self.model.add(Dense(3, activation="relu")) self.model.add(Dropout(0.1)) self.model.add(Dense(2, activation='softmax')) self.model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) @@ -24,8 +25,8 @@ def setUp(self): self.estimator_type = KerasBaseClassifier inputs = tf.keras.Input(shape=(self.X.shape[1],)) - x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) - outputs = tf.keras.layers.Dense(2, activation=tf.nn.softmax)(x) + x = tf.keras.layers.Dense(4, activation=tf.keras.activations.relu)(inputs) + outputs = tf.keras.layers.Dense(2, activation=tf.keras.activations.softmax)(x) self.tf_model = tf.keras.Model(inputs=inputs, outputs=outputs) self.tf_model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) @@ -57,10 +58,18 @@ def test_tf_model(self): estimator.save("keras_example_saved_model") - reload_estinator = self.estimator_type() - reload_estinator.load("keras_example_saved_model") + reload_estimator = self.estimator_type() + reload_estimator.load("keras_example_saved_model") + + np.testing.assert_array_almost_equal(estimator.predict(self.X), reload_estimator.predict(self.X), decimal=3) + + # remove novel keras file and test legacy import + os.remove("keras_example_saved_model.keras") + + reload_estimator_legacy = self.estimator_type() + reload_estimator_legacy.load("keras_example_saved_model") - np.testing.assert_array_almost_equal(estimator.predict(self.X), reload_estinator.predict(self.X), decimal=3) + np.testing.assert_array_almost_equal(estimator.predict(self.X), reload_estimator.predict(self.X), decimal=3) # remove saved keras files for fname in os.listdir("."):