Skip to content

Commit ded16ea

Browse files
committed
BatchNormalization return tuple for call
1 parent 7bc249f commit ded16ea

29 files changed

+176
-84
lines changed

src/TensorFlowNET.Core/APIs/tf.layers.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public Tensor conv2d(Tensor inputs,
6363
trainable: trainable,
6464
name: name);
6565

66-
return layer.apply(inputs);
66+
return layer.apply(inputs).Item1;
6767
}
6868

6969
/// <summary>
@@ -117,7 +117,7 @@ public Tensor batch_normalization(Tensor inputs,
117117
trainable: trainable,
118118
name: name);
119119

120-
return layer.apply(inputs, training: training);
120+
return layer.apply(inputs, training: training).Item1;
121121
}
122122

123123
/// <summary>
@@ -143,7 +143,7 @@ public Tensor max_pooling2d(Tensor inputs,
143143
data_format: data_format,
144144
name: name);
145145

146-
return layer.apply(inputs);
146+
return layer.apply(inputs).Item1;
147147
}
148148

149149
/// <summary>
@@ -179,7 +179,7 @@ public Tensor dense(Tensor inputs,
179179
kernel_initializer: kernel_initializer,
180180
trainable: trainable);
181181

182-
return layer.apply(inputs);
182+
return layer.apply(inputs).Item1;
183183
}
184184

185185
/// <summary>

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public Tensor dropout(Tensor x, Tensor keep_prob = null, Tensor noise_shape = nu
7676
/// <param name="swap_memory"></param>
7777
/// <param name="time_major"></param>
7878
/// <returns>A pair (outputs, state)</returns>
79-
public (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
79+
public (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs,
8080
Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid,
8181
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
8282
=> rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype,

src/TensorFlowNET.Core/Graphs/Graph.Control.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
using System.Diagnostics.CodeAnalysis;
1919
using System.Linq;
2020
using Tensorflow.Operations;
21+
using static Tensorflow.Binding;
2122

2223
namespace Tensorflow
2324
{

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,15 +262,11 @@ public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes
262262

263263
if (string.IsNullOrEmpty(name))
264264
name = op_type;
265+
265266
// If a names ends with a '/' it is a "name scope" and we use it as-is,
266267
// after removing the trailing '/'.
267268
name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
268269
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
269-
270-
if (name.Contains("define_loss/bigger_box_loss/mul_13"))
271-
{
272-
273-
}
274270

275271
var input_ops = inputs.Select(x => x.op).ToArray();
276272
var control_inputs = _control_dependencies_for_inputs(input_ops);
@@ -377,7 +373,11 @@ public string name_scope(string name)
377373
/// <returns>A string to be passed to `create_op()` that will be used
378374
/// to name the operation being created.</returns>
379375
public string unique_name(string name, bool mark_as_used = true)
380-
{
376+
{
377+
if (name.EndsWith("basic_r_n_n_cell"))
378+
{
379+
380+
}
381381
if (!String.IsNullOrEmpty(_name_stack))
382382
name = _name_stack + "/" + name;
383383
// For the sake of checking for names in use, we treat names as case
@@ -405,7 +405,7 @@ public string unique_name(string name, bool mark_as_used = true)
405405

406406
// Return the new name with the original capitalization of the given name.
407407
name = $"{name}_{i-1}";
408-
}
408+
}
409409
return name;
410410
}
411411

src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using System.Collections.Generic;
1818
using Tensorflow.Operations;
19+
using static Tensorflow.Binding;
1920

2021
namespace Tensorflow
2122
{

src/TensorFlowNET.Core/Interfaces/IPackable.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
namespace Tensorflow
66
{
7-
public interface IPackable
7+
public interface IPackable<T>
88
{
9-
void Pack(object[] sequences);
9+
T Pack(object[] sequences);
1010
}
1111
}

src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,14 @@ protected override void build(TensorShape input_shape)
139139
built = true;
140140
}
141141

142-
protected override Tensor call(Tensor inputs, Tensor training = null)
142+
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
143143
{
144144
Tensor outputs = null;
145145

146146
if (fused)
147147
{
148148
outputs = _fused_batch_norm(inputs, training: training);
149-
return outputs;
149+
return (outputs, outputs);
150150
}
151151

152152
throw new NotImplementedException("BatchNormalization call");

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ protected override void build(TensorShape input_shape)
108108
built = true;
109109
}
110110

111-
protected override Tensor call(Tensor inputs, Tensor training = null)
111+
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
112112
{
113113
var outputs = _convolution_op.__call__(inputs, kernel);
114114
if (use_bias)
@@ -124,9 +124,9 @@ protected override Tensor call(Tensor inputs, Tensor training = null)
124124
}
125125

126126
if (activation != null)
127-
return activation.Activate(outputs);
127+
outputs = activation.Activate(outputs);
128128

129-
return outputs;
129+
return (outputs, outputs);
130130
}
131131
}
132132
}

src/TensorFlowNET.Core/Keras/Layers/Dense.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ protected override void build(TensorShape input_shape)
7272
built = true;
7373
}
7474

75-
protected override Tensor call(Tensor inputs, Tensor training = null)
75+
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
7676
{
7777
Tensor outputs = null;
7878
var rank = inputs.rank;
@@ -88,9 +88,9 @@ protected override Tensor call(Tensor inputs, Tensor training = null)
8888
if (use_bias)
8989
outputs = tf.nn.bias_add(outputs, bias);
9090
if (activation != null)
91-
return activation.Activate(outputs);
91+
outputs = activation.Activate(outputs);
9292

93-
return outputs;
93+
return (outputs, outputs);
9494
}
9595
}
9696
}

src/TensorFlowNET.Core/Keras/Layers/Embedding.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ protected override void build(TensorShape input_shape)
5050
built = true;
5151
}
5252

53-
protected override Tensor call(Tensor inputs, Tensor training = null)
53+
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
5454
{
5555
var dtype = inputs.dtype;
5656
if (dtype != tf.int32 && dtype != tf.int64)
5757
inputs = math_ops.cast(inputs, tf.int32);
5858

5959
var @out = embedding_ops.embedding_lookup(embeddings, inputs);
60-
return @out;
60+
return (@out, @out);
6161
}
6262
}
6363
}

0 commit comments

Comments
 (0)