Skip to content

Commit c618680

Browse files
authored
Merge pull request #192 from PppBr/master
implemented naive bayes predict API
2 parents 1597a0b + f8b618c commit c618680

File tree

10 files changed

+229
-12
lines changed

10 files changed

+229
-12
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public static partial class tf
8+
{
9+
public static Tensor exp(Tensor x,
10+
string name = null) => gen_math_ops.exp(x, name);
11+
12+
}
13+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public static partial class tf
8+
{
9+
public static Tensor reduce_logsumexp(Tensor input_tensor,
10+
int[] axis = null,
11+
bool keepdims = false,
12+
string name = null) => math_ops.reduce_logsumexp(input_tensor, axis, keepdims, name);
13+
14+
}
15+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public static partial class tf
8+
{
9+
public static Tensor reshape(Tensor tensor,
10+
Tensor shape,
11+
string name = null) => gen_array_ops.reshape(tensor, shape, name);
12+
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public static partial class tf
8+
{
9+
public static Tensor tile(Tensor input,
10+
Tensor multiples,
11+
string name = null) => gen_array_ops.tile(input, multiples, name);
12+
13+
}
14+
}

src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class Distribution : _BaseDistribution
3535
/// <param name="name"> Python `str` prepended to names of ops created by this function.</param>
3636
/// <returns>log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`.</returns>
3737

38-
/*
38+
3939
public Tensor log_prob(Tensor value, string name = "log_prob")
4040
{
4141
return _call_log_prob(value, name);
@@ -45,18 +45,39 @@ private Tensor _call_log_prob (Tensor value, string name)
4545
{
4646
with(ops.name_scope(name, "moments", new { value }), scope =>
4747
{
48-
value = _convert_to_tensor(value, "value", _dtype);
48+
try
49+
{
50+
return _log_prob(value);
51+
}
52+
catch (Exception e1)
53+
{
54+
try
55+
{
56+
return math_ops.log(_prob(value));
57+
} catch (Exception e2)
58+
{
59+
throw new NotImplementedException();
60+
}
61+
}
4962
});
63+
return null;
64+
}
5065

66+
private Tensor _log_prob(Tensor value)
67+
{
5168
throw new NotImplementedException();
52-
5369
}
5470

55-
private Tensor _convert_to_tensor(Tensor value, string name = null, TF_DataType preferred_dtype)
71+
private Tensor _prob(Tensor value)
5672
{
5773
throw new NotImplementedException();
5874
}
59-
*/
75+
76+
public TF_DataType dtype()
77+
{
78+
return this._dtype;
79+
}
80+
6081

6182
/// <summary>
6283
/// Constructs the `Distribution'

src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using System.Collections.Generic;
23
using Tensorflow;
34

@@ -80,7 +81,7 @@ public Tensor _batch_shape()
8081

8182
private Tensor _log_prob(Tensor x)
8283
{
83-
return _log_unnormalized_prob(_z(x));
84+
return _log_unnormalized_prob(_z(x)) -_log_normalization();
8485
}
8586

8687
private Tensor _log_unnormalized_prob (Tensor x)
@@ -92,5 +93,11 @@ private Tensor _z (Tensor x)
9293
{
9394
return (x - this._loc) / this._scale;
9495
}
96+
97+
private Tensor _log_normalization()
98+
{
99+
Tensor t = new Tensor(Math.Log(2.0 * Math.PI));
100+
return 0.5 * t + math_ops.log(scale());
101+
}
95102
}
96103
}

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ public static Tensor rank(Tensor input, string name = null)
6666
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
6767
=> ones_like_impl(tensor, dtype, name, optimize);
6868

69+
public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
70+
{
71+
return gen_array_ops.reshape(tensor, shape, null);
72+
}
73+
6974
private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
7075
{
7176
return with(ops.name_scope(name, "ones_like", new { tensor }), scope =>

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,58 @@ public static Tensor squared_difference(Tensor x, Tensor y, string name = null)
4848
return _op.outputs[0];
4949
}
5050

51+
/// <summary>
52+
/// Computes square of x element-wise.
53+
/// </summary>
54+
/// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.</param>
55+
/// <param name="name"> A name for the operation (optional).</param>
56+
/// <returns> A `Tensor`. Has the same type as `x`.</returns>
57+
public static Tensor square(Tensor x, string name = null)
58+
{
59+
var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x });
60+
61+
return _op.outputs[0];
62+
}
63+
64+
/// <summary>
65+
/// Returns which elements of x are finite.
66+
/// </summary>
67+
/// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`.</param>
68+
/// <param name="name"> A name for the operation (optional).</param>
69+
/// <returns> A `Tensor` of type `bool`.</returns>
70+
public static Tensor is_finite(Tensor x, string name = null)
71+
{
72+
var _op = _op_def_lib._apply_op_helper("IsFinite", name, args: new { x });
73+
74+
return _op.outputs[0];
75+
}
76+
77+
/// <summary>
78+
/// Computes exponential of x element-wise. \\(y = e^x\\).
79+
/// </summary>
80+
/// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`.</param>
81+
/// <param name="name"> A name for the operation (optional).</param>
82+
/// <returns> A `Tensor`. Has the same type as `x`.</returns>
83+
public static Tensor exp(Tensor x, string name = null)
84+
{
85+
var _op = _op_def_lib._apply_op_helper("Exp", name, args: new { x });
86+
87+
return _op.outputs[0];
88+
}
89+
90+
/// <summary>
91+
/// Computes natural logarithm of x element-wise.
92+
/// </summary>
93+
/// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`.</param>
94+
/// <param name="name"> name: A name for the operation (optional).</param>
95+
/// <returns> A `Tensor`. Has the same type as `x`.</returns>
96+
public static Tensor log(Tensor x, string name = null)
97+
{
98+
var _op = _op_def_lib._apply_op_helper("Log", name, args: new { x });
99+
100+
return _op.outputs[0];
101+
}
102+
51103
public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= "")
52104
{
53105
var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate });
@@ -134,6 +186,13 @@ public static Tensor maximum<T1, T2>(T1 x, T2 y, string name = null)
134186
return _op.outputs[0];
135187
}
136188

189+
public static Tensor _max(Tensor input, int[] axis, bool keep_dims=false, string name = null)
190+
{
191+
var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims });
192+
193+
return _op.outputs[0];
194+
}
195+
137196
public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null)
138197
{
139198
var _op = _op_def_lib._apply_op_helper("Pow", name, args: new { x, y });

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@ public static Tensor square_difference(Tensor x, Tensor y, string name = null)
5757

5858
public static Tensor square(Tensor x, string name = null)
5959
{
60-
throw new NotImplementedException();
60+
return gen_math_ops.square(x, name);
61+
}
62+
63+
public static Tensor log(Tensor x, string name = null)
64+
{
65+
return gen_math_ops.log(x, name);
6166
}
6267

6368
/// <summary>
@@ -82,6 +87,51 @@ public static Tensor reduced_shape(Tensor input_shape, Tensor axes)
8287
return gen_data_flow_ops.dynamic_stitch(a1, a2);
8388
}
8489

90+
/// <summary>
91+
/// Computes log(sum(exp(elements across dimensions of a tensor))).
92+
/// Reduces `input_tensor` along the dimensions given in `axis`.
93+
/// Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
94+
/// entry in `axis`. If `keepdims` is true, the reduced dimensions
95+
/// are retained with length 1.
96+
97+
/// If `axis` has no entries, all dimensions are reduced, and a
98+
/// tensor with a single element is returned.
99+
100+
/// This function is more numerically stable than log(sum(exp(input))). It avoids
101+
/// overflows caused by taking the exp of large inputs and underflows caused by
102+
/// taking the log of small inputs.
103+
/// </summary>
104+
/// <param name="input_tensor"> The tensor to reduce. Should have numeric type.</param>
105+
/// <param name="axis"> The dimensions to reduce. If `None` (the default), reduces all
106+
/// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param>
107+
/// <param name="keepdims"></param>
108+
/// <returns> The reduced tensor.</returns>
109+
public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
110+
{
111+
with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope =>
112+
{
113+
var raw_max = reduce_max(input_tensor, axis, true);
114+
var my_max = array_ops.stop_gradient(array_ops.where(gen_math_ops.is_finite(raw_max), raw_max, array_ops.zeros_like(raw_max)));
115+
var result = gen_math_ops.log(
116+
reduce_sum(
117+
gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)),
118+
new Tensor(axis),
119+
keepdims));
120+
if (!keepdims)
121+
{
122+
my_max = array_ops.reshape(my_max, array_ops.shape(result));
123+
}
124+
result = gen_math_ops.add(result, my_max);
125+
return _may_reduce_to_scalar(keepdims, axis, result);
126+
});
127+
return null;
128+
}
129+
130+
public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
131+
{
132+
return _may_reduce_to_scalar(keepdims, axis, gen_math_ops._max(input_tensor, (int[])_ReductionDims(input_tensor, axis), keepdims, name));
133+
}
134+
85135
/// <summary>
86136
/// Casts a tensor to type `int32`.
87137
/// </summary>

test/TensorFlowNET.Examples/NaiveBayesClassifier.cs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace TensorFlowNET.Examples
1212
/// </summary>
1313
public class NaiveBayesClassifier : Python, IExample
1414
{
15+
public Normal dist { get; set; }
1516
public void Run()
1617
{
1718
np.array<float>(1.0f, 1.0f);
@@ -72,16 +73,34 @@ public void fit(NDArray X, NDArray y)
7273
// Create a 3x2 univariate normal distribution with the
7374
// Known mean and variance
7475
var dist = tf.distributions.Normal(mean, tf.sqrt(variance));
75-
76+
this.dist = dist;
7677
}
7778

78-
public void predict (NDArray X)
79+
public Tensor predict (NDArray X)
7980
{
80-
// assert self.dist is not None
81-
// nb_classes, nb_features = map(int, self.dist.scale.shape)
81+
if (dist == null)
82+
{
83+
throw new ArgumentNullException("cant not find the model (normal distribution)!");
84+
}
85+
int nb_classes = (int) dist.scale().shape[0];
86+
int nb_features = (int)dist.scale().shape[1];
87+
88+
// Conditional probabilities log P(x|c) with shape
89+
// (nb_samples, nb_classes)
90+
Tensor tile = tf.tile(new Tensor(X), new Tensor(new int[] { -1, nb_classes, nb_features }));
91+
Tensor r = tf.reshape(tile, new Tensor(new int[] { -1, nb_classes, nb_features }));
92+
var cond_probs = tf.reduce_sum(dist.log_prob(r));
93+
// uniform priors
94+
var priors = np.log(np.array<double>((1.0 / nb_classes) * nb_classes));
8295

96+
// posterior log probability, log P(c) + log P(x|c)
97+
var joint_likelihood = tf.add(new Tensor(priors), cond_probs);
98+
// normalize to get (log)-probabilities
8399

84-
throw new NotFiniteNumberException();
100+
var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, true);
101+
var log_prob = joint_likelihood - norm_factor;
102+
// exp to get the actual probabilities
103+
return tf.exp(log_prob);
85104
}
86105
}
87106
}

0 commit comments

Comments
 (0)