Skip to content

Commit 126ed27

Browse files
committed
reduce_sum in progress
1 parent 2fc45b9 commit 126ed27

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,22 @@ public static Tensor pow(Tensor x, double y)
2626
return gen_math_ops.pow(x, y);
2727
}
2828

29+
/// <summary>
30+
/// Computes the sum of elements across dimensions of a tensor.
31+
/// </summary>
32+
/// <param name="input"></param>
33+
/// <param name="axis"></param>
34+
/// <returns></returns>
2935
public static Tensor reduce_sum(Tensor input, int[] axis = null)
3036
{
31-
return gen_math_ops.sum(input, axis);
37+
Tensor rank;
38+
using (var namescop = new ops.name_scope<Tensor>("", "Rank", new List<Tensor> { input }))
39+
{
40+
string name = namescop;
41+
rank = gen_array_ops.rank(input, name);
42+
}
43+
var s = gen_math_ops.sum(input, rank);
44+
return gen_math_ops.range(0, s);
3245
}
3346
}
3447
}

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,15 @@ public static Tensor identity(Tensor input, string name = "")
4141

4242
return _op.outputs[0];
4343
}
44+
45+
public static Tensor rank(Tensor input, string name = "")
46+
{
47+
var keywords = new Dictionary<string, object>();
48+
keywords.Add("input", input);
49+
50+
var _op = _op_def_lib._apply_op_helper("Rank", name: name, keywords: keywords);
51+
52+
return _op.outputs[0];
53+
}
4454
}
4555
}

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,43 @@ public static Tensor pow(Tensor x, double y)
7878
return new Tensor(_op, 0, _op.OutputType(0));
7979
}
8080

81-
public static Tensor sum(Tensor input, int[] axis = null)
81+
public static Tensor sum(Tensor input, Tensor axis = null)
8282
{
83-
if(axis == null) axis = new int[0];
8483
var keywords = new Dictionary<string, object>();
8584
keywords.Add("input", input);
86-
keywords.Add("reduction_indices", constant_op.Constant(axis));
85+
keywords.Add("reduction_indices", axis);
8786
keywords.Add("keep_dims", false);
8887

8988
var _op = _op_def_lib._apply_op_helper("Sum", keywords: keywords);
9089

9190
return new Tensor(_op, 0, _op.OutputType(0));
9291
}
92+
93+
/// <summary>
94+
/// Creates a sequence of numbers.
95+
/// </summary>
96+
/// <param name="start"></param>
97+
/// <param name="limit"></param>
98+
/// <param name="delta"></param>
99+
/// <param name="name"></param>
100+
/// <returns></returns>
101+
public static Tensor range(int start, Tensor limit, int delta = 1)
102+
{
103+
using (var namescope = new ops.name_scope<Tensor>("", "Range", new List<Tensor> { start, limit, delta }))
104+
{
105+
var start1 = ops.convert_to_tensor(start, "start");
106+
var limit1 = ops.convert_to_tensor(limit, "limit");
107+
var delta1 = ops.convert_to_tensor(delta, "delta");
108+
109+
var keywords = new Dictionary<string, object>();
110+
keywords.Add("start", start1);
111+
keywords.Add("limit", limit1);
112+
keywords.Add("delta", delta1);
113+
114+
var _op = _op_def_lib._apply_op_helper("Range", namescope, keywords);
115+
116+
return _op.outputs[0];
117+
}
118+
}
93119
}
94120
}

0 commit comments

Comments
 (0)