Skip to content

Commit de23831

Browse files
committed
tf.keras.layers #355
1 parent 08862d2 commit de23831

31 files changed

+297
-163
lines changed

src/TensorFlowNET.Core/Gradients/GradientTape.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,19 @@ public Tensor gradient(Tensor target, Tensor source)
107107

108108
public Tensor gradient(Tensor target, ResourceVariable source)
109109
{
110-
var results = gradient(target, new[] { source });
110+
var results = gradient(target, new List<IVariableV1> { source });
111111

112112
return results[0];
113113
}
114114

115115
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources)
116116
{
117-
var results = gradient(target, new[] { sources.Item1, sources.Item2 });
117+
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 });
118118

119119
return (results[0], results[1]);
120120
}
121121

122-
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources)
122+
public Tensor[] gradient(Tensor target, List<IVariableV1> sources)
123123
{
124124
if (_recording)
125125
{

src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,16 @@ public Tensor[] ComputeGradient(long[] target_tensor_ids,
5454
var id = trace.output_tensor_info[i].GetID();
5555
if (!gradients.find(id, out var grad_it))
5656
{
57-
throw new NotImplementedException("FunctionsAcceptingNoneForIndicesMap");
57+
if (FunctionsAcceptingNoneForIndicesMap().find(trace.op_type, out var func_name_it) &&
58+
func_name_it.find(i))
59+
{
60+
out_gradients.Add(null);
61+
}
62+
else
63+
{
64+
out_gradients.Add(null);
65+
zero_indices.Add(i);
66+
}
5867
}
5968
else
6069
{
@@ -184,6 +193,15 @@ public Tensor[] ComputeGradient(long[] target_tensor_ids,
184193
return result.ToArray();
185194
}
186195

196+
UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap()
197+
{
198+
var m = new UnorderedMap<string, UnorderedSet<int>>();
199+
m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
200+
m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
201+
m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 }));
202+
return m;
203+
}
204+
187205
UnorderedMapEnumerable<long, List<Tensor>> InitialGradients(long[] target_tensor_ids,
188206
UnorderedMap<long, TapeTensor> sources_that_are_targets,
189207
Tensor[] output_gradients,

src/TensorFlowNET.Core/Keras/Activations.cs renamed to src/TensorFlowNET.Core/Keras/Activations/Activations.Linear.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55

66
namespace Tensorflow.Keras
77
{
8-
public delegate Tensor Activation(Tensor x);
9-
10-
public class Activations
8+
public partial class Activations
119
{
1210
/// <summary>
1311
/// Linear activation function (pass-through).
1412
/// </summary>
15-
public Activation Linear = x => x;
13+
public Activation Linear = (features, name) => features;
1614
}
1715
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Operations;
5+
using static Tensorflow.Binding;
6+
7+
namespace Tensorflow.Keras
8+
{
9+
public partial class Activations
10+
{
11+
public Activation Relu = (features, name) =>
12+
{
13+
if (tf.executing_eagerly())
14+
{
15+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
16+
"Relu", name,
17+
null,
18+
features);
19+
20+
return results[0];
21+
}
22+
23+
throw new NotImplementedException("");
24+
};
25+
}
26+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow.Keras
7+
{
8+
public delegate Tensor Activation(Tensor features, string name = null);
9+
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public class NodeArgs
1111
public Layer[] InboundLayers { get; set; }
1212
public int[] NodeIndices { get; set; }
1313
public int[] TensorIndices { get; set; }
14-
public Tensor[] InputTensors { get; set; }
15-
public Tensor[] Outputs { get; set; }
14+
public Tensor InputTensors { get; set; }
15+
public Tensor Outputs { get; set; }
1616
}
1717
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Layers;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow.Keras.Engine
9+
{
10+
public partial class Layer
11+
{
12+
protected List<Layer> _layers = new List<Layer>();
13+
14+
protected Layer Dense(int units,
15+
Activation activation = null,
16+
TensorShape input_shape = null)
17+
{
18+
var layer = new Dense(new DenseArgs
19+
{
20+
Units = units,
21+
Activation = activation ?? tf.keras.activations.Linear,
22+
InputShape = input_shape
23+
});
24+
25+
_layers.Add(layer);
26+
return layer;
27+
}
28+
}
29+
}

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 39 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@ limitations under the License.
1818
using System.Collections.Generic;
1919
using System.Linq;
2020
using System.Threading;
21-
using Tensorflow.Contexts;
2221
using Tensorflow.Keras.ArgsDefinition;
2322
using Tensorflow.Keras.Layers;
2423
using Tensorflow.Keras.Utils;
25-
using Tensorflow.Operations.Activation;
2624
using Tensorflow.Train;
2725
using static Tensorflow.Binding;
2826

@@ -34,7 +32,7 @@ namespace Tensorflow.Keras.Engine
3432
/// as convolution, batch norm, etc. These operations require managing weights,
3533
/// losses, updates, and inter-layer connectivity.
3634
/// </summary>
37-
public abstract class Layer : AutoTrackable
35+
public abstract partial class Layer : AutoTrackable
3836
{
3937
/// <summary>
4038
/// Arguments initialize layer.
@@ -60,8 +58,19 @@ public abstract class Layer : AutoTrackable
6058
protected InputSpec inputSpec;
6159
public bool SupportsMasking { get; set; }
6260
protected List<IVariableV1> trainableWeights;
63-
public List<IVariableV1> TrainableVariables => trainableWeights;
61+
public List<IVariableV1> trainable_variables
62+
{
63+
get
64+
{
65+
if(trainableWeights.Count == 0)
66+
_layers.ForEach(x => trainableWeights.AddRange(x.trainableWeights));
67+
68+
return trainableWeights;
69+
}
70+
}
71+
6472
protected List<IVariableV1> nonTrainableWeights;
73+
public List<IVariableV1> non_trainable_variables => nonTrainableWeights;
6574

6675
string name;
6776
public string Name => name;
@@ -112,20 +121,20 @@ public Layer(LayerArgs args)
112121
/// <param name="input"></param>
113122
/// <param name="is_training"></param>
114123
/// <returns></returns>
115-
public Tensor[] Apply(Tensor[] inputs, bool is_training = false)
124+
public Tensor Apply(Tensor inputs, bool is_training = false)
116125
{
117-
var input = inputs[0];
118-
Tensor[] outputs = null;
126+
Tensor outputs = null;
119127

120128
callContext = callContext ?? new ThreadLocal<CallContext>()
121129
{
122130
Value = new CallContext()
123131
};
124132

133+
var eager = tf.executing_eagerly();
125134
using var ctxManager = CallContext.enter();
126135

127136
string nameScope = "";
128-
if (tf.executing_eagerly())
137+
if (eager)
129138
{
130139
nameScope = name;
131140
}
@@ -134,7 +143,7 @@ public Tensor[] Apply(Tensor[] inputs, bool is_training = false)
134143
throw new NotImplementedException("");
135144
}
136145

137-
using var graph = tf.keras.backend.get_graph().as_default();
146+
// using var graph = tf.keras.backend.get_graph().as_default();
138147

139148
tf_with(ops.name_scope(nameScope), scope =>
140149
{
@@ -143,82 +152,44 @@ public Tensor[] Apply(Tensor[] inputs, bool is_training = false)
143152

144153
outputs = call(inputs, is_training: is_training);
145154

146-
(input, outputs) = _set_connectivity_metadata_(input, outputs);
147-
_handle_activity_regularization(inputs[0], outputs);
148-
_set_mask_metadata(inputs[0], outputs, null);
155+
outputs = _set_connectivity_metadata_(inputs, outputs);
156+
_handle_activity_regularization(inputs, outputs);
157+
_set_mask_metadata(inputs, outputs, null);
149158
});
150159

151160
return outputs;
152161
}
153162

154-
[Obsolete("User Apply()")]
155-
public Tensor[] __call__(Tensor[] inputs,
156-
Tensor training = null,
157-
Tensor state = null,
158-
VariableScope scope = null)
163+
private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
159164
{
160-
var input_list = inputs;
161-
var input = inputs[0];
162-
Tensor[] outputs = null;
163-
164-
// We will attempt to build a TF graph if & only if all inputs are symbolic.
165-
// This is always the case in graph mode. It can also be the case in eager
166-
// mode when all inputs can be traced back to `keras.Input()` (when building
167-
// models using the functional API).
168-
bool build_graph = tf_utils.are_all_symbolic_tensors(input_list);
169-
170-
if (build_graph)
171-
{
172-
// Only create Keras history if at least one tensor originates from a
173-
// `keras.Input`. Otherwise this Layer may be being used outside the Keras
174-
// framework.
175-
// base_layer_utils.create_keras_history(inputs)
176-
}
177-
178-
// with base_layer_utils.call_context(self):
179-
180-
// Handle Keras mask propagation from previous layer to current layer.
181-
// with base_layer_utils.call_context(self):
182-
// Check input assumptions set after layer building, e.g. input shape.
183-
if (build_graph)
165+
/*var returnOutputs = new List<Tensor>();
166+
foreach(var x in outputs)
184167
{
185-
// Symbolic execution on symbolic tensors. We will attempt to build
186-
// the corresponding TF subgraph inside `backend.get_graph()`
187-
var graph = tf.keras.backend.get_graph().as_default();
188-
tf_with(ops.name_scope(_name_scope()), delegate
168+
if (inputs.Contains(x))
189169
{
190-
// Build layer if applicable (if the `build` method has been
191-
// overridden).
192-
MaybeBuild(inputs);
193-
194-
outputs = call(inputs,
195-
// training: training,
196-
state: state);
197170
198-
(input, outputs) = _set_connectivity_metadata_(input, outputs);
199-
_handle_activity_regularization(inputs[0], outputs);
200-
_set_mask_metadata(inputs[0], outputs, null);
201-
});
202-
}
171+
}
172+
returnOutputs.Add(x);
173+
}*/
203174

204-
return outputs;
205-
}
175+
new Node(this, new NodeArgs
176+
{
177+
Outputs = outputs
178+
});
206179

207-
private (Tensor, Tensor[]) _set_connectivity_metadata_(Tensor inputs, Tensor[] outputs)
208-
{
209180
//_add_inbound_node(input_tensors: inputs, output_tensors: outputs);
210-
return (inputs, outputs);
181+
return outputs;
211182
}
212183

213-
private void _handle_activity_regularization(Tensor inputs, Tensor[] outputs)
184+
private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
214185
{
215186
//if(_activity_regularizer != null)
216187
{
217188

218189
}
219190
}
220191

221-
private void _set_mask_metadata(Tensor inputs, Tensor[] outputs, Tensor previous_mask)
192+
private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask)
222193
{
223194

224195
}
@@ -228,7 +199,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
228199
return null;
229200
}
230201

231-
protected virtual Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null)
202+
protected virtual Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
232203
{
233204
throw new NotImplementedException("");
234205
}
@@ -238,15 +209,15 @@ protected virtual string _name_scope()
238209
return Name;
239210
}
240211

241-
protected void MaybeBuild(Tensor[] inputs)
212+
protected void MaybeBuild(Tensor inputs)
242213
{
243214
// Check input assumptions set before layer building, e.g. input rank.
244215
if (built)
245216
return;
246217
if (DType == TF_DataType.DtInvalid)
247-
args.DType = inputs[0].dtype;
218+
args.DType = inputs.dtype;
248219

249-
var input_shapes = inputs[0].TensorShape;
220+
var input_shapes = inputs.TensorShape;
250221
build(input_shapes);
251222
built = true;
252223
}

src/TensorFlowNET.Core/Keras/Engine/Node.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ public class Node
3535

3636
public int[] node_indices;
3737
public int[] tensor_indices;
38-
public Tensor[] input_tensors;
39-
public Tensor[] Outputs => args.Outputs;
38+
public Tensor input_tensors;
39+
public Tensor Outputs => args.Outputs;
4040
public TensorShape[] input_shapes;
4141
public TensorShape[] output_shapes;
4242
List<Layer> kerasInputs;
4343

44-
public Node(InputLayer layer, NodeArgs args)
44+
public Node(Layer layer, NodeArgs args)
4545
{
4646
this.args = args;
4747

0 commit comments

Comments
 (0)