Skip to content

Commit cb4b248

Browse files
committed
np.random.normal
1 parent e859b20 commit cb4b248

File tree

5 files changed

+22
-1
lines changed

5 files changed

+22
-1
lines changed

src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ public NDArray randint(int low, int? high = null, Shape size = null, TF_DataType
3737
}
3838

3939
public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape size = null)
40-
=> throw new NotImplementedException("");
40+
=> new NDArray(random_ops.random_normal(size ?? Shape.Scalar, mean: loc, stddev: scale));
4141
}
4242
}

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ public NDArray this[NDArray mask]
4747
return GetData(mask.ToArray<int>());
4848
else if (mask.dtype == TF_DataType.TF_INT64)
4949
return GetData(mask.ToArray<long>().Select(x => Convert.ToInt32(x)).ToArray());
50+
else if (mask.dtype == TF_DataType.TF_FLOAT)
51+
return GetData(mask.ToArray<float>().Select(x => Convert.ToInt32(x)).ToArray());
5052

5153
throw new NotImplementedException("");
5254
}

src/TensorFlowNET.Core/NumPy/Numpy.Math.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ public partial class np
1818
[AutoNumPy]
1919
public static NDArray log(NDArray x) => new NDArray(tf.log(x));
2020

21+
[AutoNumPy]
22+
public static NDArray mean(NDArray x) => new NDArray(math_ops.reduce_mean(x));
23+
2124
[AutoNumPy]
2225
public static NDArray multiply(NDArray x1, NDArray x2) => new NDArray(tf.multiply(x1, x2));
2326

test/TensorFlowNET.UnitTest/EagerModeTestBase.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ public void TestInit()
1414
tf.Context.ensure_initialized();
1515
}
1616

17+
public bool Equal(float f1, float f2)
18+
{
19+
var tolerance = .000001f;
20+
return Math.Abs(f1 - f2) <= tolerance;
21+
}
22+
1723
public bool Equal(float[] f1, float[] f2)
1824
{
1925
bool ret = false;

test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,15 @@ public void permutation()
2323
Assert.AreEqual(x.shape, 10);
2424
Assert.AreNotEqual(x.ToArray<int>(), y.ToArray<int>());
2525
}
26+
27+
/// <summary>
28+
/// https://numpy.org/doc/stable/reference/random/generated/numpy.random.normal.html
29+
/// </summary>
30+
[TestMethod]
31+
public void normal()
32+
{
33+
var x = np.random.normal(0, 0.1f, 1000);
34+
Equal(np.mean(x), 0f);
35+
}
2636
}
2737
}

0 commit comments

Comments
 (0)