Skip to content

Commit 945ac02

Browse files
committed
Rollback Fix collections typing #448
1 parent 8b38d0d commit 945ac02

File tree

7 files changed

+146
-97
lines changed

7 files changed

+146
-97
lines changed

src/TensorFlowNET.Core/APIs/tf.variable.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ public partial class tensorflow
2323
{
2424
public VariableV1[] global_variables(string scope = null)
2525
{
26-
return (ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope))
26+
return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>)
2727
.ToArray();
2828
}
2929

3030
public Operation global_variables_initializer()
3131
{
3232
var g = variables.global_variables();
33-
return variables.variables_initializer(g?.ToArray());
33+
return variables.variables_initializer(g.ToArray());
3434
}
3535

3636
/// <summary>
@@ -54,9 +54,9 @@ public RefVariable get_variable(string name,
5454
{
5555
var scope = Tensorflow.variable_scope.get_variable_scope();
5656
var store = Tensorflow.variable_scope._get_default_variable_store();
57-
return scope.get_variable(store,
58-
name,
59-
shape: shape,
57+
return scope.get_variable(store,
58+
name,
59+
shape: shape,
6060
dtype: dtype,
6161
use_resource: use_resource,
6262
validate_shape: validate_shape,

src/TensorFlowNET.Core/Binding.FuncTools.cs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,23 @@ public static partial class Binding
1010
{
1111
public static class functools
1212
{
13-
public static Func<Tin, Tout> partial<Tin, Tout>(Func<Tin, Tout> func, Tin arg)
14-
=> (arg0) => func(arg0);
13+
public static PartialFunc<Tin, Tout> partial<Tin, Tout>(Func<Tin, Tout> func, Tin arg)
14+
=> new PartialFunc<Tin, Tout>
15+
{
16+
args = arg,
17+
invoke = func
18+
};
1519

1620
public static Func<Tin1, Tin2, Tout> partial<Tin1, Tin2, Tout>(Func<Tin1, Tin2, Tout> func, (Tin1, Tin2) args)
17-
=> (arg1, arg2) => func(arg1, arg2);
21+
=> (arg1, arg2) => func(args.Item1, args.Item2);
22+
}
23+
24+
public class PartialFunc<Tin, Tout>
25+
{
26+
public Tin args { get; set; }
27+
public object[] keywords { get; set; }
28+
29+
public Func<Tin, Tout> invoke { get; set; }
1830
}
1931
}
2032
}

src/TensorFlowNET.Core/Framework/meta_graph.cs

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scop
4646

4747
if (!string.IsNullOrEmpty(unbound_inputs_col_name))
4848
{
49-
foreach(var col in meta_graph_def.CollectionDef)
49+
foreach (var col in meta_graph_def.CollectionDef)
5050
{
51-
if(col.Key == unbound_inputs_col_name)
51+
if (col.Key == unbound_inputs_col_name)
5252
{
5353
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
5454
}
@@ -78,7 +78,7 @@ public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scop
7878

7979
// Restores all the other collections.
8080
var variable_objects = new Dictionary<ByteString, VariableV1>();
81-
foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
81+
foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
8282
{
8383
// Don't add unbound_inputs to the new graph.
8484
if (col.Key == unbound_inputs_col_name)
@@ -87,7 +87,7 @@ public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scop
8787
switch (col.Value.KindCase)
8888
{
8989
case KindOneofCase.NodeList:
90-
foreach(var value in col.Value.NodeList.Value)
90+
foreach (var value in col.Value.NodeList.Value)
9191
{
9292
var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
9393
graph.add_to_collection(col.Key, col_op);
@@ -115,7 +115,7 @@ public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scop
115115
}
116116
else
117117
{
118-
foreach(var value in col.Value.BytesList.Value)
118+
foreach (var value in col.Value.BytesList.Value)
119119
{
120120
switch (col.Key)
121121
{
@@ -139,7 +139,7 @@ public static (Dictionary<string, VariableV1>, ITensorOrOperation[]) import_scop
139139
}
140140
}
141141
}
142-
142+
143143
break;
144144
default:
145145
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
@@ -173,8 +173,8 @@ public static (MetaGraphDef, Dictionary<string, VariableV1>) export_scoped_meta_
173173
string unbound_inputs_col_name = "unbound_inputs",
174174
bool clear_devices = false,
175175
SaverDef saver_def = null,
176-
bool clear_extraneous_savers= false,
177-
bool strip_default_attrs= false,
176+
bool clear_extraneous_savers = false,
177+
bool strip_default_attrs = false,
178178
byte[] meta_info_def = null)
179179
{
180180
var graph = ops.get_default_graph();
@@ -236,12 +236,12 @@ private static MetaGraphDef create_meta_graph_def(MetaInfoDef meta_info_def = nu
236236
meta_graph_def.GraphDef = graph_def;
237237

238238
// Fills in meta_info_def.stripped_op_list using the ops from graph_def.
239-
if (meta_graph_def.MetaInfoDef.StrippedOpList == null ||
239+
if (meta_graph_def.MetaInfoDef.StrippedOpList == null ||
240240
meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0)
241241
meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef);
242242

243243
var clist = graph.get_all_collection_keys();
244-
foreach(var ctype in clist)
244+
foreach (var ctype in clist)
245245
{
246246
if (clear_extraneous_savers)
247247
{
@@ -256,30 +256,34 @@ private static MetaGraphDef create_meta_graph_def(MetaInfoDef meta_info_def = nu
256256
return meta_graph_def;
257257
}
258258

259-
private static void add_collection_def(MetaGraphDef meta_graph_def,
260-
string key,
259+
private static void add_collection_def(MetaGraphDef meta_graph_def,
260+
string key,
261261
Graph graph = null,
262262
string export_scope = "")
263263
{
264264
if (!meta_graph_def.CollectionDef.ContainsKey(key))
265265
meta_graph_def.CollectionDef[key] = new CollectionDef();
266266
var col_def = meta_graph_def.CollectionDef[key];
267-
col_def.NodeList = new Types.NodeList();
268-
col_def.BytesList = new Types.BytesList();
269-
foreach (object value in graph.get_collection(key))
267+
268+
switch (graph.get_collection(key))
270269
{
271-
switch (value)
272-
{
273-
case RefVariable x:
270+
case List<RefVariable> collection_list:
271+
col_def.BytesList = new Types.BytesList();
272+
foreach (var x in collection_list)
273+
{
274274
var proto = x.to_proto(export_scope);
275275
col_def.BytesList.Value.Add(proto.ToByteString());
276-
break;
277-
case ITensorOrOperation x2:
278-
col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope));
279-
break;
280-
default:
281-
break;
282-
}
276+
}
277+
278+
break;
279+
case List<object> collection_list:
280+
col_def.NodeList = new Types.NodeList();
281+
foreach (var x in collection_list)
282+
if (x is ITensorOrOperation x2)
283+
col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope));
284+
break;
285+
case List<Operation> collection_list:
286+
break;
283287
}
284288
}
285289

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ all variables that are created during the construction of a graph. The caller
7777
/// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks>
7878
public partial class Graph : DisposableObject
7979
#if !SERIALIZABLE
80-
,IEnumerable<Operation>
80+
, IEnumerable<Operation>
8181
#endif
8282
{
8383
private Dictionary<int, ITensorOrOperation> _nodes_by_id;
@@ -100,15 +100,13 @@ public partial class Graph : DisposableObject
100100
/// </summary>
101101
private bool _finalized = false;
102102

103-
104103
/// <summary>
105-
/// Arbitrary collections of objects inside the graph.
106-
/// TODO: Access might be slow (-> O(n)) depending on size.
104+
/// Arbitrary collections of objects.
107105
/// </summary>
108-
private readonly ICollection<(string name, string scope, object item)> _collections = new List<(string name, string scope, object item)>();
106+
private Dictionary<string, object> _collections = new Dictionary<string, object>();
109107

110-
public bool building_function;
111-
108+
public bool building_function;
109+
112110
public Graph()
113111
{
114112
_handle = c_api.TF_NewGraph();
@@ -230,14 +228,16 @@ private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tenso
230228
throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}.");
231229
}
232230

233-
public void add_to_collection(string name, object value)
231+
public void add_to_collection<T>(string name, T value)
234232
{
235233
_check_not_finalized();
236-
_collections.Add((name, null, value));
234+
if (_collections.ContainsKey(name))
235+
(_collections[name] as List<T>).Add(value);
236+
else
237+
_collections[name] = new List<T> { value };
237238
}
238239

239-
240-
public void add_to_collections(List<string> names, object value)
240+
public void add_to_collections<T>(List<string> names, T value)
241241
{
242242
foreach (string name in names)
243243
add_to_collection(name, value);
@@ -278,6 +278,12 @@ public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes
278278

279279
_create_op_helper(op, true);
280280

281+
/*Console.Write($"create_op: {op_type} '{node_def.Name}'");
282+
Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
283+
Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
284+
Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
285+
Console.WriteLine();*/
286+
281287
return op;
282288
}
283289

@@ -394,7 +400,7 @@ public string unique_name(string name, bool mark_as_used = true)
394400
_names_in_use[name_key] = 1;
395401

396402
// Return the new name with the original capitalization of the given name.
397-
name = $"{name}_{i-1}";
403+
name = $"{name}_{i - 1}";
398404
}
399405
return name;
400406
}
@@ -407,43 +413,55 @@ public TF_Output[] ReturnOutputs(IntPtr results)
407413
TF_Output[] return_outputs = new TF_Output[num_return_outputs];
408414
unsafe
409415
{
410-
var tf_output_ptr = (TF_Output*) return_output_handle;
411-
for (int i = 0; i < num_return_outputs; i++)
416+
var tf_output_ptr = (TF_Output*)return_output_handle;
417+
for (int i = 0; i < num_return_outputs; i++)
412418
return_outputs[i] = *(tf_output_ptr + i);
413419
return return_outputs;
414420
}
415421
}
416422

417423
public string[] get_all_collection_keys()
418424
{
419-
return (from c in _collections where !c.name.StartsWith("__") select c.name).ToArray();
425+
return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
420426
}
421427

422-
public List<object> get_collection(string name, string scope = null)
428+
public object get_collection(string name, string scope = null)
423429
{
424-
return get_collection<object>(name, scope);
425-
}
426-
427-
430+
return _collections.ContainsKey(name) ? _collections[name] : null;
431+
}
432+
428433
public List<T> get_collection<T>(string name, string scope = null)
429-
{
430-
431-
return (from c in _collections
432-
where c.name == name &&
433-
(scope == null || c.scope == scope) &&
434-
implementationOf<T>(c.item)
435-
select (T)(c.item)).ToList();
436-
437-
}
438-
439-
private static bool implementationOf<T>(object item)
440-
{
441-
return (item.GetType() == typeof(T) || item.GetType().IsSubclassOf(typeof(T)));
442-
}
443-
434+
{
435+
List<T> t = default;
436+
var collection = _collections.ContainsKey(name) ? _collections[name] : new List<T>();
437+
switch (collection)
438+
{
439+
case List<VariableV1> list:
440+
t = list.Select(x => (T)(object)x).ToList();
441+
break;
442+
case List<ResourceVariable> list:
443+
t = list.Select(x => (T)(object)x).ToList();
444+
break;
445+
case List<RefVariable> list:
446+
t = list.Select(x => (T)(object)x).ToList();
447+
break;
448+
case List<Tensor> list:
449+
t = list.Select(x => (T)(object)x).ToList();
450+
break;
451+
case List<Operation> list:
452+
t = list.Select(x => (T)(object)x).ToList();
453+
break;
454+
default:
455+
throw new NotImplementedException($"get_collection<{typeof(T).FullName}>");
456+
}
457+
return t;
458+
}
459+
444460
public List<T> get_collection_ref<T>(string name)
445461
{
446-
return get_collection<T>(name);
462+
if (!_collections.ContainsKey(name))
463+
_collections[name] = new List<T>();
464+
return _collections[name] as List<T>;
447465
}
448466

449467
public void prevent_feeding(Tensor tensor)
@@ -497,7 +515,7 @@ public TensorShape GetTensorShape(TF_Output output)
497515
string debugString = string.Empty;
498516
public override string ToString()
499517
{
500-
return $"{graph_key}, ({_handle})";
518+
return $"{graph_key}, ({_handle})";
501519
/*if (string.IsNullOrEmpty(debugString))
502520
{
503521
int len = 0;
@@ -514,7 +532,7 @@ private IEnumerable<Operation> GetEnumerable()
514532
IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator()
515533
=> GetEnumerable().GetEnumerator();
516534

517-
IEnumerator IEnumerable.GetEnumerator()
535+
IEnumerator IEnumerable.GetEnumerator()
518536
=> throw new NotImplementedException();
519537
#endif
520538

src/TensorFlowNET.Core/Training/TrainingUtil.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public static RefVariable create_global_step(Graph graph = null)
1616
// Create in proper graph and base name_scope.
1717
var g = graph.as_default();
1818
g.name_scope(null);
19-
var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64,
19+
var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new int[0], dtype: dtypes.int64,
2020
initializer: tf.zeros_initializer,
2121
trainable: false,
2222
aggregation: VariableAggregation.OnlyFirstReplica,

0 commit comments

Comments
 (0)