diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs index 78882e82d..ba0332836 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs @@ -1,5 +1,6 @@ using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; namespace Tensorflow.Keras.ArgsDefinition { @@ -16,5 +17,7 @@ public class DataAdapterArgs: IKerasConfig public int Worker { get; set; } public bool UseMultiprocessing { get; set; } public IModel Model { get; set; } + public Dictionary ClassWeight = null; + public NDArray SampleWeight = null; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs index 82530e950..72d0bb811 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs @@ -1,5 +1,6 @@ using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; namespace Tensorflow.Keras.ArgsDefinition { @@ -18,5 +19,7 @@ public class DataHandlerArgs: IKerasConfig public bool UseMultiprocessing { get; set; } = false; public IModel Model { get; set; } public IVariableV1 StepsPerExecution { get; set; } + public Dictionary ClassWeight = null; + public NDArray SampleWeight = null; } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs index 19f3df9ba..1840f88b9 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs @@ -3,6 +3,7 @@ using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Saving; using Tensorflow.NumPy; +using Tensorflow.Util; namespace Tensorflow.Keras.Engine; @@ -22,8 +23,10 @@ ICallback fit(NDArray x, NDArray y, int verbose = 1, List callbacks = null, float validation_split = 0f, - (NDArray val_x, NDArray val_y)? validation_data = null, + ValidationDataPack validation_data = null, bool shuffle = true, + Dictionary class_weight = null, + NDArray sample_weight = null, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, @@ -35,8 +38,10 @@ ICallback fit(IEnumerable x, NDArray y, int verbose = 1, List callbacks = null, float validation_split = 0f, - (IEnumerable val_x, NDArray val_y)? validation_data = null, + ValidationDataPack validation_data = null, bool shuffle = true, + Dictionary class_weight = null, + NDArray sample_weight = null, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, @@ -63,6 +68,8 @@ void load_weights(string filepath, Dictionary evaluate(NDArray x, NDArray y, int batch_size = -1, int verbose = 1, + NDArray sample_weight = null, + int steps = -1, int max_queue_size = 10, int workers = 1, diff --git a/src/TensorFlowNET.Core/Util/Data.cs b/src/TensorFlowNET.Core/Util/Data.cs new file mode 100644 index 000000000..a14c69b18 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/Data.cs @@ -0,0 +1,66 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Util +{ + /// + /// ValidationDataPack is used to pass validation data to fit method. + /// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays. + /// + public class ValidationDataPack + { + public NDArray val_x; + public NDArray val_y; + public NDArray val_sample_weight = null; + + public ValidationDataPack((NDArray, NDArray) validation_data) + { + this.val_x = validation_data.Item1; + this.val_y = validation_data.Item2; + } + + public ValidationDataPack((NDArray, NDArray, NDArray) validation_data) + { + this.val_x = validation_data.Item1; + this.val_y = validation_data.Item2; + this.val_sample_weight = validation_data.Item3; + } + + public ValidationDataPack((IEnumerable, NDArray) validation_data) + { + this.val_x = validation_data.Item1.ToArray()[0]; + this.val_y = validation_data.Item2; + } + + public ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data) + { + this.val_x = validation_data.Item1.ToArray()[0]; + this.val_y = validation_data.Item2; + this.val_sample_weight = validation_data.Item3; + } + + public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) + => new ValidationDataPack(validation_data); + + public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data) + => new ValidationDataPack(validation_data); + + public static implicit operator ValidationDataPack((IEnumerable, NDArray) validation_data) + => new ValidationDataPack(validation_data); + + public static implicit operator ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data) + => new ValidationDataPack(validation_data); + + public void Deconstruct(out NDArray val_x, out NDArray val_y) + { + val_x = this.val_x; + val_y = this.val_y; + } + + public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) + { + val_x = this.val_x; + val_y = this.val_y; + val_sample_weight = this.val_sample_weight; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs index 6c7d53b2f..b2750496a 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Util; namespace Tensorflow.Keras.Engine.DataAdapters { @@ -34,9 +35,67 @@ public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y) return (x, y); } + public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight) + { + for (int i = 0; i < x.Length; i++) + { + if (x[i].shape.ndim == 1) + x[i] = array_ops.expand_dims(x[i], axis: -1); + } + for (int i = 0; i < y.Length; i++) + { + if (y[i].shape.ndim == 1) + y[i] = array_ops.expand_dims(y[i], axis: -1); + } + for (int i = 0; i < sample_weight.Length; i++) + { + if (sample_weight[i].shape.ndim == 1) + sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1); + } + return (x, y, sample_weight); + } + public virtual bool ShouldRecreateIterator() { return true; } + + public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split) + { + var x = x_y_sample_weight.Item1; + var y = x_y_sample_weight.Item2; + var sample_weight = x_y_sample_weight.Item3; + int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); + var train_x = x[new Slice(0, train_count)]; + var train_y = y[new Slice(0, train_count)]; + ValidationDataPack validation_data; + if (sample_weight != null) + { + validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]); + sample_weight = sample_weight[new Slice(0, train_count)]; + } + else + { + validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]); + } + + return ((train_x, train_y, sample_weight), validation_data); + } + + public static ((IEnumerable, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable, NDArray, NDArray) x_y_sample_weight, float validation_split) + { + var x = x_y_sample_weight.Item1; + var y = x_y_sample_weight.Item2; + var sample_weight = x_y_sample_weight.Item3; + int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); + var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); + var train_y = y[new Slice(0, train_count)]; + var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); + var val_y = y[new Slice(train_count)]; + NDArray tmp_sample_weight = sample_weight; + sample_weight = sample_weight[new Slice(0, train_count)]; + ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); + return ((train_x, train_y, sample_weight), validation_data); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs index 4723222f2..a5ee75c93 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using static Tensorflow.Binding; +using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Engine.DataAdapters { @@ -28,6 +29,7 @@ public class DataHandler public DataHandler(DataHandlerArgs args) { this.args = args; + if (args.StepsPerExecution == null) { _steps_per_execution = tf.Variable(1L); @@ -48,6 +50,7 @@ public DataHandler(DataHandlerArgs args) BatchSize = args.BatchSize, Steps = args.StepsPerEpoch, Epochs = args.Epochs - args.InitialEpoch, + SampleWeight = args.SampleWeight, Shuffle = args.Shuffle, MaxQueueSize = args.MaxQueueSize, Worker = args.Workers, diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs index 4bdc49795..bb71b0a2d 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs @@ -17,6 +17,8 @@ public interface IDataAdapter IDatasetV2 GetDataset(); int GetSize(); (Tensors, Tensors) Expand1d(Tensors x, Tensors y); + (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight); + bool ShouldRecreateIterator(); } } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index 16e646a35..978a3f51c 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -20,7 +20,7 @@ public class TensorLikeDataAdapter : DataAdapter, IDataAdapter public TensorLikeDataAdapter(DataAdapterArgs args) { this.args = args; - _process_tensorlike(); + Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null; num_samples = (int)args.X.shape[0]; var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; _batch_size = batch_size; @@ -37,6 +37,8 @@ public TensorLikeDataAdapter(DataAdapterArgs args) inputs.AddRange(args.X); if (args.Y != null) inputs.AddRange(args.Y); + if (sample_weight_tensor != null) + inputs.Add(sample_weight_tensor); dataset = slice_inputs(indices_dataset, inputs); dataset.FirstInputTensorCount = args.X.Length; } @@ -94,8 +96,9 @@ IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements) public override bool ShouldRecreateIterator() => false; - void _process_tensorlike() + Tensor _process_tensorlike(NDArray sample_weights) { + return tf.convert_to_tensor(sample_weights); } } } diff --git a/src/TensorFlowNET.Keras/Engine/LossesContainer.cs b/src/TensorFlowNET.Keras/Engine/LossesContainer.cs index 6a91450de..c06fca593 100644 --- a/src/TensorFlowNET.Keras/Engine/LossesContainer.cs +++ b/src/TensorFlowNET.Keras/Engine/LossesContainer.cs @@ -26,11 +26,11 @@ public LossesContainer(ILossFunc losses, string[] output_names = null) /// /// /// - public Tensor Call(Tensor y_true, Tensor y_pred) + public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) { if (!_built) Build(y_pred); - var loss_value = _losses.Call(y_true, y_pred); + var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight); var loss_metric_value = loss_value; var batch_dim = array_ops.shape(y_true)[0]; diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index a74a77f18..626d7fcad 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -30,6 +30,7 @@ public partial class Model public Dictionary evaluate(NDArray x, NDArray y, int batch_size = -1, int verbose = 1, + NDArray sample_weight = null, int steps = -1, int max_queue_size = 10, int workers = 1, @@ -51,6 +52,7 @@ public Dictionary evaluate(NDArray x, NDArray y, StepsPerEpoch = steps, InitialEpoch = 0, Epochs = 1, + SampleWeight = sample_weight, MaxQueueSize = max_queue_size, Workers = workers, UseMultiprocessing = use_multiprocessing, @@ -140,7 +142,8 @@ Dictionary evaluate(DataHandler data_handler, CallbackList callba Dictionary test_function(DataHandler data_handler, OwnedIterator iterator) { var data = iterator.next(); - var outputs = test_step(data_handler, data[0], data[1]); + var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) : + test_step(data_handler, data[0], data[1], data[2]); tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); return outputs; } @@ -149,17 +152,23 @@ Dictionary test_step_multi_inputs_function(DataHandler data_handl { var data = iterator.next(); var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; - var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray()); + var outputs = data.Length == 2 ? + test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : + test_step( + data_handler, + new Tensors(data.Take(x_size).ToArray()), + new Tensors(data.Skip(x_size).Take(x_size).ToArray()), + new Tensors(data.Skip(2 * x_size).ToArray())); tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); return outputs; } - Dictionary test_step(DataHandler data_handler, Tensors x, Tensors y) + Dictionary test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) { - (x, y) = data_handler.DataAdapter.Expand1d(x, y); + (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); var y_pred = Apply(x, training: false); - var loss = compiled_loss.Call(y, y_pred); + var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); compiled_metrics.update_state(y, y_pred); return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index d6f89d8be..23c53b707 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -6,10 +6,12 @@ using Tensorflow.Keras.Engine.DataAdapters; using System.Diagnostics; using Tensorflow.Keras.Callbacks; -using System.Data; +using Tensorflow.Util; namespace Tensorflow.Keras.Engine { + + public partial class Model { /// @@ -19,19 +21,29 @@ public partial class Model /// /// /// - /// /// + /// /// /// /// + /// + /// + /// + /// + /// + /// + /// + /// public ICallback fit(NDArray x, NDArray y, int batch_size = -1, int epochs = 1, int verbose = 1, List callbacks = null, float validation_split = 0f, - (NDArray val_x, NDArray val_y)? validation_data = null, + ValidationDataPack validation_data = null, bool shuffle = true, + Dictionary class_weight = null, + NDArray sample_weight = null, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, @@ -43,21 +55,25 @@ public ICallback fit(NDArray x, NDArray y, $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); } - var train_x = x; - var train_y = y; + // The default dtype in NDArray is double, so we need to cast sample_weight to float to mul with loss which's dtype is float. + sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); if (validation_split != 0f && validation_data == null) { - int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); - train_x = x[new Slice(0, train_count)]; - train_y = y[new Slice(0, train_count)]; - validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]); + ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); + } + + // TODO(Wanglongzhi2001) + if (class_weight != null) + { + throw new NotImplementedException("class_weight is not implemented"); } var data_handler = new DataHandler(new DataHandlerArgs { - X = train_x, - Y = train_y, + X = x, + Y = y, + SampleWeight = sample_weight, BatchSize = batch_size, InitialEpoch = initial_epoch, Epochs = epochs, @@ -73,14 +89,17 @@ public ICallback fit(NDArray x, NDArray y, train_step_func: train_step_function); } + public ICallback fit(IEnumerable x, NDArray y, int batch_size = -1, int epochs = 1, int verbose = 1, List callbacks = null, float validation_split = 0f, - (IEnumerable val_x, NDArray val_y)? validation_data = null, + ValidationDataPack validation_data = null, bool shuffle = true, + Dictionary class_weight = null, + NDArray sample_weight = null, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, @@ -95,27 +114,23 @@ public ICallback fit(IEnumerable x, NDArray y, } } - var train_x = x; - var train_y = y; + sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); + if (validation_split != 0f && validation_data == null) { - int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); - train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); - train_y = y[new Slice(0, train_count)]; - var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); - var val_y = y[new Slice(train_count)]; - validation_data = (val_x, val_y); + ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); } var data_handler = new DataHandler(new DataHandlerArgs { - X = new Tensors(train_x.ToArray()), - Y = train_y, + X = new Tensors(x.ToArray()), + Y = y, BatchSize = batch_size, InitialEpoch = initial_epoch, Epochs = epochs, Shuffle = shuffle, + SampleWeight = sample_weight, MaxQueueSize = max_queue_size, Workers = workers, UseMultiprocessing = use_multiprocessing, @@ -142,8 +157,10 @@ public History fit(IDatasetV2 dataset, int verbose = 1, List callbacks = null, IDatasetV2 validation_data = null, - int validation_step = 10, // 间隔多少次会进行一次验证 + int validation_step = 10, bool shuffle = true, + Dictionary class_weight = null, + NDArray sample_weight = null, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, @@ -210,7 +227,7 @@ History FitInternal(DataHandler data_handler, int epochs, int validation_step, i { if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) continue; - + var val_logs = evaluate(validation_data); foreach(var log in val_logs) { @@ -233,7 +250,7 @@ History FitInternal(DataHandler data_handler, int epochs, int validation_step, i return callbacks.History; } - History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, (NDArray, NDArray)? validation_data, + History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, ValidationDataPack validation_data, Func> train_step_func) { stop_training = false; @@ -274,7 +291,8 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, (IEnumerable, NDArray)? validation_data, - Func> train_step_func) - { - stop_training = false; - _train_counter.assign(0); - var callbacks = new CallbackList(new CallbackParams - { - Model = this, - Verbose = verbose, - Epochs = epochs, - Steps = data_handler.Inferredsteps - }); - - if (callbackList != null) - { - foreach (var callback in callbackList) - callbacks.callbacks.add(callback); - } - - callbacks.on_train_begin(); - - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) - { - reset_metrics(); - callbacks.on_epoch_begin(epoch); - // data_handler.catch_stop_iteration(); - var logs = new Dictionary(); - long End_step = 0; - foreach (var step in data_handler.steps()) - { - callbacks.on_train_batch_begin(step); - logs = train_step_func(data_handler, iterator); - var end_step = step + data_handler.StepIncrement; - End_step = end_step; - callbacks.on_train_batch_end(end_step, logs); - } - - if (validation_data != null) - { - var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2); - foreach (var log in val_logs) - { - logs["val_" + log.Key] = log.Value; - callbacks.on_train_batch_end(End_step, logs); - } - } - - callbacks.on_epoch_end(epoch, logs); - - GC.Collect(); - GC.WaitForPendingFinalizers(); - if (stop_training) - { - break; - } - } - - return callbacks.History; - } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index ad3c70d2d..8f1ec808c 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -12,7 +12,9 @@ public partial class Model Dictionary train_step_function(DataHandler data_handler, OwnedIterator iterator) { var data = iterator.next(); - var outputs = train_step(data_handler, data[0], data[1]); + // whether have sample_weight + var outputs = data.Length == 2 ? train_step(data_handler, data[0], data[1]) : + train_step(data_handler, data[0], data[1], data[2]); tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); return outputs; } @@ -21,7 +23,13 @@ Dictionary train_step_multi_inputs_function(DataHandler data_hand { var data = iterator.next(); var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; - var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); + var outputs = data.Length == 2 ? + train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : + train_step( + data_handler, + new Tensors(data.Take(x_size).ToArray()), + new Tensors(data.Skip(x_size).Take(x_size).ToArray()), + new Tensors(data.Skip(2 * x_size).ToArray())); tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); return outputs; } @@ -61,6 +69,34 @@ Dictionary train_step(DataHandler data_handler, Tensors x, Tensor }); return dict; } + Dictionary train_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) + { + (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); + using var tape = tf.GradientTape(); + var y_pred = Apply(x, training: true); + var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); + + // For custom training steps, users can just write: + // trainable_variables = self.trainable_variables + // gradients = tape.gradient(loss, trainable_variables) + // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) + // The _minimize call does a few extra steps unnecessary in most cases, + // such as loss scaling and gradient clipping. + _minimize(tape, optimizer, loss, TrainableVariables); + compiled_metrics.update_state(y, y_pred); + + var dict = new Dictionary(); + metrics.ToList().ForEach(x => + { + var r = x.result(); + if (r.ndim > 0) + { + r = tf.reduce_mean(r); + } + dict[x.Name] = (float)r; + }); + return dict; + } void _minimize(GradientTape tape, IOptimizer optimizer, Tensor loss, List trainable_variables) { diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs index dbf5cae1e..67e2b0464 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -74,8 +74,8 @@ public void TrainLSTMWithMnist() OneHot = true, ValidationSize = 55000, }).Result; - - model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1); + var sample_weight = np.ones(((int)dataset.Train.Data.shape[0])); + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1, sample_weight:sample_weight); } [TestMethod]