@@ -298,5 +298,45 @@ public unsafe void tensor_resize()
298298
299299 tf . compat . v1 . disable_eager_execution ( ) ;
300300 }
301+
302+ /// <summary>
303+ /// Assign tensor to slice of other tensor.
304+ /// </summary>
305+ [ TestMethod ]
306+ public void TestAssignOfficial ( )
307+ {
308+ // example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__
309+
310+ // python
311+ // import tensorflow as tf
312+ // A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32)
313+ // with tf.compat.v1.Session() as sess:
314+ // sess.run(tf.compat.v1.global_variables_initializer())
315+ // print(sess.run(A[:2, :2])) # => [[1,2], [4,5]]
316+
317+ // op = A[:2,:2].assign(22. * tf.ones((2, 2)))
318+ // print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]]
319+
320+ // C#
321+ // [[1,2,3], [4,5,6], [7,8,9]]
322+ double [ ] [ ] initial = new double [ ] [ ]
323+ {
324+ new double [ ] { 1 , 2 , 3 } ,
325+ new double [ ] { 4 , 5 , 6 } ,
326+ new double [ ] { 7 , 8 , 9 }
327+ } ;
328+ Tensor A = tf . Variable ( initial , dtype : tf . float32 ) ;
329+ // Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]]
330+ Tensor result1 = A [ ":2" , ":2" ] ;
331+ Assert . IsTrue ( Enumerable . SequenceEqual ( new double [ ] { 1 , 2 } , result1 [ 0 ] . ToArray < double > ( ) ) ) ;
332+ Assert . IsTrue ( Enumerable . SequenceEqual ( new double [ ] { 4 , 5 } , result1 [ 1 ] . ToArray < double > ( ) ) ) ;
333+
334+ // An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible'
335+ Tensor op = A [ ":2" , ":2" ] . assign ( 22.0 * tf . ones ( ( 2 , 2 ) ) ) ;
336+ // Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]]
337+ Assert . IsTrue ( Enumerable . SequenceEqual ( new double [ ] { 22 , 22 , 3 } , op [ 0 ] . ToArray < double > ( ) ) ) ;
338+ Assert . IsTrue ( Enumerable . SequenceEqual ( new double [ ] { 22 , 22 , 6 } , op [ 1 ] . ToArray < double > ( ) ) ) ;
339+ Assert . IsTrue ( Enumerable . SequenceEqual ( new double [ ] { 7 , 8 , 9 } , op [ 2 ] . ToArray < double > ( ) ) ) ;
340+ }
301341 }
302342}
0 commit comments