diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index 4d9c3da58..b529cd319 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -140,6 +140,16 @@ public Tensor identity(Tensor input, string name = null)
public Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
=> array_ops.gather(@params, indices, name: name, axis: ops.convert_to_tensor(axis));
+ ///
+ /// Gather slices from `params` into a Tensor with shape specified by `indices`.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor gather_nd(Tensor @params, Tensor indices, string name = null)
+ => gen_array_ops.gather_nd(@params, indices, name: name);
+
///
/// Return the elements, either from `x` or `y`, depending on the `condition`.
///
diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs
index 4b7027992..a4da60eed 100644
--- a/src/TensorFlowNET.Core/Gradients/array_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs
@@ -381,5 +381,48 @@ public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads)
var axis = op.inputs[1];
return new Tensor[] { array_ops.reverse(grad, axis), null };
}
+
+ [RegisterGradient("Tile")]
+ public static Tensor[] _TileGrad(Operation op, Tensor[] grads)
+ {
+ var grad = grads[0];
+ var input_shape = array_ops.shape(op.inputs[0], out_type: op.inputs[1].dtype);
+ var split_shape = array_ops.reshape(array_ops.transpose(array_ops.stack(new Tensor[] { op.inputs[1], input_shape })), new Shape(-1));
+ var axes = math_ops.range(0, array_ops.size(split_shape), 2);
+
+ //# Sum reduces grad along the first dimension for IndexedSlices
+ //if isinstance(grad, indexed_slices_lib.IndexedSlices):
+ //input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
+ //grad = math_ops.unsorted_segment_sum(
+ // grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
+ //split_shape = array_ops.concat([[1], split_shape[1:]], axis = 0)
+
+ var input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes);
+ if (!tf.Context.executing_eagerly())
+ {
+ input_grad.set_shape(op.inputs[0].GetShape());
+ }
+ return new Tensor[] { input_grad, null };
+ }
+
+ [RegisterGradient("GatherNd")]
+ public static Tensor[] _GatherNdGrad(Operation op, Tensor[] grads)
+ {
+ var @ref = op.inputs[0];
+ var indices = op.inputs[1];
+ var grad = grads[0];
+ var ref_shape = array_ops.shape(@ref, out_type: indices.dtype);
+ Tensor ref_grad = null;
+ if (indices.shape.ndim == 2 && indices.shape.dims[indices.shape.Length - 1] == 1)
+ {
+ ref_grad = (Tensor)new IndexedSlices(grad, array_ops.squeeze(indices, axis: -1), ref_shape);
+ }
+ else
+ {
+ ref_grad = gen_array_ops.scatter_nd(indices, grad, ref_shape);
+ }
+ return new Tensor[] { ref_grad, null };
+ }
+
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
index d441dc828..1d215576f 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
@@ -4,10 +4,8 @@
namespace Tensorflow.Keras.ArgsDefinition
{
- public class GRUOptionalArgs
+ public class GRUOptionalArgs : RnnOptionalArgs
{
public string Identifier => "GRU";
-
- public Tensor Mask { get; set; } = null;
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMOptionalArgs.cs
new file mode 100644
index 000000000..2829927c3
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMOptionalArgs.cs
@@ -0,0 +1,11 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.ArgsDefinition.Rnn
+{
+ public class LSTMOptionalArgs : RnnOptionalArgs
+ {
+ public string Identifier => "LSTM";
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNOptionalArgs.cs
new file mode 100644
index 000000000..a8b8caf06
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNOptionalArgs.cs
@@ -0,0 +1,11 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.ArgsDefinition.Rnn
+{
+ public class SimpleRNNOptionalArgs : RnnOptionalArgs
+ {
+ public string Identifier => "SimpleRNN";
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index fdc53cd7e..57af3b835 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -829,7 +829,7 @@ public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end,
/// A `Tensor`. Has the same type as `input`.
/// Contains the same data as `input`, but has one or more dimensions of
/// size 1 removed.
- public static Tensor squeeze(Tensor input, int[] axis = null, string name = null)
+ public static Tensor squeeze(Tensor input, Axis axis = null, string name = null)
=> gen_array_ops.squeeze(input, axis, name);
public static Tensor identity(Tensor input, string name = null)
@@ -990,7 +990,7 @@ public static Tensor gather(ResourceVariable @params, Tensor indices, string nam
return @params.sparse_read(indices, name);
}
- public static Tensor transpose(T1 a, Axis perm, string name = "transpose", bool conjugate = false)
+ public static Tensor transpose(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false)
{
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
diff --git a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
index e41e1d617..1cfceb3e3 100644
--- a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
+++ b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
@@ -62,7 +62,7 @@ public void SquaredDifference_1D()
// Calcute the gradient of (x1-x2)^2
// by Automatic Differentiation in Eager mode
// Expected is 2*(abs(x1-x2))
- Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
+ Tensor x1 = new NDArray(new float[] { 1, 3, 5, 21, 19, 17 });
Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
float[] expected = new float[]
{
@@ -173,5 +173,34 @@ public void ConditionalMultiply()
var result = grad(x, 4);
Assert.AreEqual((float)result, 4.0f);
}
+
+ [TestMethod]
+ public void Tile()
+ {
+ var a = tf.constant(new int[] { 1 }, TF_DataType.TF_FLOAT);
+ var b = tf.constant(new int[] { 2 });
+ using (var tape = tf.GradientTape())
+ {
+ tape.watch(a);
+ var y = tf.tile(a, b);
+ var grad = tape.gradient(y, a);
+ Assert.AreEqual((float)grad.numpy(), 2.0f);
+ }
+ }
+
+ [TestMethod]
+ public void GatherNdTest()
+ {
+ var x = tf.constant(new float[,] { { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f } }, dtype: TF_DataType.TF_FLOAT);
+ var indices = tf.constant(new int[,] { { 0, 1 }, { 1, 1 }, { 2, 1 } }, dtype: TF_DataType.TF_INT32);
+ using (var tape = tf.GradientTape())
+ {
+ tape.watch(x);
+ var res = tf.gather_nd(x, indices);
+ var grad = tape.gradient(res, x);
+ var expected = np.array(new float[,] { { 0f, 1f, 0f }, { 0f, 1f, 0f }, { 0f, 1f, 0f } });
+ Assert.IsTrue(Enumerable.SequenceEqual(grad.ToArray(), expected.ToArray()));
+ }
+ }
}
}