@@ -144,6 +144,37 @@ public void assertAllClose(double value, NDArray array2, double eps = 1e-5)
144144 Assert . IsTrue ( np . allclose ( array1 , array2 , rtol : eps ) ) ;
145145 }
146146
147+ private class CollectionComparer : System . Collections . IComparer
148+ {
149+ private readonly double _epsilon ;
150+
151+ public CollectionComparer ( double eps = 1e-06 ) {
152+ _epsilon = eps ;
153+ }
154+ public int Compare ( object x , object y )
155+ {
156+ var a = ( double ) x ;
157+ var b = ( double ) y ;
158+
159+ double delta = Math . Abs ( a - b ) ;
160+ if ( delta < _epsilon )
161+ {
162+ return 0 ;
163+ }
164+ return a . CompareTo ( b ) ;
165+ }
166+ }
167+
168+ public void assertAllCloseAccordingToType < T > (
169+ T [ ] expected ,
170+ T [ ] given ,
171+ double eps = 1e-6 ,
172+ float float_eps = 1e-6f )
173+ {
174+ // TODO: check if any of arguments is not double and change toletance
175+ CollectionAssert . AreEqual ( expected , given , new CollectionComparer ( eps ) ) ;
176+ }
177+
147178 public void assertProtoEquals ( object toProto , object o )
148179 {
149180 throw new NotImplementedException ( ) ;
@@ -153,6 +184,20 @@ public void assertProtoEquals(object toProto, object o)
153184
154185 #region tensor evaluation and test session
155186
187+ private Session _cached_session = null ;
188+ private Graph _cached_graph = null ;
189+ private object _cached_config = null ;
190+ private bool _cached_force_gpu = false ;
191+
192+ private void _ClearCachedSession ( )
193+ {
194+ if ( self . _cached_session != null )
195+ {
196+ self . _cached_session . Dispose ( ) ;
197+ self . _cached_session = null ;
198+ }
199+ }
200+
156201 //protected object _eval_helper(Tensor[] tensors)
157202 //{
158203 // if (tensors == null)
@@ -218,9 +263,56 @@ public T evaluate<T>(Tensor tensor)
218263 }
219264
220265
221- public Session cached_session ( )
266+ ///Returns a TensorFlow Session for use in executing tests.
267+ public Session cached_session (
268+ Graph graph = null , object config = null , bool use_gpu = false , bool force_gpu = false )
222269 {
223- throw new NotImplementedException ( ) ;
270+ // This method behaves differently than self.session(): for performance reasons
271+ // `cached_session` will by default reuse the same session within the same
272+ // test.The session returned by this function will only be closed at the end
273+ // of the test(in the TearDown function).
274+
275+ // Use the `use_gpu` and `force_gpu` options to control where ops are run.If
276+ // `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if
277+ // `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
278+ // possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
279+ // the CPU.
280+
281+ // Example:
282+ // python
283+ // class MyOperatorTest(test_util.TensorFlowTestCase) :
284+ // def testMyOperator(self):
285+ // with self.cached_session() as sess:
286+ // valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
287+ // result = MyOperator(valid_input).eval()
288+ // self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
289+ // invalid_input = [-1.0, 2.0, 7.0]
290+ // with self.assertRaisesOpError("negative input not supported"):
291+ // MyOperator(invalid_input).eval()
292+
293+
294+ // Args:
295+ // graph: Optional graph to use during the returned session.
296+ // config: An optional config_pb2.ConfigProto to use to configure the
297+ // session.
298+ // use_gpu: If True, attempt to run as many ops as possible on GPU.
299+ // force_gpu: If True, pin all ops to `/device:GPU:0`.
300+
301+ // Yields:
302+ // A Session object that should be used as a context manager to surround
303+ // the graph building and execution code in a test case.
304+
305+
306+ // TODO:
307+ // if context.executing_eagerly():
308+ // return self._eval_helper(tensors)
309+ // else:
310+ {
311+ var sess = self . _get_cached_session (
312+ graph , config , force_gpu , crash_if_inconsistent_args : true ) ;
313+ using var cached = self . _constrain_devices_and_set_default ( sess , use_gpu , force_gpu ) ;
314+ return cached ;
315+ }
224316 }
225317
226318 //Returns a TensorFlow Session for use in executing tests.
@@ -268,6 +360,40 @@ public Session session(Graph graph = null, object config = null, bool use_gpu =
268360 return s . as_default ( ) ;
269361 }
270362
363+ private Session _constrain_devices_and_set_default ( Session sess , bool use_gpu , bool force_gpu )
364+ {
365+ // Set the session and its graph to global default and constrain devices."""
366+ if ( tf . executing_eagerly ( ) )
367+ return null ;
368+ else
369+ {
370+ sess . graph . as_default ( ) ;
371+ sess . as_default ( ) ;
372+ {
373+ if ( force_gpu )
374+ {
375+ // TODO:
376+
377+ // Use the name of an actual device if one is detected, or
378+ // '/device:GPU:0' otherwise
379+ /* var gpu_name = gpu_device_name();
380+ if (!gpu_name)
381+ gpu_name = "/device:GPU:0"
382+ using (sess.graph.device(gpu_name)) {
383+ yield return sess;
384+ }*/
385+ return sess ;
386+ }
387+ else if ( use_gpu )
388+ return sess ;
389+ else
390+ using ( sess . graph . device ( "/device:CPU:0" ) )
391+ return sess ;
392+ }
393+
394+ }
395+ }
396+
271397 // See session() for details.
272398 private Session _create_session ( Graph graph , object cfg , bool forceGpu )
273399 {
@@ -312,6 +438,54 @@ private Session _create_session(Graph graph, object cfg, bool forceGpu)
312438 return new Session ( graph ) ; //, config = prepare_config(config))
313439 }
314440
441+ private Session _get_cached_session (
442+ Graph graph = null ,
443+ object config = null ,
444+ bool force_gpu = false ,
445+ bool crash_if_inconsistent_args = true )
446+ {
447+ // See cached_session() for documentation.
448+ if ( self . _cached_session == null )
449+ {
450+ var sess = self . _create_session ( graph , config , force_gpu ) ;
451+ self . _cached_session = sess ;
452+ self . _cached_graph = graph ;
453+ self . _cached_config = config ;
454+ self . _cached_force_gpu = force_gpu ;
455+ return sess ;
456+ }
457+ else
458+ {
459+
460+ if ( crash_if_inconsistent_args && ! self . _cached_graph . Equals ( graph ) )
461+ throw new ValueError ( @"The graph used to get the cached session is
462+ different than the one that was used to create the
463+ session. Maybe create a new session with
464+ self.session()" ) ;
465+ if ( crash_if_inconsistent_args && ! self . _cached_config . Equals ( config ) )
466+ {
467+ throw new ValueError ( @"The config used to get the cached session is
468+ different than the one that was used to create the
469+ session. Maybe create a new session with
470+ self.session()" ) ;
471+ }
472+ if ( crash_if_inconsistent_args && ! self . _cached_force_gpu . Equals ( force_gpu ) )
473+ {
474+ throw new ValueError ( @"The force_gpu value used to get the cached session is
475+ different than the one that was used to create the
476+ session. Maybe create a new session with
477+ self.session()" ) ;
478+ }
479+ return _cached_session ;
480+ }
481+ }
482+
483+ [ TestCleanup ]
484+ public void Cleanup ( )
485+ {
486+ _ClearCachedSession ( ) ;
487+ }
488+
315489 #endregion
316490
317491 public void AssetSequenceEqual < T > ( T [ ] a , T [ ] b )
0 commit comments