@@ -77,7 +77,7 @@ all variables that are created during the construction of a graph. The caller
7777 /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks>
7878 public partial class Graph : DisposableObject
7979#if ! SERIALIZABLE
80- , IEnumerable < Operation >
80+ , IEnumerable < Operation >
8181#endif
8282 {
8383 private Dictionary < int , ITensorOrOperation > _nodes_by_id ;
@@ -100,15 +100,13 @@ public partial class Graph : DisposableObject
100100 /// </summary>
101101 private bool _finalized = false ;
102102
103-
104103 /// <summary>
105- /// Arbitrary collections of objects inside the graph.
106- /// TODO: Access might be slow (-> O(n)) depending on size.
104+ /// Arbitrary collections of objects.
107105 /// </summary>
108- private readonly ICollection < ( string name , string scope , object item ) > _collections = new List < ( string name , string scope , object item ) > ( ) ;
106+ private Dictionary < string , object > _collections = new Dictionary < string , object > ( ) ;
109107
110- public bool building_function ;
111-
108+ public bool building_function ;
109+
112110 public Graph ( )
113111 {
114112 _handle = c_api . TF_NewGraph ( ) ;
@@ -230,14 +228,16 @@ private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tenso
230228 throw new Exception ( $ "Can not convert a { obj . GetType ( ) . Name } into a { types_str } .") ;
231229 }
232230
233- public void add_to_collection ( string name , object value )
231+ public void add_to_collection < T > ( string name , T value )
234232 {
235233 _check_not_finalized ( ) ;
236- _collections . Add ( ( name , null , value ) ) ;
234+ if ( _collections . ContainsKey ( name ) )
235+ ( _collections [ name ] as List < T > ) . Add ( value ) ;
236+ else
237+ _collections [ name ] = new List < T > { value } ;
237238 }
238239
239-
240- public void add_to_collections ( List < string > names , object value )
240+ public void add_to_collections < T > ( List < string > names , T value )
241241 {
242242 foreach ( string name in names )
243243 add_to_collection ( name , value ) ;
@@ -278,6 +278,12 @@ public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes
278278
279279 _create_op_helper ( op , true ) ;
280280
281+ /*Console.Write($"create_op: {op_type} '{node_def.Name}'");
282+ Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
283+ Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
284+ Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
285+ Console.WriteLine();*/
286+
281287 return op ;
282288 }
283289
@@ -394,7 +400,7 @@ public string unique_name(string name, bool mark_as_used = true)
394400 _names_in_use [ name_key ] = 1 ;
395401
396402 // Return the new name with the original capitalization of the given name.
397- name = $ "{ name } _{ i - 1 } ";
403+ name = $ "{ name } _{ i - 1 } ";
398404 }
399405 return name ;
400406 }
@@ -407,43 +413,55 @@ public TF_Output[] ReturnOutputs(IntPtr results)
407413 TF_Output [ ] return_outputs = new TF_Output [ num_return_outputs ] ;
408414 unsafe
409415 {
410- var tf_output_ptr = ( TF_Output * ) return_output_handle ;
411- for ( int i = 0 ; i < num_return_outputs ; i ++ )
416+ var tf_output_ptr = ( TF_Output * ) return_output_handle ;
417+ for ( int i = 0 ; i < num_return_outputs ; i ++ )
412418 return_outputs [ i ] = * ( tf_output_ptr + i ) ;
413419 return return_outputs ;
414420 }
415421 }
416422
417423 public string [ ] get_all_collection_keys ( )
418424 {
419- return ( from c in _collections where ! c . name . StartsWith ( "__" ) select c . name ) . ToArray ( ) ;
425+ return _collections . Keys . Where ( x => ! x . StartsWith ( "__" ) ) . ToArray ( ) ;
420426 }
421427
422- public List < object > get_collection ( string name , string scope = null )
428+ public object get_collection ( string name , string scope = null )
423429 {
424- return get_collection < object > ( name , scope ) ;
425- }
426-
427-
430+ return _collections . ContainsKey ( name ) ? _collections [ name ] : null ;
431+ }
432+
428433 public List < T > get_collection < T > ( string name , string scope = null )
429- {
430-
431- return ( from c in _collections
432- where c . name == name &&
433- ( scope == null || c . scope == scope ) &&
434- implementationOf < T > ( c . item )
435- select ( T ) ( c . item ) ) . ToList ( ) ;
436-
437- }
438-
439- private static bool implementationOf < T > ( object item )
440- {
441- return ( item . GetType ( ) == typeof ( T ) || item . GetType ( ) . IsSubclassOf ( typeof ( T ) ) ) ;
442- }
443-
434+ {
435+ List < T > t = default ;
436+ var collection = _collections . ContainsKey ( name ) ? _collections [ name ] : new List < T > ( ) ;
437+ switch ( collection )
438+ {
439+ case List < VariableV1 > list :
440+ t = list . Select ( x => ( T ) ( object ) x ) . ToList ( ) ;
441+ break ;
442+ case List < ResourceVariable > list :
443+ t = list . Select ( x => ( T ) ( object ) x ) . ToList ( ) ;
444+ break ;
445+ case List < RefVariable > list :
446+ t = list . Select ( x => ( T ) ( object ) x ) . ToList ( ) ;
447+ break ;
448+ case List < Tensor > list :
449+ t = list . Select ( x => ( T ) ( object ) x ) . ToList ( ) ;
450+ break ;
451+ case List < Operation > list :
452+ t = list . Select ( x => ( T ) ( object ) x ) . ToList ( ) ;
453+ break ;
454+ default :
455+ throw new NotImplementedException ( $ "get_collection<{ typeof ( T ) . FullName } >") ;
456+ }
457+ return t ;
458+ }
459+
444460 public List < T > get_collection_ref < T > ( string name )
445461 {
446- return get_collection < T > ( name ) ;
462+ if ( ! _collections . ContainsKey ( name ) )
463+ _collections [ name ] = new List < T > ( ) ;
464+ return _collections [ name ] as List < T > ;
447465 }
448466
449467 public void prevent_feeding ( Tensor tensor )
@@ -497,7 +515,7 @@ public TensorShape GetTensorShape(TF_Output output)
497515 string debugString = string . Empty ;
498516 public override string ToString ( )
499517 {
500- return $ "{ graph_key } , ({ _handle } )";
518+ return $ "{ graph_key } , ({ _handle } )";
501519 /*if (string.IsNullOrEmpty(debugString))
502520 {
503521 int len = 0;
@@ -514,7 +532,7 @@ private IEnumerable<Operation> GetEnumerable()
514532 IEnumerator < Operation > IEnumerable < Operation > . GetEnumerator ( )
515533 => GetEnumerable ( ) . GetEnumerator ( ) ;
516534
517- IEnumerator IEnumerable . GetEnumerator ( )
535+ IEnumerator IEnumerable . GetEnumerator ( )
518536 => throw new NotImplementedException ( ) ;
519537#endif
520538
0 commit comments