Skip to content

Commit acf8fbd

Browse files
Mascha, PhilippOceania2018
authored andcommitted
Fixed error in type checking for generic get_collection<T>.
1 parent b87081c commit acf8fbd

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -424,18 +424,23 @@ public List<object> get_collection(string name, string scope = null)
424424
return get_collection<object>(name, scope);
425425
}
426426

427-
private IEnumerable<object> findObjects(string name, string scope)
427+
428+
public List<T> get_collection<T>(string name, string scope = null)
428429
{
429-
return (from c in _collections where c.name == name && (scope == null || c.scope == scope) select c.item);
430+
431+
return (from c in _collections
432+
where c.name == name &&
433+
(scope == null || c.scope == scope) &&
434+
implementationOf<T>(c.item)
435+
select (T)(c.item)).ToList();
436+
437+
}
438+
439+
private static bool implementationOf<T>(object item)
440+
{
441+
return (item.GetType() == typeof(T) || item.GetType().IsSubclassOf(typeof(T)));
430442
}
431443

432-
public List<T> get_collection<T>(string name, string scope = null)
433-
{
434-
435-
return (from c in findObjects(name, scope) where c.GetType().IsSubclassOf(typeof(T)) select (T)c).ToList();
436-
437-
}
438-
439444
public List<T> get_collection_ref<T>(string name)
440445
{
441446
return get_collection<T>(name);

src/TensorFlowNET.Core/Variables/variable_scope.py.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,6 @@ public static VariableScope get_variable_scope()
229229
return get_variable_scope_store().current_scope;
230230
}
231231

232-
233-
// TODO: Misses RefVariable as possible value type?
234232
public static _VariableScopeStore get_variable_scope_store()
235233
{
236234
var scope_store = ops.get_collection<_VariableScopeStore>(_VARSCOPESTORE_KEY).FirstOrDefault();

0 commit comments

Comments
 (0)