66using System . Linq ;
77using Tensorflow ;
88using static Tensorflow . Binding ;
9+ using System . Collections . Generic ;
910
1011namespace TensorFlowNET . UnitTest
1112{
@@ -144,6 +145,40 @@ public void assertAllClose(double value, NDArray array2, double eps = 1e-5)
144145 Assert . IsTrue ( np . allclose ( array1 , array2 , rtol : eps ) ) ;
145146 }
146147
148+ private class CollectionComparer : IComparer
149+ {
150+ private readonly double _epsilon ;
151+
152+ public CollectionComparer ( double eps = 1e-06 )
153+ {
154+ _epsilon = eps ;
155+ }
156+ public int Compare ( object x , object y )
157+ {
158+ var a = ( double ) x ;
159+ var b = ( double ) y ;
160+
161+ double delta = Math . Abs ( a - b ) ;
162+ if ( delta < _epsilon )
163+ {
164+ return 0 ;
165+ }
166+ return a . CompareTo ( b ) ;
167+ }
168+ }
169+
170+ public void assertAllCloseAccordingToType < T > (
171+ ICollection expected ,
172+ ICollection < T > given ,
173+ double eps = 1e-6 ,
174+ float float_eps = 1e-6f )
175+ {
176+ // TODO: check if any of arguments is not double and change toletance
177+ // remove givenAsDouble and cast expected instead
178+ var givenAsDouble = given . Select ( x => Convert . ToDouble ( x ) ) . ToArray ( ) ;
179+ CollectionAssert . AreEqual ( expected , givenAsDouble , new CollectionComparer ( eps ) ) ;
180+ }
181+
147182 public void assertProtoEquals ( object toProto , object o )
148183 {
149184 throw new NotImplementedException ( ) ;
@@ -153,6 +188,20 @@ public void assertProtoEquals(object toProto, object o)
153188
154189 #region tensor evaluation and test session
155190
191+ private Session _cached_session = null ;
192+ private Graph _cached_graph = null ;
193+ private object _cached_config = null ;
194+ private bool _cached_force_gpu = false ;
195+
196+ private void _ClearCachedSession ( )
197+ {
198+ if ( self . _cached_session != null )
199+ {
200+ self . _cached_session . Dispose ( ) ;
201+ self . _cached_session = null ;
202+ }
203+ }
204+
156205 //protected object _eval_helper(Tensor[] tensors)
157206 //{
158207 // if (tensors == null)
@@ -196,17 +245,25 @@ public T evaluate<T>(Tensor tensor)
196245 // return self._eval_helper(tensors)
197246 // else:
198247 {
199- var sess = tf . Session ( ) ;
248+ var sess = tf . get_default_session ( ) ;
200249 var ndarray = tensor . eval ( sess ) ;
201- if ( typeof ( T ) == typeof ( double ) )
250+ if ( typeof ( T ) == typeof ( double )
251+ || typeof ( T ) == typeof ( float )
252+ || typeof ( T ) == typeof ( int ) )
253+ {
254+ result = Convert . ChangeType ( ndarray , typeof ( T ) ) ;
255+ }
256+ else if ( typeof ( T ) == typeof ( double [ ] ) )
257+ {
258+ result = ndarray . ToMultiDimArray < double > ( ) ;
259+ }
260+ else if ( typeof ( T ) == typeof ( float [ ] ) )
202261 {
203- double x = ndarray ;
204- result = x ;
262+ result = ndarray . ToMultiDimArray < float > ( ) ;
205263 }
206- else if ( typeof ( T ) == typeof ( int ) )
264+ else if ( typeof ( T ) == typeof ( int [ ] ) )
207265 {
208- int x = ndarray ;
209- result = x ;
266+ result = ndarray . ToMultiDimArray < int > ( ) ;
210267 }
211268 else
212269 {
@@ -218,9 +275,56 @@ public T evaluate<T>(Tensor tensor)
218275 }
219276
220277
221- public Session cached_session ( )
278+ ///Returns a TensorFlow Session for use in executing tests.
279+ public Session cached_session (
280+ Graph graph = null , object config = null , bool use_gpu = false , bool force_gpu = false )
222281 {
223- throw new NotImplementedException ( ) ;
282+ // This method behaves differently than self.session(): for performance reasons
283+ // `cached_session` will by default reuse the same session within the same
284+ // test.The session returned by this function will only be closed at the end
285+ // of the test(in the TearDown function).
286+
287+ // Use the `use_gpu` and `force_gpu` options to control where ops are run.If
288+ // `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if
289+ // `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
290+ // possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
291+ // the CPU.
292+
293+ // Example:
294+ // python
295+ // class MyOperatorTest(test_util.TensorFlowTestCase) :
296+ // def testMyOperator(self):
297+ // with self.cached_session() as sess:
298+ // valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
299+ // result = MyOperator(valid_input).eval()
300+ // self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
301+ // invalid_input = [-1.0, 2.0, 7.0]
302+ // with self.assertRaisesOpError("negative input not supported"):
303+ // MyOperator(invalid_input).eval()
304+
305+
306+ // Args:
307+ // graph: Optional graph to use during the returned session.
308+ // config: An optional config_pb2.ConfigProto to use to configure the
309+ // session.
310+ // use_gpu: If True, attempt to run as many ops as possible on GPU.
311+ // force_gpu: If True, pin all ops to `/device:GPU:0`.
312+
313+ // Yields:
314+ // A Session object that should be used as a context manager to surround
315+ // the graph building and execution code in a test case.
316+
317+
318+ // TODO:
319+ // if context.executing_eagerly():
320+ // return self._eval_helper(tensors)
321+ // else:
322+ {
323+ var sess = self . _get_cached_session (
324+ graph , config , force_gpu , crash_if_inconsistent_args : true ) ;
325+ using var cached = self . _constrain_devices_and_set_default ( sess , use_gpu , force_gpu ) ;
326+ return cached ;
327+ }
224328 }
225329
226330 //Returns a TensorFlow Session for use in executing tests.
@@ -268,6 +372,40 @@ public Session session(Graph graph = null, object config = null, bool use_gpu =
268372 return s . as_default ( ) ;
269373 }
270374
375+ private Session _constrain_devices_and_set_default ( Session sess , bool use_gpu , bool force_gpu )
376+ {
377+ // Set the session and its graph to global default and constrain devices."""
378+ if ( tf . executing_eagerly ( ) )
379+ return null ;
380+ else
381+ {
382+ sess . graph . as_default ( ) ;
383+ sess . as_default ( ) ;
384+ {
385+ if ( force_gpu )
386+ {
387+ // TODO:
388+
389+ // Use the name of an actual device if one is detected, or
390+ // '/device:GPU:0' otherwise
391+ /* var gpu_name = gpu_device_name();
392+ if (!gpu_name)
393+ gpu_name = "/device:GPU:0"
394+ using (sess.graph.device(gpu_name)) {
395+ yield return sess;
396+ }*/
397+ return sess ;
398+ }
399+ else if ( use_gpu )
400+ return sess ;
401+ else
402+ using ( sess . graph . device ( "/device:CPU:0" ) )
403+ return sess ;
404+ }
405+
406+ }
407+ }
408+
271409 // See session() for details.
272410 private Session _create_session ( Graph graph , object cfg , bool forceGpu )
273411 {
@@ -312,6 +450,54 @@ private Session _create_session(Graph graph, object cfg, bool forceGpu)
312450 return new Session ( graph ) ; //, config = prepare_config(config))
313451 }
314452
453+ private Session _get_cached_session (
454+ Graph graph = null ,
455+ object config = null ,
456+ bool force_gpu = false ,
457+ bool crash_if_inconsistent_args = true )
458+ {
459+ // See cached_session() for documentation.
460+ if ( self . _cached_session == null )
461+ {
462+ var sess = self . _create_session ( graph , config , force_gpu ) ;
463+ self . _cached_session = sess ;
464+ self . _cached_graph = graph ;
465+ self . _cached_config = config ;
466+ self . _cached_force_gpu = force_gpu ;
467+ return sess ;
468+ }
469+ else
470+ {
471+
472+ if ( crash_if_inconsistent_args && self . _cached_graph != null && ! self . _cached_graph . Equals ( graph ) )
473+ throw new ValueError ( @"The graph used to get the cached session is
474+ different than the one that was used to create the
475+ session. Maybe create a new session with
476+ self.session()" ) ;
477+ if ( crash_if_inconsistent_args && self . _cached_config != null && ! self . _cached_config . Equals ( config ) )
478+ {
479+ throw new ValueError ( @"The config used to get the cached session is
480+ different than the one that was used to create the
481+ session. Maybe create a new session with
482+ self.session()" ) ;
483+ }
484+ if ( crash_if_inconsistent_args && ! self . _cached_force_gpu . Equals ( force_gpu ) )
485+ {
486+ throw new ValueError ( @"The force_gpu value used to get the cached session is
487+ different than the one that was used to create the
488+ session. Maybe create a new session with
489+ self.session()" ) ;
490+ }
491+ return _cached_session ;
492+ }
493+ }
494+
495+ [ TestCleanup ]
496+ public void Cleanup ( )
497+ {
498+ _ClearCachedSession ( ) ;
499+ }
500+
315501 #endregion
316502
317503 public void AssetSequenceEqual < T > ( T [ ] a , T [ ] b )
0 commit comments