Skip to content

Commit 115d489

Browse files
committed
implemented _log_prob in normal.py
1 parent 4481d19 commit 115d489

File tree

4 files changed

+58
-6
lines changed

4 files changed

+58
-6
lines changed

src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class Distribution : _BaseDistribution
3535
/// <param name="name"> Python `str` prepended to names of ops created by this function.</param>
3636
/// <returns>log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`.</returns>
3737

38-
/*
38+
3939
public Tensor log_prob(Tensor value, string name = "log_prob")
4040
{
4141
return _call_log_prob(value, name);
@@ -45,18 +45,39 @@ private Tensor _call_log_prob (Tensor value, string name)
4545
{
4646
with(ops.name_scope(name, "moments", new { value }), scope =>
4747
{
48-
value = _convert_to_tensor(value, "value", _dtype);
48+
try
49+
{
50+
return _log_prob(value);
51+
}
52+
catch (Exception e1)
53+
{
54+
try
55+
{
56+
return math_ops.log(_prob(value));
57+
} catch (Exception e2)
58+
{
59+
throw new NotImplementedException();
60+
}
61+
}
4962
});
63+
return null;
64+
}
5065

66+
private Tensor _log_prob(Tensor value)
67+
{
5168
throw new NotImplementedException();
52-
5369
}
5470

55-
private Tensor _convert_to_tensor(Tensor value, string name = null, TF_DataType preferred_dtype)
71+
private Tensor _prob(Tensor value)
5672
{
5773
throw new NotImplementedException();
5874
}
59-
*/
75+
76+
public TF_DataType dtype()
77+
{
78+
return this._dtype;
79+
}
80+
6081

6182
/// <summary>
6283
/// Constructs the `Distribution'

src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using System.Collections.Generic;
23
using Tensorflow;
34

@@ -80,7 +81,7 @@ public Tensor _batch_shape()
8081

8182
private Tensor _log_prob(Tensor x)
8283
{
83-
return _log_unnormalized_prob(_z(x));
84+
return _log_unnormalized_prob(_z(x)) -_log_normalization();
8485
}
8586

8687
private Tensor _log_unnormalized_prob (Tensor x)
@@ -92,5 +93,11 @@ private Tensor _z (Tensor x)
9293
{
9394
return (x - this._loc) / this._scale;
9495
}
96+
97+
private Tensor _log_normalization()
98+
{
99+
Tensor t = new Tensor(Math.Log(2.0 * Math.PI));
100+
return 0.5 * t + math_ops.log(scale());
101+
}
95102
}
96103
}

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,32 @@ public static Tensor squared_difference(Tensor x, Tensor y, string name = null)
4848
return _op.outputs[0];
4949
}
5050

51+
/// <summary>
52+
/// Computes square of x element-wise.
53+
/// </summary>
54+
/// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.</param>
55+
/// <param name="name"> A name for the operation (optional).</param>
56+
/// <returns> A `Tensor`. Has the same type as `x`.</returns>
5157
public static Tensor square(Tensor x, string name = null)
5258
{
5359
var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x });
5460

5561
return _op.outputs[0];
5662
}
5763

64+
/// <summary>
65+
/// Computes natural logarithm of x element-wise.
66+
/// </summary>
67+
/// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`.</param>
68+
/// <param name="name"> name: A name for the operation (optional).</param>
69+
/// <returns> A `Tensor`. Has the same type as `x`.</returns>
70+
public static Tensor log(Tensor x, string name = null)
71+
{
72+
var _op = _op_def_lib._apply_op_helper("Log", name, args: new { x });
73+
74+
return _op.outputs[0];
75+
}
76+
5877
public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= "")
5978
{
6079
var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate });

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ public static Tensor square(Tensor x, string name = null)
6060
return gen_math_ops.square(x, name);
6161
}
6262

63+
public static Tensor log(Tensor x, string name = null)
64+
{
65+
return gen_math_ops.log(x, name);
66+
}
67+
6368
/// <summary>
6469
/// Helper function for reduction ops.
6570
/// </summary>

0 commit comments

Comments
 (0)