@@ -8,27 +8,73 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
88 [ TestClass ]
99 public class TensorOperate
1010 {
11- [ TestMethod , Ignore ]
11+ [ TestMethod ]
1212 public void TransposeTest ( )
1313 {
1414 // https://www.tensorflow.org/api_docs/python/tf/transpose#for_example_2
15- var x = tf . constant ( new int [ , ] {
15+ var x = tf . constant ( new int [ , ]
16+ {
1617 { 1 , 2 , 3 } ,
1718 { 4 , 5 , 6 }
1819 } ) ;
1920 var transpose_x = tf . transpose ( x ) ;
20- Assert . IsTrue ( Enumerable . SequenceEqual ( new int [ ] { 1 , 4 } , transpose_x [ 0 ] . numpy ( ) . ToArray < int > ( ) ) ) ;
21- Assert . IsTrue ( Enumerable . SequenceEqual ( new int [ ] { 2 , 5 } , transpose_x [ 1 ] . numpy ( ) . ToArray < int > ( ) ) ) ;
22- Assert . IsTrue ( Enumerable . SequenceEqual ( new int [ ] { 3 , 6 } , transpose_x [ 2 ] . numpy ( ) . ToArray < int > ( ) ) ) ;
21+ Assert . AreEqual ( new [ ] { 1 , 4 } , transpose_x [ 0 ] . numpy ( ) ) ;
22+ Assert . AreEqual ( new [ ] { 2 , 5 } , transpose_x [ 1 ] . numpy ( ) ) ;
23+ Assert . AreEqual ( new [ ] { 3 , 6 } , transpose_x [ 2 ] . numpy ( ) ) ;
24+
25+ #region constant a
26+ var a = tf . constant ( np . array ( new [ , , , ]
27+ {
28+ {
29+ {
30+ { 1 , 11 , 2 , 22 }
31+ } ,
32+ {
33+ { 3 , 33 , 4 , 44 }
34+ }
35+ } ,
36+ {
37+ {
38+ { 5 , 55 , 6 , 66 }
39+ } ,
40+ {
41+ { 7 , 77 , 8 , 88 }
42+ }
43+ }
44+ } ) ) ;
45+
46+ #endregion
47+ var actual_transposed_a = tf . transpose ( a , new [ ] { 3 , 1 , 2 , 0 } ) ;
2348
24- var a = tf . constant ( np . array ( new [ , , , ] { { { { 1 , 11 , 2 , 22 } } , { { 3 , 33 , 4 , 44 } } } ,
25- { { { 5 , 55 , 6 , 66 } } , { { 7 , 77 , 8 , 88 } } } } ) ) ;
26- var b = tf . transpose ( a , new [ ] { 3 , 1 , 2 , 0 } ) ;
27- var transpose_a = tf . constant ( np . array ( new [ , , , ] { { { { 1 , 5 } } , { { 3 , 7 } } } ,
28- { { { 11 , 55 } } , { { 33 , 77 } } } , { { { 2 , 6 } } , { { 4 , 8 } } } ,
29- { { { 22 , 66 } } , { { 44 , 88 } } } } ) ) ;
30- Assert . IsTrue ( Enumerable . SequenceEqual ( new [ ] { 4 , 2 , 1 , 2 } , b . shape ) ) ;
31- Assert . IsTrue ( Enumerable . SequenceEqual ( transpose_a . numpy ( ) . ToArray < int > ( ) , b . numpy ( ) . ToArray < int > ( ) ) ) ;
49+ #region constant transpose_a
50+ var expected_transposed_a = tf . constant ( np . array ( new [ , , , ]
51+ {
52+ {
53+ { { 1 , 5 } } , { { 3 , 7 } }
54+ } ,
55+ {
56+ { { 11 , 55 } } , { { 33 , 77 } }
57+ } ,
58+ {
59+ {
60+ { 2 , 6 }
61+ } ,
62+ {
63+ { 4 , 8 }
64+ }
65+ } ,
66+ {
67+ {
68+ { 22 , 66 }
69+ } ,
70+ {
71+ { 44 , 88 }
72+ }
73+ }
74+ } ) ) ;
75+ #endregion
76+ Assert . AreEqual ( ( 4 , 2 , 1 , 2 ) , actual_transposed_a . TensorShape ) ;
77+ Assert . AreEqual ( expected_transposed_a . numpy ( ) , actual_transposed_a . numpy ( ) ) ;
3278 }
3379
3480 [ TestMethod ]
0 commit comments