Skip to content

Commit 0ff31a3

Browse files
committed
fix tf.transpose #673
1 parent b3f72cd commit 0ff31a3

File tree

4 files changed

+77
-16
lines changed

4 files changed

+77
-16
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public Tensor where<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null)
152152
/// <param name="name"></param>
153153
/// <param name="conjugate"></param>
154154
/// <returns></returns>
155-
public Tensor transpose<T1>(T1 a, int[] perm = null, string name = "transpose", bool conjugate = false)
155+
public Tensor transpose<T1>(T1 a, TensorShape perm = null, string name = "transpose", bool conjugate = false)
156156
=> array_ops.transpose(a, perm, name, conjugate);
157157

158158
/// <summary>

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,22 @@ public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null,
779779
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
780780
}
781781

782-
public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false)
782+
public static Tensor transpose<T1>(T1 a, TensorShape perm, string name = "transpose", bool conjugate = false)
783+
{
784+
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
785+
{
786+
var a_tensor = ops.convert_to_tensor(a);
787+
if(perm == null)
788+
{
789+
var rank = a_tensor.rank;
790+
perm = range(0, rank).OrderByDescending(x => x).ToArray();
791+
}
792+
793+
return gen_array_ops.transpose(a_tensor, perm, name: scope);
794+
});
795+
}
796+
797+
public static Tensor transpose(Tensor a, Tensor perm, string name = "transpose", bool conjugate = false)
783798
{
784799
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
785800
{

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ public static Tensor tile<T>(Tensor input, T multiples, string name = null)
531531
input, multiples).FirstOrDefault(),
532532
input);
533533

534-
public static Tensor transpose<T1, T2>(T1 x, T2 perm, string name = null)
534+
public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null)
535535
{
536536
if (tf.Context.executing_eagerly())
537537
{

test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,73 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
88
[TestClass]
99
public class TensorOperate
1010
{
11-
[TestMethod, Ignore]
11+
[TestMethod]
1212
public void TransposeTest()
1313
{
1414
// https://www.tensorflow.org/api_docs/python/tf/transpose#for_example_2
15-
var x = tf.constant(new int[,] {
15+
var x = tf.constant(new int[,]
16+
{
1617
{ 1, 2, 3 },
1718
{ 4, 5, 6 }
1819
});
1920
var transpose_x = tf.transpose(x);
20-
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 4 }, transpose_x[0].numpy().ToArray<int>()));
21-
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 5 }, transpose_x[1].numpy().ToArray<int>()));
22-
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 6 }, transpose_x[2].numpy().ToArray<int>()));
21+
Assert.AreEqual(new[] { 1, 4 }, transpose_x[0].numpy());
22+
Assert.AreEqual(new[] { 2, 5 }, transpose_x[1].numpy());
23+
Assert.AreEqual(new[] { 3, 6 }, transpose_x[2].numpy());
24+
25+
#region constant a
26+
var a = tf.constant(np.array(new[, , ,]
27+
{
28+
{
29+
{
30+
{ 1, 11, 2, 22 }
31+
},
32+
{
33+
{ 3, 33, 4, 44 }
34+
}
35+
},
36+
{
37+
{
38+
{ 5, 55, 6, 66 }
39+
},
40+
{
41+
{ 7, 77, 8, 88 }
42+
}
43+
}
44+
}));
45+
46+
#endregion
47+
var actual_transposed_a = tf.transpose(a, new[] { 3, 1, 2, 0 });
2348

24-
var a = tf.constant(np.array(new[, , ,] { { { { 1, 11, 2, 22 } }, { { 3, 33, 4, 44 } } },
25-
{ { { 5, 55, 6, 66 } }, { { 7, 77, 8, 88 } } } }));
26-
var b = tf.transpose(a, new[] { 3, 1, 2, 0 });
27-
var transpose_a = tf.constant(np.array(new[, , ,] { { { { 1, 5 } }, { { 3, 7 } } },
28-
{ { { 11, 55 } }, { { 33, 77 } } }, { { { 2, 6 } }, { { 4, 8 } } },
29-
{ { { 22, 66 } }, { { 44, 88 } } } }));
30-
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 4, 2, 1, 2 }, b.shape));
31-
Assert.IsTrue(Enumerable.SequenceEqual(transpose_a.numpy().ToArray<int>(), b.numpy().ToArray<int>()));
49+
#region constant transpose_a
50+
var expected_transposed_a = tf.constant(np.array(new[, , ,]
51+
{
52+
{
53+
{ { 1, 5 } }, { { 3, 7 } }
54+
},
55+
{
56+
{ { 11, 55 } }, { { 33, 77 } }
57+
},
58+
{
59+
{
60+
{ 2, 6 }
61+
},
62+
{
63+
{ 4, 8 }
64+
}
65+
},
66+
{
67+
{
68+
{ 22, 66 }
69+
},
70+
{
71+
{ 44, 88 }
72+
}
73+
}
74+
}));
75+
#endregion
76+
Assert.AreEqual((4, 2, 1, 2 ), actual_transposed_a.TensorShape);
77+
Assert.AreEqual(expected_transposed_a.numpy(), actual_transposed_a.numpy());
3278
}
3379

3480
[TestMethod]

0 commit comments

Comments
 (0)