Skip to content

Commit d19e513

Browse files
committed
implementing normal.py.cs API log_prob() for NB classifier prediction
1 parent c26fccf commit d19e513

File tree

4 files changed

+55
-4
lines changed

4 files changed

+55
-4
lines changed

src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,37 @@ public class Distribution : _BaseDistribution
2727
public List<Tensor> _graph_parents {get;set;}
2828
public string _name {get;set;}
2929

30+
31+
/// <summary>
32+
/// Log probability density/mass function.
33+
/// </summary>
34+
/// <param name="value"> `Tensor`.</param>
35+
/// <param name="name"> Python `str` prepended to names of ops created by this function.</param>
36+
/// <returns>log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`.</returns>
37+
38+
/*
39+
public Tensor log_prob(Tensor value, string name = "log_prob")
40+
{
41+
return _call_log_prob(value, name);
42+
}
43+
44+
private Tensor _call_log_prob (Tensor value, string name)
45+
{
46+
with(ops.name_scope(name, "moments", new { value }), scope =>
47+
{
48+
value = _convert_to_tensor(value, "value", _dtype);
49+
});
50+
51+
throw new NotImplementedException();
52+
53+
}
54+
55+
private Tensor _convert_to_tensor(Tensor value, string name = null, TF_DataType preferred_dtype)
56+
{
57+
throw new NotImplementedException();
58+
}
59+
*/
60+
3061
/// <summary>
3162
/// Constructs the `Distribution'
3263
/// **This is a private method for subclass use.**
@@ -47,7 +78,7 @@ public class Distribution : _BaseDistribution
4778
/// <param name = "name"> Name prefixed to Ops created by this class. Default: subclass name.</param>
4879
/// <returns> Two `Tensor` objects: `mean` and `variance`.</returns>
4980

50-
/*
81+
/*
5182
private Distribution (
5283
TF_DataType dtype,
5384
ReparameterizationType reparameterization_type,
@@ -66,6 +97,10 @@ private Distribution (
6697
this._name = name;
6798
}
6899
*/
100+
101+
102+
103+
69104
}
70105

71106
/// <summary>

src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,19 @@ public Tensor _batch_shape()
7878
return array_ops.broadcast_static_shape(new Tensor(_loc.shape), new Tensor(_scale.shape));
7979
}
8080

81+
private Tensor _log_prob(Tensor x)
82+
{
83+
return _log_unnormalized_prob(_z(x));
84+
}
85+
86+
private Tensor _log_unnormalized_prob (Tensor x)
87+
{
88+
return -0.5 * math_ops.square(_z(x));
89+
}
90+
91+
private Tensor _z (Tensor x)
92+
{
93+
return (x - this._loc) / this._scale;
94+
}
8195
}
8296
}

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ public static Tensor square_difference(Tensor x, Tensor y, string name = null)
5555
return m;
5656
}
5757

58+
public static Tensor square(Tensor x, string name = null)
59+
{
60+
throw new NotImplementedException();
61+
}
62+
5863
/// <summary>
5964
/// Helper function for reduction ops.
6065
/// </summary>

test/TensorFlowNET.Examples/NaiveBayesClassifier.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,10 @@ public class NaiveBayesClassifier : Python, IExample
1515
public void Run()
1616
{
1717
np.array<float>(1.0f, 1.0f);
18-
// var X = np.array<float>(np.array<float>(1.0f, 1.0f), np.array<float>(2.0f, 2.0f), np.array<float>(1.0f, -1.0f), np.array<float>(2.0f, -2.0f), np.array<float>(-1.0f, -1.0f), np.array<float>(-1.0f, 1.0f),);
19-
// var X = np.array<float[]>(new float[][] { new float[] { 1.0f, 1.0f}, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, });
2018
var X = np.array<float>(new float[][] { new float[] { 1.0f, 1.0f }, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, });
2119
var y = np.array<int>(0,0,1,1,2,2);
2220
fit(X, y);
2321
// Create a regular grid and classify each point
24-
2522
}
2623

2724
public void fit(NDArray X, NDArray y)

0 commit comments

Comments
 (0)