Skip to content

Commit c8643c4

Browse files
committed
Add tf.math.erf #738
1 parent 47f953b commit c8643c4

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ public class MathApi
2323
{
2424
public Tensor log(Tensor x, string name = null)
2525
=> gen_math_ops.log(x, name);
26+
27+
/// <summary>
28+
/// Computes the Gauss error function of `x` element-wise.
29+
/// </summary>
30+
/// <param name="x"></param>
31+
/// <param name="name"></param>
32+
/// <returns></returns>
33+
public Tensor erf(Tensor x, string name = null)
34+
=> math_ops.erf(x, name);
2635
}
2736

2837
public Tensor abs(Tensor x, string name = null)

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,29 @@ public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null)
265265
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null)
266266
=> gen_math_ops.equal(x, y, name: name);
267267

268+
/// <summary>
269+
/// Computes the Gauss error function of `x` element-wise.
270+
/// </summary>
271+
/// <param name="x"></param>
272+
/// <param name="name"></param>
273+
/// <returns></returns>
274+
public static Tensor erf(Tensor x, string name = null)
275+
=> tf.Context.RunInAutoMode2(
276+
() => tf.OpDefLib._apply_op_helper("Erf", name, new { x }).output,
277+
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
278+
"Erf", name,
279+
null,
280+
x).FirstOrDefault(),
281+
(op) =>
282+
{
283+
var attrs = new object[]
284+
{
285+
"T", op.get_attr<TF_DataType>("T")
286+
};
287+
tf.Runner.RecordGradient("Erf", op.inputs, attrs, op.outputs);
288+
},
289+
new Tensors(x));
290+
268291
public static Tensor sqrt(Tensor x, string name = null)
269292
=> gen_math_ops.sqrt(x, name: name);
270293

test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,14 @@ public void ReduceSum()
4848
var x5 = tf.reduce_sum(b, (0, 1));
4949
Assert.AreEqual(-4.7f, (float)x5);
5050
}
51+
52+
[TestMethod]
53+
public void Erf()
54+
{
55+
var erf = tf.math.erf(a, name: "erf");
56+
var expected = new float[] { 0.8427007f, -0.5204999f, 0.99999845f, -0.9970206f, 0f, -1f };
57+
var actual = erf.ToArray<float>();
58+
Assert.IsTrue(Equal(expected, actual));
59+
}
5160
}
5261
}

0 commit comments

Comments
 (0)