Skip to content

Commit

Permalink
Merge pull request #1263 from ASolomatin/master
Browse files Browse the repository at this point in the history
fix: Support for training a multi-input model using a dataset.
  • Loading branch information
Oceania2018 authored Jul 2, 2024
2 parents 7fb73cd + 93dda17 commit 6a2d7e1
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 2 deletions.
14 changes: 13 additions & 1 deletion src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,19 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
Steps = data_handler.Inferredsteps
});

return evaluate(data_handler, callbacks, is_val, test_function);
Func<DataHandler, OwnedIterator, Dictionary<string, float>> testFunction;

if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
testFunction = test_step_multi_inputs_function;
}
else
{
testFunction = test_function;
}

return evaluate(data_handler, callbacks, is_val, testFunction);
}

/// <summary>
Expand Down
13 changes: 12 additions & 1 deletion src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,20 @@ public ICallback fit(IDatasetV2 dataset,
StepsPerExecution = _steps_per_execution
});

Func<DataHandler, OwnedIterator, Dictionary<string, float>> trainStepFunction;

if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
trainStepFunction = train_step_multi_inputs_function;
}
else
{
trainStepFunction = train_step_function;
}

return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data,
train_step_func: train_step_function);
train_step_func: trainStepFunction);
}

History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
Expand Down
82 changes: 82 additions & 0 deletions test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System;
using Tensorflow.Keras.Optimizers;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.UnitTest
Expand Down Expand Up @@ -54,10 +55,91 @@ public void LeNetModel()
var x = new NDArray[] { x1, x2 };
model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);

x1 = x1["0:8"];
x2 = x1;

x = new NDArray[] { x1, x2 };
var y = dataset.Train.Labels["0:8"];
(model as Engine.Model).evaluate(x, y);

x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT);
x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT);
var pred = model.predict((x1, x2));
Console.WriteLine(pred);
}

[TestMethod]
public void LeNetModelDataset()
{
var inputs = keras.Input((28, 28, 1));
var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1);
var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1);
var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2);
var flat1 = keras.layers.Flatten().Apply(pool2);

var inputs_2 = keras.Input((28, 28, 1));
var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2);
var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2);
var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2);
var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2);
var flat1_2 = keras.layers.Flatten().Apply(pool2_2);

var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
var output = keras.layers.Softmax(-1).Apply(dense3);

var model = keras.Model((inputs, inputs_2), output);
model.summary();

var data_loader = new MnistModelLoader();

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 59900,
}).Result;

var loss = keras.losses.SparseCategoricalCrossentropy();
var optimizer = new Adam(0.001f);
model.compile(optimizer, loss, new string[] { "accuracy" });

NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));

var multiInputDataset = tf.data.Dataset.zip(
tf.data.Dataset.from_tensor_slices(x1),
tf.data.Dataset.from_tensor_slices(x1),
tf.data.Dataset.from_tensor_slices(dataset.Train.Labels)
).batch(8);
multiInputDataset.FirstInputTensorCount = 2;

model.fit(multiInputDataset, epochs: 3);

x1 = x1["0:8"];

multiInputDataset = tf.data.Dataset.zip(
tf.data.Dataset.from_tensor_slices(x1),
tf.data.Dataset.from_tensor_slices(x1),
tf.data.Dataset.from_tensor_slices(dataset.Train.Labels["0:8"])
).batch(8);
multiInputDataset.FirstInputTensorCount = 2;

(model as Engine.Model).evaluate(multiInputDataset);

x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT);
var x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT);

multiInputDataset = tf.data.Dataset.zip(
tf.data.Dataset.from_tensor_slices(x1),
tf.data.Dataset.from_tensor_slices(x2)
).batch(8);
multiInputDataset.FirstInputTensorCount = 2;

var pred = model.predict(multiInputDataset);
Console.WriteLine(pred);
}
}
}

0 comments on commit 6a2d7e1

Please sign in to comment.