Skip to content

Commit 2815724

Browse files
committed
add numpy api of np.moveaxis #891
1 parent 425b258 commit 2815724

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,8 @@ public partial class np
2525

2626
[AutoNumPy]
2727
public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays));
28+
29+
[AutoNumPy]
30+
public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination));
2831
}
2932
}

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,26 @@ public static Tensor[] meshgrid<T>(T[] array, bool copy = true, bool sparse = fa
792792
});
793793
}
794794

795+
public static Tensor moveaxis(NDArray array, Axis source, Axis destination)
796+
{
797+
List<int> perm = null;
798+
source = source.axis.Select(x => x < 0 ? array.rank + x : x).ToArray();
799+
destination = destination.axis.Select(x => x < 0 ? array.rank + x : x).ToArray();
800+
801+
if (array.shape.rank > -1)
802+
{
803+
perm = range(0, array.rank).Where(i => !source.axis.Contains(i)).ToList();
804+
foreach (var (dest, src) in zip(destination.axis, source.axis).OrderBy(x => x.Item1))
805+
{
806+
perm.Insert(dest, src);
807+
}
808+
}
809+
else
810+
throw new NotImplementedException("");
811+
812+
return array_ops.transpose(array, perm.ToArray());
813+
}
814+
795815
/// <summary>
796816
/// Computes the shape of a broadcast given symbolic shapes.
797817
/// When shape_x and shape_y are Tensors representing shapes(i.e.the result of

test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,19 @@ public void expand_dims()
2424
y = np.expand_dims(x, axis: 1);
2525
Assert.AreEqual(y.shape, (2, 1));
2626
}
27+
28+
[TestMethod]
29+
public void moveaxis()
30+
{
31+
var x = np.zeros((3, 4, 5));
32+
var y = np.moveaxis(x, 0, -1);
33+
Assert.AreEqual(y.shape, (4, 5, 3));
34+
35+
y = np.moveaxis(x, (0, 1), (-1, -2));
36+
Assert.AreEqual(y.shape, (5, 4, 3));
37+
38+
y = np.moveaxis(x, (0, 1, 2), (-1, -2, -3));
39+
Assert.AreEqual(y.shape, (5, 4, 3));
40+
}
2741
}
2842
}

0 commit comments

Comments
 (0)