Skip to content

Commit 967fc43

Browse files
authored
Merge pull request #275 from PppBr/master
fix bug in normal distribution: _log_prob
2 parents d2e7078 + 41bb848 commit 967fc43

File tree

6 files changed

+25
-4
lines changed

6 files changed

+25
-4
lines changed

TensorFlow.NET.sln

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
Microsoft Visual Studio Solution File, Format Version 12.00
3-
# Visual Studio Version 16
4-
VisualStudioVersion = 16.0.28803.452
3+
# Visual Studio 15
4+
VisualStudioVersion = 15.0.28307.168
55
MinimumVisualStudioVersion = 10.0.40219.1
66
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.UnitTest", "test\TensorFlowNET.UnitTest\TensorFlowNET.UnitTest.csproj", "{029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}"
77
EndProject

data/nb_example.npy

14.2 KB
Binary file not shown.

src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public Tensor _batch_shape()
8282

8383
protected override Tensor _log_prob(Tensor x)
8484
{
85-
var log_prob = _log_unnormalized_prob(_z(x));
85+
var log_prob = _log_unnormalized_prob(x);
8686
var log_norm = _log_normalization();
8787
return tf.sub(log_prob, log_norm);
8888
}

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,10 @@ Docs: https://tensorflownet.readthedocs.io</Description>
6262
<Folder Include="Keras\Initializers\" />
6363
</ItemGroup>
6464

65+
<ItemGroup>
66+
<Reference Include="NumSharp.Core">
67+
<HintPath>..\..\..\..\NumSharp\src\NumSharp.Core\bin\Debug\netstandard2.0\NumSharp.Core.dll</HintPath>
68+
</Reference>
69+
</ItemGroup>
70+
6571
</Project>

test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using NumSharp;
66
using System.Linq;
77
using static Tensorflow.Python;
8+
using System.IO;
9+
using TensorFlowNET.Examples.Utility;
810

911
namespace TensorFlowNET.Examples
1012
{
@@ -34,7 +36,10 @@ public bool Run()
3436
var (xx, yy) = np.meshgrid(np.linspace(x_min, x_max, 30), np.linspace(y_min, y_max, 30));
3537
with(tf.Session(), sess =>
3638
{
37-
var samples = np.hstack<float>(xx.ravel().reshape(xx.size, 1), yy.ravel().reshape(yy.size, 1));
39+
//var samples = np.vstack<float>(xx.ravel(), yy.ravel());
40+
//samples = np.transpose(samples);
41+
var array = np.Load<double[,]>(Path.Join("nb", "nb_example.npy"));
42+
var samples = np.array(array).astype(np.float32);
3843
var Z = sess.run(predict(samples));
3944
});
4045

@@ -167,6 +172,10 @@ public void PrepareData()
167172
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
168173
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
169174
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2);
175+
176+
177+
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/nb_example.npy";
178+
Web.Download(url, "nb", "nb_example.npy");
170179
#endregion
171180
}
172181

test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,10 @@
2424
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
2525
</ItemGroup>
2626

27+
<ItemGroup>
28+
<Reference Include="NumSharp.Core">
29+
<HintPath>..\..\..\..\NumSharp\src\NumSharp.Core\bin\Debug\netstandard2.0\NumSharp.Core.dll</HintPath>
30+
</Reference>
31+
</ItemGroup>
32+
2733
</Project>

0 commit comments

Comments
 (0)