Skip to content

Commit c233e70

Browse files
hcurEsther2013
authored andcommitted
implement math_ops.real, reduce_variance, reduce_std
* implement math_ops.real, reduce_variance, reduce_std * add dtype.real_dtype() * add outward-facing api functions
1 parent a2ec9b3 commit c233e70

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,9 @@ public static Tensor truediv(Tensor x, Tensor y, string name = null)
422422
public Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range")
423423
=> math_ops.range(start, limit: limit, delta: delta, dtype: dtype, name: name);
424424

425+
public Tensor real(Tensor input, string name = null)
426+
=> math_ops.real(input, name);
427+
425428
/// <summary>
426429
/// Computes the "logical or" of elements across dimensions of a tensor.
427430
/// </summary>
@@ -509,6 +512,12 @@ public Tensor reduce_max(Tensor input_tensor, int axis, bool keepdims = false, s
509512
public Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
510513
=> math_ops.reduce_min(input_tensor, axis, keepdims, name);
511514

515+
public Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
516+
=> math_ops.reduce_std(input_tensor, axis, keepdims, name);
517+
518+
public Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
519+
=> math_ops.reduce_variance(input_tensor, axis, keepdims, name);
520+
512521
public Tensor sigmoid<T>(T x, string name = null)
513522
=> math_ops.sigmoid(x, name: name);
514523

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,22 @@ public static Tensor not_equal<Tx, Ty>(Tx x, Ty y, string name = null)
229229

230230
public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null)
231231
=> gen_math_ops.mul_no_nan(x, y, name: name);
232+
233+
public static Tensor real(Tensor input, string name = null)
234+
{
235+
using (var name_ = ops.name_scope(name, "Real", new [] {input}))
236+
{
237+
input = ops.convert_to_tensor(input, name: "input");
238+
if (input.dtype.is_complex())
239+
{
240+
var real_dtype = input.dtype.real_dtype();
241+
return real(input, name: name);
242+
} else
243+
{
244+
return input;
245+
}
246+
}
247+
}
232248

233249
/// <summary>
234250
/// Computes the mean of elements across dimensions of a tensor.
@@ -295,6 +311,46 @@ public static Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool ke
295311
}
296312
}
297313

314+
public static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
315+
{
316+
if (name == null)
317+
name = "reduce_std";
318+
// else {name = name;}
319+
320+
using (ops.name_scope(name))
321+
{
322+
var variance = reduce_variance(input_tensor, axis: axis, keepdims: keepdims);
323+
return gen_math_ops.sqrt(variance);
324+
}
325+
}
326+
327+
public static Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
328+
{
329+
if (name == null)
330+
name = "reduce_variance";
331+
// else {name = name;}
332+
333+
using (ops.name_scope(name))
334+
{
335+
var means = reduce_mean(input_tensor, axis: axis, keepdims: true);
336+
if (means.dtype.is_integer())
337+
throw new TypeError("Input must be either real or complex");
338+
var diff = input_tensor - means;
339+
340+
Tensor squared_deviations;
341+
if (diff.dtype.is_complex())
342+
{
343+
var real_dtype = diff.dtype.real_dtype();
344+
squared_deviations = real(
345+
gen_math_ops.mul(conj(diff), diff));
346+
} else
347+
{
348+
squared_deviations = gen_math_ops.square(diff);
349+
}
350+
return reduce_mean(squared_deviations, axis: axis, keepdims: keepdims);
351+
}
352+
}
353+
298354
public static Tensor sigmoid<T>(T x, string name = null)
299355
=> tf_with(ops.name_scope(name, "Sigmoid", x), scope =>
300356
{

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,5 +278,16 @@ public static bool is_compatible_with(this TF_DataType self, TF_DataType other)
278278
{
279279
return self.as_datatype_enum() == other.as_datatype_enum();
280280
}
281+
282+
public static TF_DataType real_dtype(this TF_DataType self)
283+
{
284+
TF_DataType base_ = self.as_base_dtype();
285+
if (base_ == complex64)
286+
return float32;
287+
else if (base_ == complex128)
288+
return float64;
289+
else
290+
return self;
291+
}
281292
}
282293
}

0 commit comments

Comments
 (0)