Skip to content

Commit c26fccf

Browse files
committed
fix name issue in placeholder.
control_flow_ops.switch.
1 parent b2b083a commit c26fccf

File tree

12 files changed

+188
-16
lines changed

12 files changed

+188
-16
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Framework
6+
{
7+
public class smart_module
8+
{
9+
public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null)
10+
{
11+
return control_flow_ops.cond(pred,
12+
true_fn: true_fn,
13+
false_fn: false_fn,
14+
name: name);
15+
}
16+
17+
public static bool smart_constant_value(Tensor pred)
18+
{
19+
var pred_value = tensor_util.constant_value(pred);
20+
return pred_value;
21+
}
22+
}
23+
}

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public partial class Graph : IPython, IDisposable
2020
private Dictionary<string, int> _names_in_use;
2121
public int _version;
2222
private int _next_id_counter;
23-
private List<String> _unfetchable_ops = new List<string>();
23+
private List<Operation> _unfetchable_ops = new List<Operation>();
2424
private List<Tensor> _unfeedable_tensors = new List<Tensor>();
2525

2626
public string _name_stack = "";
@@ -228,13 +228,13 @@ public int _next_id()
228228

229229
public bool is_fetchable<T>(T tensor_or_op)
230230
{
231-
if (tensor_or_op is Tensor)
231+
if (tensor_or_op is Tensor tensor)
232232
{
233-
return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ;
233+
return !_unfetchable_ops.Contains(tensor); ;
234234
}
235-
else if (tensor_or_op is Operation)
235+
else if (tensor_or_op is Operation op)
236236
{
237-
return !_unfetchable_ops.Contains((tensor_or_op as Operation).name);
237+
return !_unfetchable_ops.Contains(op);
238238
}
239239

240240
return false;
@@ -372,6 +372,11 @@ public void prevent_feeding(Tensor tensor)
372372
_unfeedable_tensors.Add(tensor);
373373
}
374374

375+
public void prevent_fetching(Operation op)
376+
{
377+
_unfetchable_ops.Add(op);
378+
}
379+
375380
public void Dispose()
376381
{
377382
c_api.TF_DeleteGraph(_handle);

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_D
4848
}
4949

5050
public Tensor __call__(Tensor inputs,
51+
Tensor training = null,
5152
VariableScope scope = null)
5253
{
5354
var input_list = new Tensor[] { inputs };
@@ -73,7 +74,7 @@ public Tensor __call__(Tensor inputs,
7374
// Symbolic execution on symbolic tensors. We will attempt to build
7475
// the corresponding TF subgraph inside `backend.get_graph()`
7576
var graph = backend.get_graph();
76-
outputs = call(inputs);
77+
outputs = call(inputs, training: training);
7778
_handle_activity_regularization(inputs, outputs);
7879
_set_mask_metadata(inputs, outputs, null);
7980
}
@@ -100,7 +101,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
100101
return null;
101102
}
102103

103-
protected virtual Tensor call(Tensor inputs)
104+
protected virtual Tensor call(Tensor inputs, Tensor training = null)
104105
{
105106
throw new NotImplementedException("Layer.call");
106107
}
@@ -143,13 +144,15 @@ protected virtual RefVariable add_weight(string name,
143144

144145
protected virtual void _init_set_name(string name)
145146
{
146-
if (string.IsNullOrEmpty(name))
147-
(_name, _base_name) = _make_unique_name();
147+
string base_name = name;
148+
if (name == null)
149+
(_name, base_name) = _make_unique_name();
150+
_base_name = base_name;
148151
}
149152

150153
protected virtual (string, string) _make_unique_name()
151154
{
152-
string base_name = "conv2d";
155+
string base_name = generic_utils.to_snake_case(this.GetType().Name);
153156
string name = base_layer_utils.unique_layer_name(base_name);
154157
return (name, base_name);
155158
}

src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Text;
5+
using Tensorflow.Keras.Utils;
56
using Tensorflow.Layers;
67

78
namespace Tensorflow.Keras.Layers
@@ -25,6 +26,7 @@ public class BatchNormalization : Layer
2526
private RefVariable gamma;
2627
private RefVariable beta;
2728
private RefVariable moving_mean;
29+
private RefVariable moving_variance;
2830

2931
public BatchNormalization(int axis = -1,
3032
float momentum = 0.99f,
@@ -103,7 +105,56 @@ protected override void build(TensorShape input_shape)
103105

104106
moving_mean = add_weight("moving_mean",
105107
param_shape,
106-
dtype: param_dtype);
108+
dtype: param_dtype,
109+
initializer: moving_mean_initializer,
110+
synchronization: VariableSynchronization.ON_READ,
111+
trainable: false,
112+
aggregation: VariableAggregation.MEAN);
113+
114+
moving_variance = add_weight("moving_variance",
115+
shape: param_shape,
116+
dtype: param_dtype,
117+
initializer: moving_variance_initializer,
118+
synchronization: VariableSynchronization.ON_READ,
119+
trainable: false,
120+
aggregation: VariableAggregation.MEAN);
121+
122+
if (renorm)
123+
throw new NotImplementedException("build when renorm is true");
124+
125+
built = true;
126+
}
127+
128+
protected override Tensor call(Tensor inputs, Tensor training = null)
129+
{
130+
Tensor outputs = null;
131+
132+
if (fused)
133+
{
134+
outputs = _fused_batch_norm(inputs, training: training);
135+
}
136+
137+
throw new NotImplementedException("BatchNormalization call");
138+
}
139+
140+
private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
141+
{
142+
var beta = this.beta;
143+
var gamma = this.gamma;
144+
145+
Action _fused_batch_norm_training = () =>
146+
{
147+
148+
};
149+
150+
Action _fused_batch_norm_inference = () =>
151+
{
152+
153+
};
154+
155+
tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);
156+
157+
throw new NotImplementedException("_fused_batch_norm");
107158
}
108159
}
109160
}

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ protected override void build(TensorShape input_shape)
9191
built = true;
9292
}
9393

94-
protected override Tensor call(Tensor inputs)
94+
protected override Tensor call(Tensor inputs, Tensor training = null)
9595
{
9696
var outputs = _convolution_op.__call__(inputs, kernel);
9797
if (use_bias)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Keras.Utils
7+
{
8+
public class generic_utils
9+
{
10+
public static string to_snake_case(string name)
11+
{
12+
return string.Concat(name.Select((x, i) =>
13+
{
14+
return i > 0 && char.IsUpper(x) && !Char.IsDigit(name[i - 1]) ?
15+
"_" + x.ToString() :
16+
x.ToString();
17+
})).ToLower();
18+
}
19+
}
20+
}

src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Text;
5+
using Tensorflow.Framework;
56

67
namespace Tensorflow.Keras.Utils
78
{
@@ -16,5 +17,13 @@ public static bool is_symbolic_tensor(Tensor tensor)
1617
{
1718
return true;
1819
}
20+
21+
public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null)
22+
{
23+
return smart_module.smart_cond(pred,
24+
true_fn: true_fn,
25+
false_fn: false_fn,
26+
name: name);
27+
}
1928
}
2029
}

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ public Layer(bool trainable = true,
2929

3030
public virtual Tensor apply(Tensor inputs, Tensor training = null)
3131
{
32-
return __call__(inputs);
32+
return __call__(inputs, training: training);
3333
}
3434

3535
public Tensor __call__(Tensor inputs,
36+
Tensor training = null,
3637
VariableScope scope = null)
3738
{
3839
_set_scope(scope);
@@ -51,7 +52,7 @@ public Tensor __call__(Tensor inputs,
5152

5253
Python.with(scope_context_manager, scope2 => _current_scope = scope2);
5354
// Actually call layer
54-
var outputs = base.__call__(inputs);
55+
var outputs = base.__call__(inputs, training: training);
5556

5657
// Update global default collections.
5758
//_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS);
@@ -63,7 +64,9 @@ protected virtual RefVariable add_weight(string name,
6364
int[] shape,
6465
TF_DataType dtype = TF_DataType.DtInvalid,
6566
IInitializer initializer = null,
66-
bool? trainable = null)
67+
bool? trainable = null,
68+
VariableSynchronization synchronization = VariableSynchronization.AUTO,
69+
VariableAggregation aggregation = VariableAggregation.NONE)
6770
{
6871
var default_graph = ops.get_default_graph();
6972
Graph init_graph = null;

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,5 +135,53 @@ public static Tensor _Identity(Tensor data, string name = null)
135135
else
136136
return gen_array_ops.identity(data, name: name);
137137
}
138+
139+
public static (Tensor, Tensor) cond(Tensor pred,
140+
Action true_fn = null,
141+
Action false_fn = null,
142+
bool strict = false,
143+
string name = null)
144+
{
145+
return with(ops.name_scope(name, "cond", new { pred }), delegate
146+
{
147+
// Add the Switch to the graph.
148+
var (p_2, p_1) = @switch(pred, pred);
149+
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
150+
var pivot_2 = array_ops.identity(p_2, name: "switch_f");
151+
pred = array_ops.identity(pred, name: "pred_id");
152+
153+
// Disable the fetching of tensors that are only on one branch of cond.
154+
foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
155+
tensor.op.graph.prevent_fetching(tensor.op);
156+
157+
return (p_2, p_1);
158+
});
159+
}
160+
161+
/// <summary>
162+
/// Forwards `data` to an output determined by `pred`.
163+
/// </summary>
164+
/// <param name="data"></param>
165+
/// <param name="pred"></param>
166+
/// <param name="dtype"></param>
167+
/// <param name="name"></param>
168+
public static (Tensor, Tensor) @switch(Tensor data,
169+
Tensor pred,
170+
TF_DataType dtype = TF_DataType.DtInvalid,
171+
string name = null)
172+
{
173+
return with(ops.name_scope(name, "Switch", new { data, pred }), scope =>
174+
{
175+
name = scope;
176+
data = ops.internal_convert_to_tensor_or_indexed_slices(data,
177+
dtype: dtype,
178+
name: "data",
179+
as_ref: true);
180+
181+
pred = ops.convert_to_tensor(pred, name: "pred");
182+
183+
return gen_control_flow_ops.@switch(data, pred, name: name);
184+
});
185+
}
138186
}
139187
}

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public static Tensor less<Tx, Ty>(Tx x, Ty y, string name = null)
4242

4343
public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)
4444
{
45-
var _op = _op_def_lib._apply_op_helper("Placeholder", args: new { dtype, shape });
45+
var _op = _op_def_lib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape });
4646
var _result = _op.outputs;
4747
var _inputs_flat = _op.inputs;
4848

0 commit comments

Comments
 (0)