Skip to content

Commit e859b20

Browse files
committed
np.expand_dims
1 parent c7ee230 commit e859b20

File tree

4 files changed

+32
-8
lines changed

4 files changed

+32
-8
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,12 @@ public Tensor concat(IEnumerable<Tensor> values, int axis, string name = "concat
9999
/// <param name="input"></param>
100100
/// <param name="axis"></param>
101101
/// <param name="name"></param>
102-
/// <param name="dim"></param>
103102
/// <returns>
104103
/// A `Tensor` with the same data as `input`, but its shape has an additional
105104
/// dimension of size 1 added.
106105
/// </returns>
107-
public Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
108-
=> array_ops.expand_dims(input, axis, name, dim);
106+
public Tensor expand_dims(Tensor input, int axis = -1, string name = null)
107+
=> array_ops.expand_dims(input, axis, name);
109108

110109
/// <summary>
111110
/// Creates a tensor filled with a scalar value.

src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public partial class np
1515
public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException("");
1616

1717
[AutoNumPy]
18-
public static NDArray expand_dims(NDArray a, Axis? axis = null) => throw new NotImplementedException("");
18+
public static NDArray expand_dims(NDArray a, Axis? axis = null) => new NDArray(array_ops.expand_dims(a, axis: axis ?? -1));
1919

2020
[AutoNumPy]
2121
public static NDArray reshape(NDArray x1, Shape newshape) => x1.reshape(newshape);

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,7 @@ public static Tensor _autopacking_helper(IEnumerable<object> list_or_tuple, TF_D
300300
return result;
301301
}
302302

303-
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
304-
=> expand_dims_v2(input, axis, name);
305-
306-
private static Tensor expand_dims_v2(Tensor input, int axis, string name = null)
303+
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null)
307304
=> gen_array_ops.expand_dims(input, axis, name);
308305

309306
/// <summary>
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using Tensorflow;
7+
using Tensorflow.NumPy;
8+
9+
namespace TensorFlowNET.UnitTest.NumPy
10+
{
11+
/// <summary>
12+
/// https://numpy.org/doc/stable/reference/routines.array-manipulation.html
13+
/// </summary>
14+
[TestClass]
15+
public class ManipulationTest : EagerModeTestBase
16+
{
17+
[TestMethod]
18+
public void expand_dims()
19+
{
20+
var x = np.array(new[] { 1, 2 });
21+
var y = np.expand_dims(x, axis: 0);
22+
Assert.AreEqual(y.shape, (1, 2));
23+
24+
y = np.expand_dims(x, axis: 1);
25+
Assert.AreEqual(y.shape, (2, 1));
26+
}
27+
}
28+
}

0 commit comments

Comments
 (0)