Skip to content

Commit ee7bb7c

Browse files
committed
np.linalg.norm
1 parent 5a3ec5a commit ee7bb7c

File tree

6 files changed

+51
-1
lines changed

6 files changed

+51
-1
lines changed

src/TensorFlowNET.Core/APIs/tf.linalg.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ public Tensor diag(Tensor diagonal, string name = null)
3939
public Tensor matmul(Tensor a, Tensor b)
4040
=> math_ops.matmul(a, b);
4141

42+
public Tensor norm(Tensor a, string ord = "euclidean", Axis axis = null, string name = null)
43+
=> ops.norm(a, ord: ord, axis: axis, name: name);
44+
4245
public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
4346
=> math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name);
4447

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ public static int len(object a)
166166
return arr.Count;
167167
case IEnumerable enumerable:
168168
return enumerable.OfType<object>().Count();
169+
case Axis axis:
170+
return axis.size;
169171
case Shape arr:
170172
return arr.ndim;
171173
}
Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using static Tensorflow.Binding;
45

56
namespace Tensorflow.NumPy
67
{
78
public class LinearAlgebraImpl
89
{
10+
[AutoNumPy]
911
public NDArray lstsq(NDArray a, NDArray b, string rcond = "warn")
12+
=> new NDArray(tf.linalg.lstsq(a, b));
13+
14+
[AutoNumPy]
15+
public NDArray norm(NDArray a, Axis axis = null)
1016
{
11-
return a;
17+
if (a.dtype.is_integer())
18+
{
19+
var float_a = math_ops.cast(a, dtype: tf.float32);
20+
return new NDArray(tf.linalg.norm(float_a, axis: axis));
21+
}
22+
23+
return new NDArray(tf.linalg.norm(a, axis: axis));
1224
}
1325
}
1426
}

src/TensorFlowNET.Core/Operations/linalg_ops.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,21 @@ public Tensor matrix_solve_ls(Tensor matrix, Tensor rhs,
5252
return _composite_impl(matrix, rhs, l2_regularizer: l2_regularizer);
5353
}
5454

55+
public Tensor norm(Tensor tensor, string ord = "euclidean", Axis axis = null, string name = null, bool keepdims = true)
56+
{
57+
var is_matrix_norm = axis != null && len(axis) == 2;
58+
return tf_with(ops.name_scope(name, default_name: "norm", tensor), scope =>
59+
{
60+
if (is_matrix_norm)
61+
throw new NotImplementedException("");
62+
var result = math_ops.sqrt(math_ops.reduce_sum(tensor * math_ops.conj(tensor), axis, keepdims: true));
63+
64+
if(!keepdims)
65+
result = array_ops.squeeze(result, axis);
66+
return result;
67+
});
68+
}
69+
5570
Tensor _composite_impl(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null)
5671
{
5772
Shape matrix_shape = matrix.shape.dims.Skip(matrix.shape.ndim - 2).ToArray();

src/TensorFlowNET.Core/Operations/nn_impl.py.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,16 @@ public static (Tensor, Tensor) moments(Tensor x,
109109
});
110110
}
111111

112+
public static Tensor normalize(Tensor tensor, string ord = "euclidean", Axis axis = null, string name = null)
113+
{
114+
return tf_with(ops.name_scope(name, "normalize", tensor), scope =>
115+
{
116+
var norm = tf.linalg.norm(tensor, ord: ord, axis: axis, name: name);
117+
var normalized = tensor / norm;
118+
return normalized;
119+
});
120+
}
121+
112122
public static Tensor batch_normalization(Tensor x,
113123
Tensor mean,
114124
Tensor variance,

test/TensorFlowNET.UnitTest/NumPy/LinearAlgebra.Test.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,13 @@ public void lstsq()
1919
{
2020

2121
}
22+
23+
[TestMethod]
24+
public void norm()
25+
{
26+
var x = np.arange(9) - 4;
27+
var y = x.reshape((3, 3));
28+
var norm = np.linalg.norm(y);
29+
}
2230
}
2331
}

0 commit comments

Comments
 (0)