@@ -100,10 +100,12 @@ public partial class Graph : DisposableObject
100100 /// </summary>
101101 private bool _finalized = false ;
102102
103+
103104 /// <summary>
104- /// Arbitrary collections of objects.
105+ /// Arbitrary collections of objects inside the graph.
106+ /// TODO: Access might be slow (-> O(n)) depending on size.
105107 /// </summary>
106- private Dictionary < string , object > _collections = new Dictionary < string , object > ( ) ;
108+ private readonly ICollection < ( string name , string scope , object item ) > _collections = new List < ( string name , string scope , object item ) > ( ) ;
107109
108110 public bool building_function ;
109111
@@ -228,16 +230,14 @@ private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tenso
228230 throw new Exception ( $ "Can not convert a { obj . GetType ( ) . Name } into a { types_str } .") ;
229231 }
230232
231- public void add_to_collection < T > ( string name , T value )
233+ public void add_to_collection ( string name , object value )
232234 {
233235 _check_not_finalized ( ) ;
234- if ( _collections . ContainsKey ( name ) )
235- ( _collections [ name ] as List < T > ) . Add ( value ) ;
236- else
237- _collections [ name ] = new List < T > { value } ;
236+ _collections . Add ( ( name , null , value ) ) ;
238237 }
239238
240- public void add_to_collections < T > ( List < string > names , T value )
239+
240+ public void add_to_collections ( List < string > names , object value )
241241 {
242242 foreach ( string name in names )
243243 add_to_collection ( name , value ) ;
@@ -278,12 +278,6 @@ public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes
278278
279279 _create_op_helper ( op , true ) ;
280280
281- /*Console.Write($"create_op: {op_type} '{node_def.Name}'");
282- Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
283- Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
284- Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
285- Console.WriteLine();*/
286-
287281 return op ;
288282 }
289283
@@ -422,46 +416,34 @@ public TF_Output[] ReturnOutputs(IntPtr results)
422416
423417 public string [ ] get_all_collection_keys ( )
424418 {
425- return _collections . Keys . Where ( x => ! x . StartsWith ( "__" ) ) . ToArray ( ) ;
419+ return ( from c in _collections where ! c . name . StartsWith ( "__" ) select c . name ) . ToArray ( ) ;
426420 }
427421
428- public object get_collection ( string name , string scope = null )
422+ public List < object > get_collection ( string name , string scope = null )
429423 {
430- return _collections . ContainsKey ( name ) ? _collections [ name ] : null ;
431- }
432-
424+ return get_collection < object > ( name , scope ) ;
425+ }
426+
427+
433428 public List < T > get_collection < T > ( string name , string scope = null )
434- {
435- List < T > t = default ;
436- var collection = _collections . ContainsKey ( name ) ? _collections [ name ] : new List < T > ( ) ;
437- switch ( collection )
438- {
439- case List < VariableV1 > list :
440- t = list . Select ( x => ( T ) ( object ) x ) . ToList ( ) ;
441- break ;
442- case List < ResourceVariable > list :
443- t = list . Select ( x => ( T ) ( object ) x ) . ToList ( ) ;
444- break ;
445- case List < RefVariable > list :
446- t = list . Select ( x => ( T ) ( object ) x ) . ToList ( ) ;
447- break ;
448- case List < Tensor > list :
449- t = list . Select ( x => ( T ) ( object ) x ) . ToList ( ) ;
450- break ;
451- case List < Operation > list :
452- t = list . Select ( x => ( T ) ( object ) x ) . ToList ( ) ;
453- break ;
454- default :
455- throw new NotImplementedException ( $ "get_collection<{ typeof ( T ) . FullName } >") ;
456- }
457- return t ;
458- }
459-
429+ {
430+
431+ return ( from c in _collections
432+ where c . name == name &&
433+ ( scope == null || c . scope == scope ) &&
434+ implementationOf < T > ( c . item )
435+ select ( T ) ( c . item ) ) . ToList ( ) ;
436+
437+ }
438+
439+ private static bool implementationOf < T > ( object item )
440+ {
441+ return ( item . GetType ( ) == typeof ( T ) || item . GetType ( ) . IsSubclassOf ( typeof ( T ) ) ) ;
442+ }
443+
460444 public List < T > get_collection_ref < T > ( string name )
461445 {
462- if ( ! _collections . ContainsKey ( name ) )
463- _collections [ name ] = new List < T > ( ) ;
464- return _collections [ name ] as List < T > ;
446+ return get_collection < T > ( name ) ;
465447 }
466448
467449 public void prevent_feeding ( Tensor tensor )
0 commit comments