Skip to content

Commit 16a59af

Browse files
committed
assign_sub, smart_cond
1 parent 1597a0b commit 16a59af

File tree

23 files changed

+284
-52
lines changed

23 files changed

+284
-52
lines changed

src/TensorFlowNET.Core/APIs/tf.layers.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,32 @@ public static Tensor batch_normalization(Tensor inputs,
100100

101101
return layer.apply(inputs, training: training);
102102
}
103+
104+
/// <summary>
105+
/// Max pooling layer for 2D inputs (e.g. images).
106+
/// </summary>
107+
/// <param name="inputs">The tensor over which to pool. Must have rank 4.</param>
108+
/// <param name="pool_size"></param>
109+
/// <param name="strides"></param>
110+
/// <param name="padding"></param>
111+
/// <param name="data_format"></param>
112+
/// <param name="name"></param>
113+
/// <returns></returns>
114+
public static Tensor max_pooling2d(Tensor inputs,
115+
int[] pool_size,
116+
int[] strides,
117+
string padding = "valid",
118+
string data_format = "channels_last",
119+
string name = null)
120+
{
121+
var layer = new MaxPooling2D(pool_size: pool_size,
122+
strides: strides,
123+
padding: padding,
124+
data_format: data_format,
125+
name: name);
126+
127+
return layer.apply(inputs);
128+
}
103129
}
104130
}
105131
}

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Operations;
45
using Tensorflow.Operations.Activation;
56

67
namespace Tensorflow
@@ -27,19 +28,21 @@ public static Tensor embedding_lookup(RefVariable @params,
2728

2829
public static IActivation relu => new relu();
2930

30-
public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x,
31-
RefVariable scale,
32-
RefVariable offset,
33-
Tensor mean = null,
34-
Tensor variance = null,
35-
float epsilon = 0.001f,
36-
string data_format = "NHWC",
37-
bool is_training = true,
38-
string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance,
39-
epsilon: epsilon,
40-
data_format: data_format,
41-
is_training: is_training,
42-
name: name);
31+
public static Tensor[] fused_batch_norm(Tensor x,
32+
RefVariable scale,
33+
RefVariable offset,
34+
Tensor mean = null,
35+
Tensor variance = null,
36+
float epsilon = 0.001f,
37+
string data_format = "NHWC",
38+
bool is_training = true,
39+
string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance,
40+
epsilon: epsilon,
41+
data_format: data_format,
42+
is_training: is_training,
43+
name: name);
44+
45+
public static Tensor max_pool() => gen_nn_ops.max_pool();
4346
}
4447
}
4548
}

src/TensorFlowNET.Core/Framework/smart_module.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ namespace Tensorflow.Framework
66
{
77
public class smart_module
88
{
9-
public static object smart_cond(Tensor pred,
10-
Func<(Tensor, Tensor, Tensor)> true_fn = null,
11-
Func<(Tensor, Tensor, Tensor)> false_fn = null,
9+
public static Tensor[] smart_cond<T>(Tensor pred,
10+
Func<T[]> true_fn = null,
11+
Func<T[]> false_fn = null,
1212
string name = null)
1313
{
1414
return control_flow_ops.cond(pred,
@@ -17,9 +17,12 @@ public static object smart_cond(Tensor pred,
1717
name: name);
1818
}
1919

20-
public static bool smart_constant_value(Tensor pred)
20+
public static bool? smart_constant_value(Tensor pred)
2121
{
2222
var pred_value = tensor_util.constant_value(pred);
23+
if (pred_value is null)
24+
return null;
25+
2326
return pred_value;
2427
}
2528
}

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using Tensorflow.Keras.Utils;
56

@@ -34,6 +35,7 @@ public class Layer : CheckpointableBase
3435
protected string _name;
3536
protected string _base_name;
3637
protected bool _compute_previous_mask;
38+
protected List<Operation> _updates;
3739

3840
public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
3941
{
@@ -45,6 +47,7 @@ public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_D
4547
_init_set_name(name);
4648
_trainable_weights = new List<RefVariable>();
4749
_compute_previous_mask = false;
50+
_updates = new List<Operation>();
4851
}
4952

5053
public Tensor __call__(Tensor inputs,
@@ -142,6 +145,12 @@ protected virtual RefVariable add_weight(string name,
142145
return variable;
143146
}
144147

148+
protected virtual void add_update(Tensor[] updates, bool inputs = false)
149+
{
150+
var updates_op = updates.Select(x => x.op).ToArray();
151+
_updates.AddRange(updates_op);
152+
}
153+
145154
protected virtual void _init_set_name(string name)
146155
{
147156
string base_name = name;

src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ protected override Tensor call(Tensor inputs, Tensor training = null)
132132
if (fused)
133133
{
134134
outputs = _fused_batch_norm(inputs, training: training);
135+
return outputs;
135136
}
136137

137138
throw new NotImplementedException("BatchNormalization call");
@@ -142,7 +143,7 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
142143
var beta = this.beta;
143144
var gamma = this.gamma;
144145

145-
Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () =>
146+
Func<Tensor[]> _fused_batch_norm_training = () =>
146147
{
147148
return tf.nn.fused_batch_norm(
148149
inputs,
@@ -152,7 +153,7 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
152153
data_format: _data_format);
153154
};
154155

155-
Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () =>
156+
Func<Tensor[]> _fused_batch_norm_inference = () =>
156157
{
157158
return tf.nn.fused_batch_norm(
158159
inputs,
@@ -165,9 +166,41 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
165166
data_format: _data_format);
166167
};
167168

168-
tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);
169+
var results = tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);
170+
var (output, mean, variance) = (results[0], results[1], results[2]);
171+
var training_value = tf_utils.constant_value(training);
169172

170-
throw new NotImplementedException("_fused_batch_norm");
173+
Tensor momentum_tensor;
174+
if (training_value == null)
175+
{
176+
momentum_tensor = tf_utils.smart_cond(training,
177+
() => new float[] { momentum }, () => new float[] { 1.0f })[0];
178+
}
179+
else
180+
{
181+
momentum_tensor = ops.convert_to_tensor(momentum);
182+
}
183+
184+
if(training_value == null)
185+
{
186+
var mean_update = _assign_moving_average(moving_mean, mean, momentum_tensor);
187+
var variance_update = _assign_moving_average(moving_variance, variance, momentum_tensor);
188+
add_update(new Tensor[] { mean_update }, inputs: true);
189+
add_update(new Tensor[] { variance_update }, inputs: true);
190+
}
191+
192+
return output;
193+
}
194+
195+
public Tensor _assign_moving_average(RefVariable variable, Tensor value, Tensor momentum)
196+
{
197+
return Python.with(ops.name_scope(null, "AssignMovingAvg", new { variable, value, momentum }), scope =>
198+
{
199+
// var cm = ops.colocate_with(variable);
200+
var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay");
201+
var update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay;
202+
return state_ops.assign_sub(variable, update_delta, name: scope);
203+
});
171204
}
172205
}
173206
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.tf;
5+
6+
namespace Tensorflow.Keras.Layers
7+
{
8+
public class MaxPooling2D : Pooling2D
9+
{
10+
public MaxPooling2D(
11+
int[] pool_size,
12+
int[] strides,
13+
string padding = "valid",
14+
string data_format = null,
15+
string name = null) : base(nn.max_pool, pool_size,
16+
strides,
17+
padding: padding,
18+
data_format: data_format,
19+
name: name)
20+
{
21+
22+
}
23+
}
24+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Engine;
5+
using Tensorflow.Keras.Utils;
6+
7+
namespace Tensorflow.Keras.Layers
8+
{
9+
public class Pooling2D : Tensorflow.Layers.Layer
10+
{
11+
private Func<Tensor> pool_function;
12+
private int[] pool_size;
13+
private int[] strides;
14+
private string padding;
15+
private string data_format;
16+
private InputSpec input_spec;
17+
18+
public Pooling2D(Func<Tensor> pool_function,
19+
int[] pool_size,
20+
int[] strides,
21+
string padding = "valid",
22+
string data_format = null,
23+
string name = null) : base(name: name)
24+
{
25+
this.pool_function = pool_function;
26+
this.pool_size = conv_utils.normalize_tuple(pool_size, 2, "pool_size");
27+
this.strides = conv_utils.normalize_tuple(strides, 2, "strides");
28+
this.padding = conv_utils.normalize_padding(padding);
29+
this.data_format = conv_utils.normalize_data_format(data_format);
30+
this.input_spec = new InputSpec(ndim: 4);
31+
}
32+
}
33+
}

src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,20 @@ public static string convert_data_format(string data_format, int ndim)
2929
else
3030
throw new ValueError($"Invalid data_format: {data_format}");
3131
}
32+
33+
public static int[] normalize_tuple(int[] value, int n, string name)
34+
{
35+
return value;
36+
}
37+
38+
public static string normalize_padding(string value)
39+
{
40+
return value.ToLower();
41+
}
42+
43+
public static string normalize_data_format(string value)
44+
{
45+
return value.ToLower();
46+
}
3247
}
3348
}

src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,19 @@ public static bool are_all_symbolic_tensors(Tensor[] tensors)
1313
return tensors.Select(x => is_symbolic_tensor(x)).Count() == tensors.Length;
1414
}
1515

16+
public static bool? constant_value(Tensor pred)
17+
{
18+
return smart_module.smart_constant_value(pred);
19+
}
20+
1621
public static bool is_symbolic_tensor(Tensor tensor)
1722
{
1823
return true;
1924
}
2025

21-
public static object smart_cond(Tensor pred,
22-
Func<(Tensor, Tensor, Tensor)> true_fn = null,
23-
Func<(Tensor, Tensor, Tensor)> false_fn = null,
26+
public static Tensor[] smart_cond<T>(Tensor pred,
27+
Func<T[]> true_fn = null,
28+
Func<T[]> false_fn = null,
2429
string name = null)
2530
{
2631
return smart_module.smart_cond(pred,

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using Tensorflow.Keras.Engine;
56

@@ -55,11 +56,23 @@ public Tensor __call__(Tensor inputs,
5556
var outputs = base.__call__(inputs, training: training);
5657

5758
// Update global default collections.
58-
//_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS);
59+
_add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS });
5960

6061
return outputs;
6162
}
6263

64+
protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list)
65+
{
66+
foreach(var name in collection_list)
67+
{
68+
var collection = ops.get_collection_ref(name) as List<object>;
69+
70+
foreach (var element in elements)
71+
if (!collection.Contains(element))
72+
collection.Add(element);
73+
}
74+
}
75+
6376
protected virtual RefVariable add_weight(string name,
6477
int[] shape,
6578
TF_DataType dtype = TF_DataType.DtInvalid,

0 commit comments

Comments
 (0)