Skip to content

Commit 38ad490

Browse files
committed
return array instead of tuple for layer.call
1 parent dd1b589 commit 38ad490

File tree

6 files changed

+20
-19
lines changed

6 files changed

+20
-19
lines changed

src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,14 @@ protected override void build(TensorShape input_shape)
139139
built = true;
140140
}
141141

142-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
142+
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
143143
{
144144
Tensor outputs = null;
145145

146146
if (fused)
147147
{
148148
outputs = _fused_batch_norm(inputs, training: training);
149-
return (outputs, outputs);
149+
return new[] { outputs, outputs };
150150
}
151151

152152
throw new NotImplementedException("BatchNormalization call");

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ protected override void build(TensorShape input_shape)
108108
built = true;
109109
}
110110

111-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
111+
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
112112
{
113113
var outputs = _convolution_op.__call__(inputs, kernel);
114114
if (use_bias)
@@ -126,7 +126,7 @@ protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null,
126126
if (activation != null)
127127
outputs = activation.Activate(outputs);
128128

129-
return (outputs, outputs);
129+
return new[] { outputs, outputs };
130130
}
131131
}
132132
}

src/TensorFlowNET.Core/Keras/Layers/Dense.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ protected override void build(TensorShape input_shape)
7272
built = true;
7373
}
7474

75-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
75+
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
7676
{
7777
Tensor outputs = null;
7878
var rank = inputs.rank;
@@ -90,7 +90,7 @@ protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null,
9090
if (activation != null)
9191
outputs = activation.Activate(outputs);
9292

93-
return (outputs, outputs);
93+
return new[] { outputs, outputs };
9494
}
9595
}
9696
}

src/TensorFlowNET.Core/Keras/Layers/Embedding.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ protected override void build(TensorShape input_shape)
5050
built = true;
5151
}
5252

53-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
53+
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
5454
{
5555
var dtype = inputs.dtype;
5656
if (dtype != tf.int32 && dtype != tf.int64)
5757
inputs = math_ops.cast(inputs, tf.int32);
5858

5959
var @out = embedding_ops.embedding_lookup(embeddings, inputs);
60-
return (@out, @out);
60+
return new[] { @out, @out };
6161
}
6262
}
6363
}

src/TensorFlowNET.Core/Keras/Layers/Layer.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,14 @@ public Layer(bool trainable = true,
103103
_inbound_nodes = new List<Node>();
104104
}
105105

106-
public (Tensor, Tensor) __call__(Tensor[] inputs,
106+
public Tensor[] __call__(Tensor[] inputs,
107107
Tensor training = null,
108108
Tensor state = null,
109109
VariableScope scope = null)
110110
{
111111
var input_list = inputs;
112112
var input = inputs[0];
113-
Tensor outputs = null;
113+
Tensor[] outputs = null;
114114

115115
// We will attempt to build a TF graph if & only if all inputs are symbolic.
116116
// This is always the case in graph mode. It can also be the case in eager
@@ -142,33 +142,34 @@ public Layer(bool trainable = true,
142142
// overridden).
143143
_maybe_build(inputs[0]);
144144

145-
(input, outputs) = call(inputs[0],
145+
outputs = call(inputs[0],
146146
training: training,
147147
state: state);
148+
148149
(input, outputs) = _set_connectivity_metadata_(input, outputs);
149150
_handle_activity_regularization(inputs[0], outputs);
150151
_set_mask_metadata(inputs[0], outputs, null);
151152
});
152153
}
153154

154-
return (input, outputs);
155+
return outputs;
155156
}
156157

157-
private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
158+
private (Tensor, Tensor[]) _set_connectivity_metadata_(Tensor inputs, Tensor[] outputs)
158159
{
159160
//_add_inbound_node(input_tensors: inputs, output_tensors: outputs);
160161
return (inputs, outputs);
161162
}
162163

163-
private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
164+
private void _handle_activity_regularization(Tensor inputs, Tensor[] outputs)
164165
{
165166
//if(_activity_regularizer != null)
166167
{
167168

168169
}
169170
}
170171

171-
private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask)
172+
private void _set_mask_metadata(Tensor inputs, Tensor[] outputs, Tensor previous_mask)
172173
{
173174

174175
}
@@ -178,9 +179,9 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
178179
return null;
179180
}
180181

181-
protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
182+
protected virtual Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
182183
{
183-
return (inputs, inputs);
184+
throw new NotImplementedException("");
184185
}
185186

186187
protected virtual string _name_scope()

src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public Pooling2D(IPoolFunction pool_function,
4343
this.input_spec = new InputSpec(ndim: 4);
4444
}
4545

46-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
46+
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
4747
{
4848
int[] pool_shape;
4949
if (data_format == "channels_last")
@@ -64,7 +64,7 @@ protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null,
6464
padding: padding.ToUpper(),
6565
data_format: conv_utils.convert_data_format(data_format, 4));
6666

67-
return (outputs, outputs);
67+
return new[] { outputs, outputs };
6868
}
6969
}
7070
}

0 commit comments

Comments
 (0)