Skip to content

Commit 43e59ca

Browse files
committed
np.argsort
1 parent 9c5692b commit 43e59ca

File tree

4 files changed

+80
-2
lines changed

4 files changed

+80
-2
lines changed

src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public static NDArray argmax(NDArray a, Axis axis = null)
1414

1515
[AutoNumPy]
1616
public static NDArray argsort(NDArray a, Axis axis = null)
17-
=> new NDArray(math_ops.argmax(a, axis ?? -1));
17+
=> new NDArray(sort_ops.argsort(a, axis: axis ?? -1));
1818

1919
[AutoNumPy]
2020
public static (NDArray, NDArray) unique(NDArray a)

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor
281281
data_format
282282
}));
283283

284-
public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null)
284+
public static Tensor[] top_kv2<T>(Tensor input, T k, bool sorted = true, string name = null)
285285
{
286286
var _op = tf.OpDefLib._apply_op_helper("TopKV2", name: name, args: new
287287
{
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using Tensorflow.Operations;
18+
using static Tensorflow.Binding;
19+
20+
namespace Tensorflow
21+
{
22+
public class sort_ops
23+
{
24+
public static Tensor argsort(Tensor values, Axis axis = null, string direction = "ASCENDING", bool stable = false, string name = null)
25+
{
26+
axis = axis ?? new Axis(-1);
27+
var k = array_ops.shape(values)[axis];
28+
values = -values;
29+
var (_, indices) = tf.Context.ExecuteOp("TopKV2", name,
30+
new ExecuteOpArgs(values, k).SetAttributes(new
31+
{
32+
sorted = true
33+
}));
34+
return indices;
35+
}
36+
37+
public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null)
38+
=> tf.Context.ExecuteOp("MatrixInverse", name,
39+
new ExecuteOpArgs(input).SetAttributes(new
40+
{
41+
adjoint
42+
}));
43+
}
44+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using Tensorflow;
7+
using Tensorflow.NumPy;
8+
using static Tensorflow.Binding;
9+
10+
namespace TensorFlowNET.UnitTest.NumPy
11+
{
12+
/// <summary>
13+
/// https://numpy.org/doc/stable/user/basics.indexing.html
14+
/// </summary>
15+
[TestClass]
16+
public class ArraySortingTest : EagerModeTestBase
17+
{
18+
/// <summary>
19+
/// https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
20+
/// </summary>
21+
[TestMethod]
22+
public void argsort()
23+
{
24+
var x = np.array(new[] { 3, 1, 2 });
25+
var ind = np.argsort(x);
26+
Assert.AreEqual(ind, new[] { 1, 2, 0 });
27+
28+
var y = np.array(new[,] { { 0, 3 }, { 2, 2 } });
29+
ind = np.argsort(y, axis: 0);
30+
Assert.AreEqual(ind[0], new[] { 0, 1 });
31+
Assert.AreEqual(ind[1], new[] { 1, 0 });
32+
}
33+
}
34+
}

0 commit comments

Comments
 (0)