Skip to content

Commit dd399e3

Browse files
committed
name_scope._name_stack incorrect #110
1 parent c6f9ec6 commit dd399e3

File tree

6 files changed

+133
-36
lines changed

6 files changed

+133
-36
lines changed

src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45

56
namespace Tensorflow
@@ -19,14 +20,51 @@ public static void gradients(object ys,
1920

2021
public static void _GradientsHelper(object ys,
2122
object xs,
22-
List<Tensor> grad_ys = null,
23+
object grad_ys = null,
2324
string name = "gradients",
2425
bool colocate_gradients_with_ops = false,
2526
bool gate_gradients = false,
27+
object stop_gradients = null,
2628
Graph src_graph = null)
2729
{
2830
if (src_graph == null)
2931
src_graph = ops.get_default_graph();
32+
33+
var ys1 = _AsList(ys);
34+
var xs1 = _AsList(xs);
35+
List<Tensor> grad_ys1 = null;
36+
List<Tensor> stop_gradients1 = stop_gradients == null ? new List<Tensor>() : _AsList(stop_gradients);
37+
if (grad_ys == null)
38+
grad_ys1 = ys1.Select(x => new Tensor(IntPtr.Zero)).ToList();
39+
else
40+
grad_ys = _AsList(grad_ys);
41+
42+
var all = new List<Tensor>();
43+
all.AddRange(ys1);
44+
all.AddRange(xs1);
45+
all.AddRange(stop_gradients1);
46+
all.AddRange(grad_ys1);
47+
48+
string grad_scope = "";
49+
using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all))
50+
grad_scope = namescope;
51+
}
52+
53+
private static List<Tensor> _AsList(object ys)
54+
{
55+
List<Tensor> ret = null;
56+
57+
switch (ys)
58+
{
59+
case Tensor value:
60+
ret = new List<Tensor> { value };
61+
break;
62+
case List<Tensor> value:
63+
ret = value;
64+
break;
65+
}
66+
67+
return ret;
3068
}
3169
}
3270
}

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ public partial class Graph : IDisposable
2323
private int _next_id_counter;
2424
private List<String> _unfetchable_ops = new List<string>();
2525

26-
private string _name_stack;
26+
public string _name_stack = "";
27+
public string old_stack = "";
2728
public string _graph_key;
2829
public Status Status { get; }
2930

@@ -168,41 +169,68 @@ public string get_name_scope()
168169

169170
public string name_scope(string name)
170171
{
172+
old_stack = _name_stack;
173+
171174
string new_stack = "";
172175

176+
173177
if (name.EndsWith("/"))
174-
{
175178
new_stack = ops._name_from_scope_name(name);
176-
}
177179
else
178-
{
179180
new_stack = unique_name(name);
180-
}
181181

182182
_name_stack = new_stack;
183183

184184
return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/";
185185
}
186186

187-
public string unique_name(string name)
187+
public string unique_name(string name, bool mark_as_used = true)
188188
{
189189
if (!String.IsNullOrEmpty(_name_stack))
190190
{
191191
name = _name_stack + "/" + name;
192192
}
193193

194194
var name_key = name.ToLower();
195+
int i = 0;
195196
if (_names_in_use.ContainsKey(name_key))
196197
{
197-
_names_in_use[name_key]++;
198+
foreach (var item in _names_in_use)
199+
{
200+
if (item.Key == name_key)
201+
{
202+
i = _names_in_use[name_key];
203+
break;
204+
}
205+
206+
i++;
207+
}
198208
}
199-
else
209+
210+
if (mark_as_used)
211+
if (_names_in_use.ContainsKey(name_key))
212+
_names_in_use[name_key]++;
213+
else
214+
_names_in_use[name_key] = i + 1;
215+
216+
if (i > 0)
200217
{
201-
_names_in_use[name_key] = 1;
202-
return name;
218+
var base_name_key = name_key;
219+
220+
// Make sure the composed name key is not already used.
221+
if (_names_in_use.ContainsKey(name_key))
222+
{
223+
name_key = $"{base_name_key}_{i}";
224+
i += 1;
225+
}
226+
227+
if (mark_as_used)
228+
_names_in_use[name_key] = 1;
229+
230+
name = $"{name}_{i - 1}";
203231
}
204232

205-
return $"{name}_{_names_in_use[name_key]}";
233+
return name;
206234
}
207235

208236
public TF_Output[] ReturnOutputs(IntPtr results)

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ public Operation _apply_op_helper(string op_type_name, string name = "", Diction
2020
name = op_type_name;
2121
}
2222

23-
string scope = new ops.name_scope(name);
23+
string scope = "";
24+
using (var namescope = new ops.name_scope<object>(name))
25+
scope = namescope;
2426

2527
var default_type_attr_map = new Dictionary<string, object>();
2628
foreach (var attr_def in op_def.Attr)

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,36 @@ private void _init_from_args(object initial_value,
5151
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
5252

5353
ops.init_scope();
54-
name = new ops.name_scope(name, "Variable", init_from_fn ? new List<object>() : new List<object> { initial_value });
55-
if (init_from_fn)
54+
var values = init_from_fn ? new List<object>() : new List<object> { initial_value };
55+
using (var namescope = new ops.name_scope<object>(name, "Variable", values))
5656
{
57+
name = namescope;
5758

58-
}
59-
else
60-
{
61-
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
62-
}
59+
if (init_from_fn)
60+
{
6361

64-
var shape = _initial_value.shape;
65-
dtype = _initial_value.dtype;
66-
_variable = gen_state_ops.variable_v2(shape, dtype, name);
62+
}
63+
else
64+
{
65+
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
66+
}
6767

68-
// Manually overrides the variable's shape with the initial value's.
69-
if (validate_shape)
70-
{
71-
var initial_value_shape = _initial_value.shape;
72-
}
68+
var shape = _initial_value.shape;
69+
dtype = _initial_value.dtype;
70+
_variable = gen_state_ops.variable_v2(shape, dtype, name);
71+
72+
// Manually overrides the variable's shape with the initial value's.
73+
if (validate_shape)
74+
{
75+
var initial_value_shape = _initial_value.shape;
76+
}
7377

74-
// If 'initial_value' makes use of other variables, make sure we don't
75-
// have an issue if these other variables aren't initialized first by
76-
// using their initialized_value() method.
78+
// If 'initial_value' makes use of other variables, make sure we don't
79+
// have an issue if these other variables aren't initialized first by
80+
// using their initialized_value() method.
7781

78-
ops.add_to_collections(collections, this);
82+
ops.add_to_collections(collections, this);
83+
}
7984
}
8085

8186
public Tensor _ref()

src/TensorFlowNET.Core/ops.name_scope.cs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@ namespace Tensorflow
66
{
77
public partial class ops
88
{
9-
public class name_scope
9+
public class name_scope<T> : IDisposable
1010
{
1111
public string _name;
1212
public string _default_name;
1313
public object _values;
1414
public Context _ctx;
1515
public string _name_scope;
16+
private object _g_manager;
1617

17-
public name_scope(string name, string default_name = "", List<object> values = null)
18+
public name_scope(string name, string default_name = "", List<T> values = null)
1819
{
1920
_name = name;
2021
_default_name = default_name;
@@ -31,11 +32,23 @@ public string __enter__()
3132
_name = _default_name;
3233
}
3334

35+
Graph g = null;
36+
if (_values is List<Tensor> values)
37+
g = _get_graph_from_inputs(values);
38+
39+
if (g == null)
40+
g = get_default_graph();
41+
42+
return g.name_scope(_name); ;
43+
}
44+
45+
public void Dispose()
46+
{
3447
var g = get_default_graph();
35-
return g.name_scope(_name);
48+
g._name_stack = g.old_stack;
3649
}
3750

38-
public static implicit operator string(name_scope ns)
51+
public static implicit operator string(name_scope<T> ns)
3952
{
4053
return ns._name_scope;
4154
}

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ public static Graph get_default_graph()
3434
return tf.Graph();
3535
}
3636

37+
public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
38+
{
39+
foreach(var op_input in op_input_list)
40+
{
41+
// Determine if this is a valid graph_element.
42+
var graph_element = op_input;
43+
}
44+
45+
return get_default_graph();
46+
}
47+
3748
public static Tensor convert_to_tensor(object value, string name = "")
3849
{
3950
var nd = tensor_util.convert_to_numpy_ndarray(value);

0 commit comments

Comments
 (0)