Skip to content

Commit

Permalink
add tf.reverse_v2 function
Browse files Browse the repository at this point in the history
tf.reverse param2 is "dims" type DT_BOOL,
tf.reverse_v2 params2 is "axis" type Tidx
  • Loading branch information
dogvane committed Jul 26, 2023
1 parent ab7b986 commit 2041d7e
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.array.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,20 @@ public Tensor reverse(Tensor tensor, Axis axis, string name = null)
return array_ops.reverse(tensor, axis, name: name);
}


/// <summary>
/// Reverses specific dimensions of a tensor.
/// </summary>
/// <param name="tensor"></param>
/// <param name="axis"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor reverse_v2(Tensor tensor, int[] axis, string name = null)
=> gen_array_ops.reverse_v2(tensor, ops.convert_to_tensor(axis), name: name);

public Tensor reverse_v2(Tensor tensor, Tensor axis, string name = null)
=> gen_array_ops.reverse_v2(tensor, axis, name: name);

/// <summary>
/// Returns the rank of a tensor.
/// </summary>
Expand Down
157 changes: 157 additions & 0 deletions test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,163 @@
using static Tensorflow.Binding;
using System.Linq;

namespace TensorFlowNET.UnitTest.ManagedAPI
{
[TestClass]
public class ArrayOpsTest : EagerModeTestBase
{
/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/slice
/// </summary>
[TestMethod]
public void Slice()
{
// Tests based on example code in TF documentation
var input_array = tf.constant(np.array(new int[] { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }).reshape((3, 2, 3)));
var indices = tf.constant(np.array(new int[] { 0, 2 }));

var r1 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 1, 1, 3 }));
Assert.AreEqual(new Shape(1, 1, 3), r1.shape);
var r1np = r1.numpy();
Assert.AreEqual(r1np[0, 0, 0], 3);
Assert.AreEqual(r1np[0, 0, 1], 3);
Assert.AreEqual(r1np[0, 0, 2], 3);


var r2 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 1, 2, 3 }));
Assert.AreEqual(new Shape(1, 2, 3), r2.shape);
var r2np = r2.numpy();
Assert.AreEqual(r2np[0, 0, 0], 3);
Assert.AreEqual(r2np[0, 0, 1], 3);
Assert.AreEqual(r2np[0, 0, 2], 3);
Assert.AreEqual(r2np[0, 1, 0], 4);
Assert.AreEqual(r2np[0, 1, 1], 4);
Assert.AreEqual(r2np[0, 1, 2], 4);

var r3 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 2, 1, 3 }));
Assert.AreEqual(new Shape(2, 1, 3), r3.shape);
var r3np = r3.numpy();
Assert.AreEqual(r3np[0, 0, 0], 3);
Assert.AreEqual(r3np[0, 0, 1], 3);
Assert.AreEqual(r3np[0, 0, 2], 3);
Assert.AreEqual(r3np[1, 0, 0], 5);
Assert.AreEqual(r3np[1, 0, 1], 5);
Assert.AreEqual(r3np[1, 0, 2], 5);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/gather
/// </summary>
[TestMethod]
public void Gather()
{
var input_array = tf.constant(np.arange(12).reshape((3, 4)).astype(np.float32));
var indices = tf.constant(np.array(new int[] { 0, 2 }));

var result = array_ops.gather(input_array, indices);
Assert.AreEqual(new Shape(2, 4), result.shape);
Assert.AreEqual(result.numpy()[0, 0], 0.0f);
Assert.AreEqual(result.numpy()[0, 1], 1.0f);
Assert.AreEqual(result.numpy()[1, 3], 11.0f);

// Tests based on example code in Python doc string for tf.gather()

var p1 = tf.random.normal(new Shape(5, 6, 7, 8));
var i1 = tf.random_uniform(new Shape(10, 11), maxval: 7, dtype: tf.int32);
var r1 = tf.gather(p1, i1, axis: 2);
Assert.AreEqual(new Shape(5, 6, 10, 11, 8), r1.shape);

var p2 = tf.random.normal(new Shape(4, 3));
var i2 = tf.constant(new int[,] { { 0, 2 } });
var r2 = tf.gather(p2, i2, axis: 0);
Assert.AreEqual(new Shape(1, 2, 3), r2.shape);

var r3 = tf.gather(p2, i2, axis: 1);
Assert.AreEqual(new Shape(4, 1, 2), r3.shape);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/TensorArray
/// </summary>
[TestMethod]
public void TensorArray()
{
var ta = tf.TensorArray(tf.float32, size: 0, dynamic_size: true, clear_after_read: false);
ta.write(0, 10);
ta.write(1, 20);
ta.write(2, 30);
Assert.AreEqual(ta.read(0).numpy(), 10f);
Assert.AreEqual(ta.read(1).numpy(), 20f);
Assert.AreEqual(ta.read(2).numpy(), 30f);
}

/// <summary>
///
/// </summary>
[TestMethod]
public void Reverse()
{
/*
* python run get test data code:
import tensorflow as tf
data=[[1, 2, 3], [4, 5, 6], [7,8,9]]
data2 = tf.constant(data)
print('test data shaper:', data2.shape)
print('test data:', data2)
axis = [-2,-1,0,1]
for i in axis:
print('')
print('axis:', i)
ax = tf.constant([i])
datar = tf.reverse(data2, ax)
datar2 = array_ops.reverse(data2, ax)
print(datar)
print(datar2)
* */
var inputData = np.array(new int[,] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } });
var expectedOutput = new[] {
// np.array(new int[,] { { 7, 8, 9 }, { 4, 5, 6 }, { 1, 2, 3 } }),
np.array(new int[,] { { 3, 2, 1 }, { 6, 5, 4 }, { 9, 8, 7 } }),
np.array(new int[,] { { 7, 8, 9 }, { 4, 5, 6 }, { 1, 2, 3 } }),
np.array(new int[,] { { 3, 2, 1 }, { 6, 5, 4 }, { 9, 8, 7 } })
};

var axes = new int [] {
-1,
0,
1 };
for (var i = 0; i < axes.Length; i++)
{
var axis = axes[i];
var expected = tf.constant(expectedOutput[i]).numpy();

var inputTensor = tf.constant(inputData);
var axisTrensor = tf.constant(new[] { axis });

var outputTensor = tf.reverse_v2(inputTensor, axisTrensor);
var npout = outputTensor.numpy();
Assert.IsTrue(Enumerable.SequenceEqual(npout, expected), $"axis:{axis}");

var outputTensor2 = tf.reverse_v2(inputTensor, new[] { axis } );
var npout2 = outputTensor2.numpy();
Assert.IsTrue(Enumerable.SequenceEqual(npout2, expected), $"axis:{axis}");

}
}
}
}
using Microsoft.VisualStudio.TestTools.UnitTesting;

Check failure on line 158 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / linux

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 158 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / linux

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 158 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / windows

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 158 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / windows

A using clause must precede all other elements defined in the namespace except extern alias declarations
using Tensorflow.NumPy;

Check failure on line 159 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / linux

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 159 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / linux

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 159 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / windows

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 159 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / windows

A using clause must precede all other elements defined in the namespace except extern alias declarations
using Tensorflow;

Check failure on line 160 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / linux

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 160 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / linux

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 160 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / windows

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 160 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / windows

A using clause must precede all other elements defined in the namespace except extern alias declarations
using static Tensorflow.Binding;

Check failure on line 161 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / linux

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 161 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / linux

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 161 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / windows

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 161 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / windows

A using clause must precede all other elements defined in the namespace except extern alias declarations
using System.Linq;

Check failure on line 162 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / linux

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 162 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / linux

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 162 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / windows

A using clause must precede all other elements defined in the namespace except extern alias declarations

Check failure on line 162 in test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

View workflow job for this annotation

GitHub Actions / windows

A using clause must precede all other elements defined in the namespace except extern alias declarations

namespace TensorFlowNET.UnitTest.ManagedAPI
{
[TestClass]
Expand Down

0 comments on commit 2041d7e

Please sign in to comment.