Skip to content

Commit 053f40b

Browse files
committed
nn_ops.bias_add, tf.layers.batch_normalization
1 parent b536026 commit 053f40b

File tree

17 files changed

+281
-5
lines changed

17 files changed

+281
-5
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,16 @@ public static partial class tf
1919
/// </returns>
2020
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
2121
=> array_ops.expand_dims(input, axis, name, dim);
22+
23+
/// <summary>
24+
/// Transposes `a`. Permutes the dimensions according to `perm`.
25+
/// </summary>
26+
/// <param name="a"></param>
27+
/// <param name="perm"></param>
28+
/// <param name="name"></param>
29+
/// <param name="conjugate"></param>
30+
/// <returns></returns>
31+
public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false)
32+
=> array_ops.transpose(a, perm, name, conjugate);
2233
}
2334
}

src/TensorFlowNET.Core/APIs/tf.layers.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,45 @@ public static Tensor conv2d(Tensor inputs,
4646

4747
return layer.apply(inputs);
4848
}
49+
50+
/// <summary>
51+
/// Functional interface for the batch normalization layer.
52+
/// http://arxiv.org/abs/1502.03167
53+
/// </summary>
54+
/// <param name="inputs"></param>
55+
/// <param name="axis"></param>
56+
/// <param name="momentum"></param>
57+
/// <param name="epsilon"></param>
58+
/// <param name="center"></param>
59+
/// <param name="scale"></param>
60+
/// <param name="beta_initializer"></param>
61+
/// <param name="gamma_initializer"></param>
62+
/// <param name="moving_mean_initializer"></param>
63+
/// <param name="moving_variance_initializer"></param>
64+
/// <param name="training"></param>
65+
/// <param name="trainable"></param>
66+
/// <param name="name"></param>
67+
/// <param name="renorm"></param>
68+
/// <param name="renorm_momentum"></param>
69+
/// <returns></returns>
70+
public static Tensor batch_normalization(Tensor inputs,
71+
int axis = -1,
72+
float momentum = 0.99f,
73+
float epsilon = 0.001f,
74+
bool center = true,
75+
bool scale = true,
76+
IInitializer beta_initializer = null,
77+
IInitializer gamma_initializer = null,
78+
IInitializer moving_mean_initializer = null,
79+
IInitializer moving_variance_initializer = null,
80+
Tensor training = null,
81+
bool trainable = true,
82+
string name = null,
83+
bool renorm = false,
84+
float renorm_momentum = 0.99f)
85+
{
86+
throw new NotImplementedException("batch_normalization");
87+
}
4988
}
5089
}
5190
}

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public Tensor __call__(Tensor inputs,
3030
VariableScope scope = null)
3131
{
3232
var input_list = new Tensor[] { inputs };
33+
Tensor outputs = null;
3334

3435
// We will attempt to build a TF graph if & only if all inputs are symbolic.
3536
// This is always the case in graph mode. It can also be the case in eager
@@ -45,9 +46,42 @@ public Tensor __call__(Tensor inputs,
4546
_maybe_build(inputs);
4647
built = true;
4748
}
49+
50+
if (build_graph)
51+
{
52+
// Symbolic execution on symbolic tensors. We will attempt to build
53+
// the corresponding TF subgraph inside `backend.get_graph()`
54+
var graph = backend.get_graph();
55+
outputs = call(inputs);
56+
_handle_activity_regularization(inputs, outputs);
57+
_set_mask_metadata(inputs, outputs, null);
58+
}
4859
});
4960

50-
throw new NotImplementedException("");
61+
return outputs;
62+
}
63+
64+
private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
65+
{
66+
//if(_activity_regularizer != null)
67+
{
68+
69+
}
70+
}
71+
72+
private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask)
73+
{
74+
75+
}
76+
77+
private Tensor compute_mask(Tensor inputs, Tensor mask = null)
78+
{
79+
return null;
80+
}
81+
82+
protected virtual Tensor call(Tensor inputs)
83+
{
84+
throw new NotImplementedException("Layer.call");
5185
}
5286

5387
protected virtual string _name_scope()

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,5 +90,26 @@ protected override void build(TensorShape input_shape)
9090

9191
built = true;
9292
}
93+
94+
protected override Tensor call(Tensor inputs)
95+
{
96+
var outputs = _convolution_op.__call__(inputs, kernel);
97+
if (use_bias)
98+
{
99+
if (data_format == "channels_first")
100+
{
101+
throw new NotImplementedException("call channels_first");
102+
}
103+
else
104+
{
105+
outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC");
106+
}
107+
}
108+
109+
if (activation != null)
110+
return activation.Activate(outputs);
111+
112+
return outputs;
113+
}
93114
}
94115
}

src/TensorFlowNET.Core/Keras/backend.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,10 @@ public static void track_variable(RefVariable v)
1010
{
1111

1212
}
13+
14+
public static Graph get_graph()
15+
{
16+
return ops.get_default_graph();
17+
}
1318
}
1419
}

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ public Tensor __call__(Tensor inputs,
6565
// Actually call layer
6666
var outputs = base.__call__(inputs);
6767

68-
throw new NotImplementedException("");
68+
// Update global default collections.
69+
//_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS);
70+
71+
return outputs;
6972
}
7073

7174
protected virtual RefVariable add_weight(string name,

src/TensorFlowNET.Core/Operations/Activation/IActivation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ namespace Tensorflow.Operations.Activation
66
{
77
public interface IActivation
88
{
9-
9+
Tensor Activate(Tensor features, string name = null);
1010
}
1111
}

src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.py.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@ namespace Tensorflow.Operations.Activation
66
{
77
public class relu : IActivation
88
{
9+
public Tensor Activate(Tensor features, string name = null)
10+
{
11+
OpDefLibrary _op_def_lib = new OpDefLibrary();
912

13+
var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new
14+
{
15+
features
16+
});
17+
18+
return _op.outputs[0];
19+
}
1020
}
1121
}

src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,10 @@ public _NonAtrousConvolution _build_op(int _, string padding)
6262
strides: strides,
6363
name: name);
6464
}
65+
66+
public Tensor __call__(Tensor inp, RefVariable filter)
67+
{
68+
return conv_op.__call__(inp, filter);
69+
}
6570
}
6671
}

src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,18 @@ public _NonAtrousConvolution(TensorShape input_shape,
5252
throw new NotImplementedException("_NonAtrousConvolution conv_dims 3");
5353
}
5454
}
55+
56+
public Tensor __call__(Tensor inp, RefVariable filter)
57+
{
58+
return conv_op(new
59+
{
60+
input = inp,
61+
filter,
62+
strides,
63+
padding,
64+
data_format,
65+
name
66+
});
67+
}
5568
}
5669
}

0 commit comments

Comments
 (0)