Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difference in training with model.fit and with tf.GradientTape #19341

Open
markodjordjic opened this issue Mar 20, 2024 · 4 comments
Open

Difference in training with model.fit and with tf.GradientTape #19341

markodjordjic opened this issue Mar 20, 2024 · 4 comments
Assignees
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:bug/performance

Comments

@markodjordjic
Copy link

OS: Windows 11
Python == 3.11.0 64 bit
Keras == 3.0.5

import time
import os
from math import floor
import tensorflow as tf
tf.config.experimental.enable_op_determinism()
os.environ['KERAS_BACKEND'] = 'tensorflow'
import keras
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def make_model_configuration():
    inputs = keras.layers.Input(shape=(feature_count,), name='input_layer')
    layer_1 = keras.layers.Dense(
        units=feature_count,
        kernel_initializer='lecun_normal',
        activation='selu',
        kernel_constraint=keras.constraints.MaxNorm(2),
        bias_constraint=keras.constraints.MaxNorm(2)
    )(inputs)
    layer_2 = keras.layers.Dense(
        units=feature_count,
        kernel_initializer='lecun_normal',
        activation='selu',
        kernel_constraint=keras.constraints.MaxNorm(2),
        bias_constraint=keras.constraints.MaxNorm(2)
    )(layer_1)
    output_layer = keras.layers.Dense(
        units=1,
        activation=None
    )(layer_2)
    model = keras.Model(inputs, output_layer)
    
    # Only model configuration is returned.
    return model.get_config() 


def get_optimizer(learning_rate: float) -> keras.Optimizer:

    return keras.optimizers.Adam(learning_rate=learning_rate)


def generate_raw_data(case_count, feature_count) -> tuple:
    """Generated data includes both features and targets, returned
    within the `tuple` object.

    """
    features = np.random.randint(low=0, high=100, size=(
        case_count, feature_count
    ))
    targets = np.random.gamma(shape=2, size=(case_count, 1))

    return features, targets


def split_data(propotion, features, targets):
    cv_split_index = floor(case_count*propotion)  # Safe rounding.
    training_features = features[0:cv_split_index, :]
    training_targets = targets[0:cv_split_index, :]
    testing_features = features[cv_split_index:, ]
    testing_targets = targets[cv_split_index:, ]

    return training_features, training_targets, testing_features, \
        testing_targets


def prepare_data(training_features, 
                 training_targets, 
                 training_batch_size,
                 testing_features,
                 testing_targets,
                 validation_batch_size) -> tf.data.Dataset:
    # Prepare the training dataset.
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (training_features, training_targets)
    ).batch(batch_size=training_batch_size)

    # Prepare the validation dataset.
    val_dataset = tf.data.Dataset.from_tensor_slices(
        (testing_features, testing_targets)
    ).batch(batch_size=validation_batch_size)

    return train_dataset, val_dataset


def declare_loss_and_metrics() -> tuple:
    # Declare loss functions
    loss_function = keras.losses.MeanSquaredError()

    # Declare metrics to record loss and accuracy.
    training_loss = keras.metrics.MeanSquaredError()
    training_accuracy = keras.metrics.MeanAbsoluteError()
    testing_loss = keras.metrics.MeanSquaredError()
    testing_accuaracy = keras.metrics.MeanAbsoluteError()

    return loss_function, training_loss, training_accuracy, testing_loss, \
        testing_accuaracy


def train_w_fit(train_dataset) -> pd.DataFrame:

    keras.backend.clear_session()

    model = keras.Model.from_config(make_model_configuration())

    optimizer = get_optimizer(learning_rate=learning_rate)

    loss_function, training_loss, training_accuracy, testing_loss, \
        testing_accuaracy = declare_loss_and_metrics()

    model.compile(
        loss=loss_function,
        optimizer=optimizer,
        metrics=[training_accuracy],
        auto_scale_loss=False  # To make similar to `tf.GradientTape`
    )

    training_log = model.fit(
        x=train_dataset, 
        epochs=epochs
    )

    training_history = pd.DataFrame.from_records([
        training_log.history['loss'],
        training_log.history['mean_absolute_error']
    ]).T

    return training_history


def train_w_gradient_tape(train_dataset, val_dataset) -> pd.DataFrame:

    keras.backend.clear_session()

    model = keras.Model.from_config(make_model_configuration())

    optimizer = get_optimizer(learning_rate=learning_rate)

    loss_function, training_loss, training_accuracy, testing_loss, \
        testing_accuaracy = declare_loss_and_metrics()

    training_log = []
    for epoch in range(epochs):
        training_loss.reset_state()
        training_accuracy.reset_state()
        testing_loss.reset_state()
        testing_accuaracy.reset_state()
        print('Epoch: %s out of %s.' % (epoch+1, epochs))
        start_time = time.time()

        # Iterate over the batches of the dataset.
        for _, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            with tf.GradientTape() as tape:
                predictions = model(x_batch_train, training=True)
                loss_value = loss_function(
                    y_true=y_batch_train, y_pred=predictions
                )
            gradients = tape.gradient(loss_value, model.trainable_weights)
            optimizer.apply(gradients, model.trainable_weights)
            training_loss.update_state(y_batch_train, predictions)
            training_accuracy.update_state(y_batch_train, predictions)

        # Run a validation loop at the end of each epoch.
        for _, (x_batch_val, y_batch_val) in enumerate(val_dataset):
            validation_predictions = model(x_batch_val, training=False)
            testing_loss.update_state(y_batch_val, validation_predictions)
            testing_accuaracy.update_state(y_batch_val, validation_predictions)

        end_time = time.time()
        duration_in_seconds =  round((end_time-start_time)/60, 2)
        print('Epoch trainin duration: %s min.' % duration_in_seconds)

        training_log.extend([
            [
                training_loss.result().numpy(),
                training_accuracy.result().numpy(),
                testing_loss.result().numpy(),
                testing_accuaracy.result().numpy()
            ]
        ])

    training_history = pd.DataFrame(training_log)
    training_history.columns = [
        'mse_training', 
        'mae_training',
        'mse_testing',
        'mae_testing'
    ]
    training_history.index = range(1, epochs+1)

    return training_history

    # predictions_gradient_tape = model(x_batch_val)
    # predictions_model_fit = model.predict(x_batch_val)


if __name__ == '__main__':

    seed = 0
    case_count = 100
    feature_count = 10
    propotion = .8  # Proportion of data to use for CV.
    learning_rate = 1e-3
    epochs = 10
    training_batch_size = 2

    keras.utils.set_random_seed(seed=seed)

    # Generate features and targets.  
    features, targets = generate_raw_data(
        case_count=case_count, feature_count=feature_count
    )

    # Split raw data.
    training_features, training_targtes, testing_features, testing_tagets = \
        split_data(
            propotion=propotion, 
            features=features, 
            targets=targets
        )
    
    # Administer raw data via `tf.data.Dataset` class.
    training_dataset, validation_dataset = prepare_data(
        training_features=training_features,
        training_targets=training_targtes,
        testing_features=testing_features,
        testing_targets=testing_tagets,
        training_batch_size=training_batch_size,
        validation_batch_size=case_count-floor(case_count*propotion)
    )

    # Training w/ fit.
    training_w_fit_report = train_w_fit(
        train_dataset=training_dataset
    )

    # Training w/ gradient tape.
    training_w_gradient_tape_report = train_w_gradient_tape(
        train_dataset=training_dataset,
        val_dataset=validation_dataset
    )

    # Comparison.
    np.testing.assert_almost_equal(
        actual=training_w_fit_report.values,
        desired=training_w_gradient_tape_report.iloc[:,0:2].values
    )

In the function to create a model, get_config is used instead of returning a model. This has proven to be better than returning a model itself.

@james77777778
Copy link
Contributor

Hi @markodjordjic
Just came across this issue :)

I think you need to make 2 modifications to ensure the same result:

  1. Ensure that the initial model weights are the same.
  2. Decorate train step using tf.function(jit_compile=True) to align with model.fit(...)
import time
from math import floor

import numpy as np
import pandas as pd
import tensorflow as tf

import keras
from keras import losses
from keras import metrics


def make_model_configuration(feature_count):
    inputs = keras.layers.Input(shape=(feature_count,), name="input_layer")
    layer_1 = keras.layers.Dense(
        units=feature_count,
        kernel_initializer="lecun_normal",
        activation="selu",
        kernel_constraint=keras.constraints.MaxNorm(2),
        bias_constraint=keras.constraints.MaxNorm(2),
    )(inputs)
    layer_2 = keras.layers.Dense(
        units=feature_count,
        kernel_initializer="lecun_normal",
        activation="selu",
        kernel_constraint=keras.constraints.MaxNorm(2),
        bias_constraint=keras.constraints.MaxNorm(2),
    )(layer_1)
    output_layer = keras.layers.Dense(units=1, activation=None)(layer_2)
    model = keras.Model(inputs, output_layer)

    # Only model configuration is returned.
    return model.get_config(), model.get_weights()


def get_optimizer(learning_rate: float):
    return keras.optimizers.Adam(learning_rate=learning_rate)


def generate_raw_data(case_count, feature_count) -> tuple:
    """Generated data includes both features and targets, returned
    within the `tuple` object.

    """
    features = np.random.randint(
        low=0, high=100, size=(case_count, feature_count)
    )
    targets = np.random.gamma(shape=2, size=(case_count, 1))
    return features, targets


def split_data(propotion, features, targets):
    cv_split_index = floor(case_count * propotion)  # Safe rounding.
    training_features = features[0:cv_split_index, :]
    training_targets = targets[0:cv_split_index, :]
    testing_features = features[cv_split_index:,]
    testing_targets = targets[cv_split_index:,]
    return (
        training_features,
        training_targets,
        testing_features,
        testing_targets,
    )


def prepare_data(
    training_features,
    training_targets,
    training_batch_size,
    testing_features,
    testing_targets,
    validation_batch_size,
) -> tf.data.Dataset:
    # Prepare the training dataset.
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (training_features, training_targets)
    ).batch(batch_size=training_batch_size)

    # Prepare the validation dataset.
    val_dataset = tf.data.Dataset.from_tensor_slices(
        (testing_features, testing_targets)
    ).batch(batch_size=validation_batch_size)
    return train_dataset, val_dataset


def declare_loss_and_metrics() -> tuple:
    # Declare loss functions
    loss_function = losses.MeanSquaredError()

    # Declare metrics to record loss and accuracy.
    training_loss = metrics.MeanSquaredError()
    training_accuracy = metrics.MeanAbsoluteError()
    testing_loss = metrics.MeanSquaredError()
    testing_accuaracy = metrics.MeanAbsoluteError()
    return (
        loss_function,
        training_loss,
        training_accuracy,
        testing_loss,
        testing_accuaracy,
    )


def train_w_fit(train_dataset, model_config, model_weights) -> pd.DataFrame:
    model = keras.Model.from_config(model_config)
    model.set_weights(model_weights)  # <---
    optimizer = get_optimizer(learning_rate=learning_rate)
    (
        loss_function,
        training_loss,
        training_accuracy,
        testing_loss,
        testing_accuaracy,
    ) = declare_loss_and_metrics()
    model.compile(
        loss=loss_function,
        optimizer=optimizer,
        metrics=[training_accuracy],
        auto_scale_loss=False,  # To make similar to `tf.GradientTape`
    )
    training_log = model.fit(x=train_dataset, epochs=epochs)
    training_history = pd.DataFrame.from_records(
        [
            training_log.history["loss"],
            training_log.history["mean_absolute_error"],
        ]
    ).T
    return training_history


def train_w_gradient_tape(
    train_dataset, val_dataset, model_config, model_weights
) -> pd.DataFrame:
    model = keras.Model.from_config(model_config)
    model.set_weights(model_weights)  # <---
    optimizer = get_optimizer(learning_rate=learning_rate)
    (
        loss_function,
        training_loss,
        training_accuracy,
        testing_loss,
        testing_accuaracy,
    ) = declare_loss_and_metrics()
    training_log = []

    @tf.function(jit_compile=True)  # <---
    def train_step(x_batch_train, y_batch_train):
        with tf.GradientTape() as tape:
            predictions = model(x_batch_train, training=True)
            loss_value = loss_function(y_true=y_batch_train, y_pred=predictions)
        gradients = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply(gradients, model.trainable_weights)
        return predictions

    for epoch in range(epochs):
        training_loss.reset_state()
        training_accuracy.reset_state()
        testing_loss.reset_state()
        testing_accuaracy.reset_state()
        print("Epoch: %s out of %s." % (epoch + 1, epochs))
        start_time = time.time()

        # Iterate over the batches of the dataset.
        for _, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            predictions = train_step(x_batch_train, y_batch_train)
            training_loss.update_state(y_batch_train, predictions)
            training_accuracy.update_state(y_batch_train, predictions)

        # Run a validation loop at the end of each epoch.
        for _, (x_batch_val, y_batch_val) in enumerate(val_dataset):
            validation_predictions = model(x_batch_val, training=False)
            testing_loss.update_state(y_batch_val, validation_predictions)
            testing_accuaracy.update_state(y_batch_val, validation_predictions)

        end_time = time.time()
        duration_in_seconds = round((end_time - start_time) / 60, 2)
        print("Epoch trainin duration: %s min." % duration_in_seconds)

        training_log.extend(
            [
                [
                    training_loss.result().numpy(),
                    training_accuracy.result().numpy(),
                    testing_loss.result().numpy(),
                    testing_accuaracy.result().numpy(),
                ]
            ]
        )

    training_history = pd.DataFrame(training_log)
    training_history.columns = [
        "mse_training",
        "mae_training",
        "mse_testing",
        "mae_testing",
    ]
    training_history.index = range(1, epochs + 1)
    return training_history


if __name__ == "__main__":
    seed = 0
    case_count = 100
    feature_count = 10
    propotion = 0.8  # Proportion of data to use for CV.
    learning_rate = 1e-3
    epochs = 10
    training_batch_size = 2

    keras.utils.set_random_seed(seed=seed)

    # Generate features and targets.
    features, targets = generate_raw_data(
        case_count=case_count, feature_count=feature_count
    )

    # Split raw data.
    training_features, training_targtes, testing_features, testing_tagets = (
        split_data(propotion=propotion, features=features, targets=targets)
    )

    # Administer raw data via `tf.data.Dataset` class.
    training_dataset, validation_dataset = prepare_data(
        training_features=training_features,
        training_targets=training_targtes,
        testing_features=testing_features,
        testing_targets=testing_tagets,
        training_batch_size=training_batch_size,
        validation_batch_size=case_count - floor(case_count * propotion),
    )

    # Get model weights
    model_config, model_weights = make_model_configuration(feature_count)

    # Training w/ fit.
    training_w_fit_report = train_w_fit(
        training_dataset, model_config, model_weights
    )

    # Training w/ gradient tape.
    training_w_gradient_tape_report = train_w_gradient_tape(
        training_dataset, validation_dataset, model_config, model_weights
    )

    # Comparison.
    np.testing.assert_almost_equal(
        actual=training_w_fit_report.values,
        desired=training_w_gradient_tape_report.iloc[:, 0:2].values,
    )

I passed the final assertion line on my machine.

@markodjordjic
Copy link
Author

@james77777778 keras.utils.set_random_seed(seed=seed) should resolve the initialization of the weights to the same values. This requires further investigation, since addition of the decorator should only have an impact on performance, not on computation. Also, separate return of the weights and configuration is is concerning. Thank you for your contribution.

@markodjordjic
Copy link
Author

markodjordjic commented Mar 20, 2024

@james77777778 thanks for your contribution. I also was able to pass the assertion with configuration and weights being separated. The improvement is not in the separation per se, but in the single initialization. Decoration was not necessary. But still there should be some more insight and perhaps clearer documentation.

@markodjordjic
Copy link
Author

It is necessary to reconstitute the model from the weights and configuration. If the whole model is returned, it seems to be returned only as a reference not a completely new object, and therefore it is being trained twice. There could be a serious side effect for training of the models within the loop. This should definitely be looked at.

@sachinprasadhs sachinprasadhs added type:bug/performance keras-team-review-pending Pending review by a Keras team member. labels Mar 21, 2024
@sampathweb sampathweb removed the keras-team-review-pending Pending review by a Keras team member. label Mar 28, 2024
@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:bug/performance
Projects
None yet
Development

No branches or pull requests

5 participants