Skip to content

Commit 387ae4c

Browse files
committed
overload tf.reduce_mean
1 parent ca39124 commit 387ae4c

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name =
474474
public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
475475
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);
476476

477-
public Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null)
477+
public Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
478478
=> math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name);
479479

480480
public Tensor round(Tensor x, string name = null)

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,16 @@ private static Tensor expand_dims_v2(Tensor input, int axis, string name = null)
228228
public static Tensor rank(Tensor input, string name = null)
229229
=> rank_internal(input, name, optimize: true);
230230

231+
public static Tensor rank(Tensor[] inputs, string name = null)
232+
{
233+
return tf_with(ops.name_scope(name, "Rank", new { inputs }), scope =>
234+
{
235+
name = scope;
236+
var input_tensor = ops.convert_to_tensor(inputs);
237+
return constant_op.constant(input_tensor.NDims, dtype: tf.int32, name: name);
238+
});
239+
}
240+
231241
public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true)
232242
{
233243
return tf_with(ops.name_scope(name, "Rank", new List<Tensor> { input }), scope =>

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,19 @@ public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool ke
219219
}
220220
}
221221

222-
public static Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null)
222+
public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
223223
{
224-
var m = gen_math_ops.mean(input_tensors, axis, keepdims, name);
225-
return _may_reduce_to_scalar(keepdims, axis, m);
224+
if(axis == null)
225+
{
226+
var r = _ReductionDims(input_tensors, axis);
227+
var m = gen_math_ops.mean(input_tensors, r, keepdims, name);
228+
return _may_reduce_to_scalar(keepdims, axis, m);
229+
}
230+
else
231+
{
232+
var m = gen_math_ops.mean(input_tensors, axis, keepdims, name);
233+
return _may_reduce_to_scalar(keepdims, axis, m);
234+
}
226235
}
227236

228237
/// <summary>
@@ -492,7 +501,7 @@ private static Tensor _may_reduce_to_scalar(bool keepdims, int[] axis, Tensor ou
492501
return output;
493502
}
494503

495-
private static Tensor _may_reduce_to_scalar(bool keepdims, int axis, Tensor output)
504+
private static Tensor _may_reduce_to_scalar(bool keepdims, int? axis, Tensor output)
496505
{
497506
return output;
498507
}
@@ -515,6 +524,11 @@ private static int _ReductionDims(Tensor x, int axis)
515524
return axis;
516525
}
517526

527+
private static Tensor _ReductionDims(Tensor[] x, int? axis = null, string name = null)
528+
{
529+
return range(0, array_ops.rank(x));
530+
}
531+
518532
private static Tensor _ReductionDims(Tensor x, int[] axis)
519533
{
520534
if (axis != null)

0 commit comments

Comments
 (0)