|
2 | 2 | using System; |
3 | 3 | using System.Collections.Generic; |
4 | 4 | using System.Text; |
| 5 | +using Tensorflow.Framework; |
5 | 6 |
|
6 | 7 | namespace Tensorflow |
7 | 8 | { |
@@ -39,9 +40,41 @@ public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, s |
39 | 40 | public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) |
40 | 41 | { |
41 | 42 | var r = _ReductionDims(input_tensor, axis); |
42 | | - var m = gen_math_ops.mean(input_tensor, (int[]) r, keepdims, name); |
43 | | - return _may_reduce_to_scalar(keepdims,axis, m); |
| 43 | + if (axis == null) |
| 44 | + { |
| 45 | + var m = gen_math_ops.mean(input_tensor, r, keepdims, name); |
| 46 | + return _may_reduce_to_scalar(keepdims, axis, m); |
| 47 | + } |
| 48 | + else |
| 49 | + { |
| 50 | + var m = gen_math_ops.mean(input_tensor, axis, keepdims, name); |
| 51 | + return _may_reduce_to_scalar(keepdims, axis, m); |
| 52 | + } |
| 53 | + } |
| 54 | + |
| 55 | + /// <summary> |
| 56 | + /// Computes the product of elements across dimensions of a tensor. |
| 57 | + /// </summary> |
| 58 | + /// <param name="input_tensor"></param> |
| 59 | + /// <param name="axis"></param> |
| 60 | + /// <param name="keepdims"></param> |
| 61 | + /// <param name="name"></param> |
| 62 | + /// <returns></returns> |
| 63 | + public static Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) |
| 64 | + { |
| 65 | + var r = _ReductionDims(input_tensor, axis); |
| 66 | + if (axis == null) |
| 67 | + { |
| 68 | + var m = gen_math_ops.prod(input_tensor, r, keepdims, name); |
| 69 | + return _may_reduce_to_scalar(keepdims, axis, m); |
| 70 | + } |
| 71 | + else |
| 72 | + { |
| 73 | + var m = gen_math_ops.prod(input_tensor, axis, keepdims, name); |
| 74 | + return _may_reduce_to_scalar(keepdims, axis, m); |
| 75 | + } |
44 | 76 | } |
| 77 | + |
45 | 78 | /// <summary> |
46 | 79 | /// Returns (x - y)(x - y) element-wise. |
47 | 80 | /// </summary> |
@@ -134,7 +167,10 @@ public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bo |
134 | 167 |
|
135 | 168 | public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) |
136 | 169 | { |
137 | | - return _may_reduce_to_scalar(keepdims, axis, gen_math_ops._max(input_tensor, (int[])_ReductionDims(input_tensor, axis), keepdims, name)); |
| 170 | + var r = _ReductionDims(input_tensor, axis); |
| 171 | + var max = (axis != null) ? gen_math_ops._max(input_tensor, axis, keepdims, name) : |
| 172 | + gen_math_ops._max(input_tensor, r, keepdims, name); |
| 173 | + return _may_reduce_to_scalar(keepdims, axis, max); |
138 | 174 | } |
139 | 175 |
|
140 | 176 | /// <summary> |
@@ -197,18 +233,19 @@ private static Tensor _ReductionDims(Tensor x, Tensor axis) |
197 | 233 | } |
198 | 234 | } |
199 | 235 |
|
200 | | - private static object _ReductionDims(Tensor x, int[] axis) |
| 236 | + private static Tensor _ReductionDims(Tensor x, int[] axis) |
201 | 237 | { |
202 | 238 | if (axis != null) |
203 | 239 | { |
204 | | - return axis; |
| 240 | + // should return axis. or check before. |
| 241 | + return null; |
205 | 242 | } |
206 | 243 | else |
207 | 244 | { |
208 | | - var rank = array_ops.rank(x); |
| 245 | + var rank = common_shapes.rank(x); |
209 | 246 | if (rank != null) |
210 | 247 | { |
211 | | - return constant_op.constant(np.arange(rank), TF_DataType.TF_INT32); |
| 248 | + return constant_op.constant(np.arange(rank.Value), TF_DataType.TF_INT32); |
212 | 249 | } |
213 | 250 | return range(0, rank, 1); |
214 | 251 | } |
@@ -303,5 +340,20 @@ public static Tensor conj(Tensor x, string name = null) |
303 | 340 | return x; |
304 | 341 | }); |
305 | 342 | } |
| 343 | + |
| 344 | + public static Tensor truediv(Tensor x, Tensor y, string name = null) |
| 345 | + => _truediv_python3(x, y, name); |
| 346 | + |
| 347 | + public static Tensor _truediv_python3(Tensor x, Tensor y, string name = null) |
| 348 | + { |
| 349 | + return with(ops.name_scope(name, "truediv", new { x, y }), scope => |
| 350 | + { |
| 351 | + name = scope; |
| 352 | + var x_dtype = x.dtype.as_base_dtype(); |
| 353 | + var y_dtype = y.dtype.as_base_dtype(); |
| 354 | + |
| 355 | + return gen_math_ops.real_div(x, y, name: name); |
| 356 | + }); |
| 357 | + } |
306 | 358 | } |
307 | 359 | } |
0 commit comments