Skip to content

Commit 0e2488c

Browse files
committed
math.reduce_sum, tf.variables_initializer
1 parent 9e0d8c8 commit 0e2488c

File tree

17 files changed

+943
-705
lines changed

17 files changed

+943
-705
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ limitations under the License.
1717
using NumSharp;
1818
using System;
1919
using System.Collections.Generic;
20+
using System.Diagnostics;
2021
using System.Linq;
22+
using static Tensorflow.Binding;
2123

2224
namespace Tensorflow
2325
{
@@ -76,7 +78,14 @@ public Tensor check_numerics(Tensor tensor, string message, string name = null)
7678
public Tensor concat(IList<Tensor> values, int axis, string name = "concat")
7779
{
7880
if (values.Count == 1)
79-
throw new NotImplementedException("tf.concat length is 1");
81+
{
82+
return tf_with(ops.name_scope(name), scope =>
83+
{
84+
var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32);
85+
Debug.Assert(tensor.TensorShape.ndim == 0);
86+
return identity(values[0], name: scope);
87+
});
88+
}
8089

8190
return gen_array_ops.concat_v2(values.ToArray(), axis, name: name);
8291
}
@@ -111,7 +120,7 @@ public Tensor fill<T>(Tensor dims, T value, string name = null)
111120
/// <param name="input"></param>
112121
/// <param name="name"></param>
113122
/// <returns></returns>
114-
public static Tensor identity(Tensor input, string name = null)
123+
public Tensor identity(Tensor input, string name = null)
115124
=> array_ops.identity(input, name: name);
116125

117126
/// <summary>
@@ -150,10 +159,10 @@ public Tensor transpose<T1>(T1 a, int[] perm = null, string name = "transpose",
150159
/// <param name="axis"></param>
151160
/// <param name="name"></param>
152161
/// <returns></returns>
153-
public static Tensor reverse(Tensor tensor, int[] axis, string name = null)
162+
public Tensor reverse(Tensor tensor, int[] axis, string name = null)
154163
=> gen_array_ops.reverse(tensor, axis, name: name);
155164

156-
public static Tensor reverse(Tensor tensor, Tensor axis, string name = null)
165+
public Tensor reverse(Tensor tensor, Tensor axis, string name = null)
157166
=> gen_array_ops.reverse(tensor, axis, name: name);
158167

159168
/// <summary>
@@ -277,5 +286,14 @@ public Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name
277286
/// <returns>A `Tensor` with all elements set to zero.</returns>
278287
public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
279288
=> array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize);
289+
290+
/// <summary>
291+
/// Stops gradient computation.
292+
/// </summary>
293+
/// <param name="x"></param>
294+
/// <param name="name"></param>
295+
/// <returns></returns>
296+
public Tensor stop_gradient(Tensor x, string name = null)
297+
=> gen_array_ops.stop_gradient(x, name: name);
280298
}
281299
}

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,11 +434,14 @@ public Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool keepdims
434434
public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null,
435435
bool keepdims = false, string name = null)
436436
{
437-
if(!axis.HasValue && reduction_indices.HasValue)
437+
if (!axis.HasValue && reduction_indices.HasValue && !keepdims)
438438
return math_ops.reduce_sum(input, reduction_indices.Value);
439-
else if (axis.HasValue && !reduction_indices.HasValue)
439+
else if (axis.HasValue && !reduction_indices.HasValue && !keepdims)
440440
return math_ops.reduce_sum(input, axis.Value);
441-
return math_ops.reduce_sum(input, keepdims: keepdims, name: name);
441+
else if (axis.HasValue && !reduction_indices.HasValue && keepdims)
442+
return math_ops.reduce_sum(input, keepdims: keepdims, axis: axis.Value, name: name);
443+
else
444+
return math_ops.reduce_sum(input, keepdims: keepdims, name: name);
442445
}
443446

444447
public Tensor reduce_sum(Tensor input, TensorShape axis, int? reduction_indices = null,
@@ -471,6 +474,9 @@ public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name =
471474
public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
472475
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);
473476

477+
public Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null)
478+
=> math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name);
479+
474480
public Tensor round(Tensor x, string name = null)
475481
=> gen_math_ops.round(x, name: name);
476482

src/TensorFlowNET.Core/APIs/tf.random.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,10 @@ public Tensor random_shuffle(Tensor value, int? seed = null, string name = null)
6565

6666
public void set_random_seed(int seed)
6767
=> ops.get_default_graph().seed = seed;
68+
69+
public Tensor multinomial(Tensor logits, int num_samples, int? seed = null,
70+
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid)
71+
=> random_ops.multinomial(logits, num_samples, seed: seed,
72+
name: name, output_dtype: output_dtype);
6873
}
6974
}

src/TensorFlowNET.Core/APIs/tf.train.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System.Collections.Generic;
18+
using Tensorflow.Keras.Optimizers;
1819
using Tensorflow.Train;
1920

2021
namespace Tensorflow
@@ -73,6 +74,26 @@ public string latest_checkpoint(string checkpoint_dir, string latest_filename =
7374

7475
public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null)
7576
=> checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename);
77+
78+
public Tensor polynomial_decay(float learning_rate,
79+
RefVariable global_step,
80+
float decay_steps,
81+
float end_learning_rate = 0.0001f,
82+
float power = 1.0f,
83+
bool cycle = false,
84+
string name = null)
85+
{
86+
var decayed = new PolynomialDecay(learning_rate,
87+
decay_steps,
88+
end_learning_rate: end_learning_rate,
89+
power: power,
90+
cycle: cycle,
91+
name: name);
92+
93+
var decayed_lr = decayed.__call__(global_step);
94+
95+
return decayed_lr;
96+
}
7697
}
7798
}
7899
}

src/TensorFlowNET.Core/APIs/tf.variable.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ public VariableV1[] global_variables(string scope = null)
2727
.ToArray();
2828
}
2929

30+
/// <summary>
31+
/// Returns an Op that initializes a list of variables.
32+
/// </summary>
33+
/// <param name="var_list">List of `Variable` objects to initialize.</param>
34+
/// <param name="name">Optional name for the returned operation.</param>
35+
/// <returns>An Op that run the initializers of all the specified variables.</returns>
36+
public Operation variables_initializer(VariableV1[] var_list, string name = "init")
37+
=> variables.variables_initializer(var_list, name: name);
38+
3039
public Operation global_variables_initializer()
3140
{
3241
var g = variables.global_variables();

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ public static IEnumerable<int> range(int start, int end)
115115
return instance;
116116
}
117117

118+
[DebuggerStepThrough]
118119
[DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception
119120
public static void tf_with(IObjectLife py, Action<IObjectLife> action)
120121
{
@@ -273,7 +274,10 @@ public static double sum(IEnumerable enumerable)
273274
return sum;
274275
}
275276

276-
public static double sum(IEnumerable<int> enumerable)
277+
public static float sum(IEnumerable<float> enumerable)
278+
=> enumerable.Sum();
279+
280+
public static int sum(IEnumerable<int> enumerable)
277281
=> enumerable.Sum();
278282

279283
public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
7+
namespace Tensorflow.Keras.Optimizers
8+
{
9+
public class LearningRateSchedule
10+
{
11+
public LearningRateSchedule()
12+
{
13+
14+
}
15+
}
16+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow.Keras.Optimizers
9+
{
10+
/// <summary>
11+
/// A LearningRateSchedule that uses a polynomial decay schedule.
12+
/// </summary>
13+
public class PolynomialDecay : LearningRateSchedule
14+
{
15+
float initial_learning_rate;
16+
float decay_steps;
17+
float end_learning_rate;
18+
float power;
19+
bool cycle;
20+
string name;
21+
22+
public PolynomialDecay(float initial_learning_rate,
23+
float decay_steps,
24+
float end_learning_rate = 0.0001f,
25+
float power = 1.0f,
26+
bool cycle = false,
27+
string name = null) : base()
28+
{
29+
this.initial_learning_rate = initial_learning_rate;
30+
this.decay_steps = decay_steps;
31+
this.end_learning_rate = end_learning_rate;
32+
this.power = power;
33+
this.cycle = cycle;
34+
this.name = name;
35+
}
36+
37+
public Tensor __call__(RefVariable step)
38+
{
39+
tf_with(ops.name_scope(name ?? "PolynomialDecay"), scope =>
40+
{
41+
name = scope;
42+
var initial_learning_rate_tensor = ops.convert_to_tensor(initial_learning_rate, name: "initial_learning_rate");
43+
var dtype = initial_learning_rate_tensor.dtype;
44+
var end_learning_rate_tensor = math_ops.cast(end_learning_rate, dtype);
45+
var power_tensor = math_ops.cast(power, dtype);
46+
47+
var global_step_recomp = math_ops.cast(step, dtype);
48+
var decay_steps_recomp = math_ops.cast(decay_steps, dtype);
49+
50+
if(cycle)
51+
{
52+
throw new NotImplementedException("PolynomialDecay cycle");
53+
}
54+
else
55+
{
56+
57+
}
58+
});
59+
throw new NotImplementedException("");
60+
}
61+
}
62+
}

src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ namespace Tensorflow.Operations.Initializers
1919
public class GlorotUniform : VarianceScaling
2020
{
2121
public GlorotUniform(float scale = 1.0f,
22-
string mode = "fan_avg",
23-
string distribution = "uniform",
22+
string mode = "FAN_AVG",
2423
int? seed = null,
2524
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale,
2625
mode: mode,
@@ -36,7 +35,6 @@ public object get_config()
3635
{
3736
scale = _scale,
3837
mode = _mode,
39-
distribution = _distribution,
4038
seed = _seed,
4139
dtype = _dtype
4240
};

src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public class VarianceScaling : IInitializer
3030
protected string _distribution;
3131
protected int? _seed;
3232
protected TF_DataType _dtype;
33+
protected bool _uniform;
3334

3435
public VarianceScaling(float factor = 2.0f,
3536
string mode = "FAN_IN",
@@ -49,31 +50,31 @@ public VarianceScaling(float factor = 2.0f,
4950
_mode = mode;
5051
_seed = seed;
5152
_dtype = dtype;
53+
_uniform = uniform;
5254
}
5355

5456
public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null)
5557
{
58+
float n = 0;
5659
var (fan_in, fan_out) = _compute_fans(shape);
57-
if (_mode == "fan_in")
58-
_scale /= Math.Max(1, fan_in);
59-
else if (_mode == "fan_out")
60-
_scale /= Math.Max(1, fan_out);
61-
else
62-
_scale /= Math.Max(1, (fan_in + fan_out) / 2);
60+
if (_mode == "FAN_IN")
61+
n = fan_in;
62+
else if (_mode == "FAN_OUT")
63+
n = fan_out;
64+
else if(_mode == "FAN_AVG")
65+
n = (fan_in + fan_out) / 2.0f;
6366

64-
if (_distribution == "normal" || _distribution == "truncated_normal")
65-
{
66-
float stddev = (float)Math.Sqrt(_scale) / .87962566103423978f;
67-
return random_ops.truncated_normal(shape, mean: 0.0f, stddev: stddev, dtype: dtype, seed: _seed);
68-
}
69-
else if (_distribution == "untruncated_normal")
67+
if(_uniform)
7068
{
71-
throw new NotImplementedException("truncated_normal");
69+
var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n));
70+
return random_ops.random_uniform(shape, -limit, limit,
71+
dtype, seed: _seed);
7272
}
7373
else
7474
{
75-
var limit = Math.Sqrt(3.0f * _scale);
76-
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
75+
var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n));
76+
return random_ops.truncated_normal(shape, 0.0f, trunc_stddev, dtype,
77+
seed: _seed);
7778
}
7879
}
7980

@@ -106,6 +107,7 @@ public virtual object get_config()
106107
mode = _mode,
107108
distribution = _distribution,
108109
seed = _seed,
110+
uniform = _uniform,
109111
dtype = _dtype
110112
};
111113
}

0 commit comments

Comments
 (0)