Skip to content

Commit 1cf190f

Browse files
committed
tf.linalg.lstsq #823
1 parent 16ff7a3 commit 1cf190f

File tree

15 files changed

+206
-13
lines changed

15 files changed

+206
-13
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public Tensor where<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null)
152152
/// <param name="name"></param>
153153
/// <param name="conjugate"></param>
154154
/// <returns></returns>
155-
public Tensor transpose<T1>(T1 a, Shape perm = null, string name = "transpose", bool conjugate = false)
155+
public Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false)
156156
=> array_ops.transpose(a, perm, name, conjugate);
157157

158158
/// <summary>

src/TensorFlowNET.Core/APIs/tf.linalg.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ You may obtain a copy of the License at
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
******************************************************************************/
16+
using Tensorflow.NumPy;
1617
using static Tensorflow.Binding;
1718

1819
namespace Tensorflow
@@ -40,13 +41,20 @@ public Tensor matmul(Tensor a, Tensor b)
4041

4142
public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
4243
=> math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name);
44+
45+
public Tensor inv(Tensor input, bool adjoint = false, string name = null)
46+
=> ops.matrix_inverse(input, adjoint: adjoint, name: name);
47+
48+
public Tensor lstsq(Tensor matrix, Tensor rhs,
49+
NDArray l2_regularizer = null, bool fast = true, string name = null)
50+
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);
4351
}
4452

4553
public Tensor diag(Tensor diagonal, string name = null)
4654
=> gen_array_ops.diag(diagonal, name: name);
4755

48-
public Tensor matmul(Tensor a, Tensor b)
49-
=> math_ops.matmul(a, b);
56+
public Tensor matmul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false)
57+
=> math_ops.matmul(a, b, transpose_a: transpose_a, transpose_b: transpose_b);
5058

5159
/// <summary>
5260
/// Multiply slices of the two matrices "x" and "y".

src/TensorFlowNET.Core/NumPy/Axis.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ public static implicit operator Axis(Shape axis)
5050

5151
public static implicit operator Tensor(Axis axis)
5252
=> constant_op.constant(axis);
53+
54+
public override string ToString()
55+
=> $"({string.Join(", ", axis)})";
5356
}
5457
}
5558

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.NumPy
6+
{
7+
public class LinearAlgebraImpl
8+
{
9+
public NDArray lstsq(NDArray a, NDArray b, string rcond = "warn")
10+
{
11+
return a;
12+
}
13+
}
14+
}

src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public static implicit operator NDArray(double value)
4848
=> new NDArray(value);
4949

5050
public static implicit operator Tensor(NDArray nd)
51-
=> nd._tensor;
51+
=> nd?._tensor;
5252

5353
public static implicit operator NDArray(Tensor tensor)
5454
=> new NDArray(tensor);

src/TensorFlowNET.Core/Numpy/Numpy.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,7 @@ public static NpzDictionary<T> Load_Npz<T>(byte[] bytes)
105105
{
106106
throw new NotImplementedException("");
107107
}
108+
109+
public static LinearAlgebraImpl linalg = new LinearAlgebraImpl();
108110
}
109111
}

src/TensorFlowNET.Core/Numpy/Shape.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ public long[] strides
3838
}
3939
}
4040

41+
#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges
42+
public int Length => ndim;
43+
public long[] Slice(int start, int length)
44+
{
45+
var slice = new long[length];
46+
Array.Copy(_dims, start, slice, 0, length);
47+
return slice;
48+
}
49+
#endregion
50+
4151
private Shape()
4252
{
4353
}
@@ -107,7 +117,7 @@ public static implicit operator Tensor(Shape shape)
107117

108118
public long this[int n]
109119
{
110-
get => dims[n];
120+
get => n < 0 ? dims[ndim + n] : dims[n];
111121
set => dims[n] = value;
112122
}
113123

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -774,10 +774,10 @@ public static Tensor matrix_diag(Tensor diagonal,
774774
int k = 0,
775775
int num_rows = -1,
776776
int num_cols = -1,
777-
double padding_value = 0,
777+
float padding_value = 0f,
778778
string align = "RIGHT_LEFT")
779779
=> tf.Context.ExecuteOp("MatrixDiagV3", name,
780-
new ExecuteOpArgs(diagonal, k, num_rows, num_cols, padding_value)
780+
new ExecuteOpArgs(diagonal, k, num_rows, num_cols, ops.convert_to_tensor(padding_value, dtype: diagonal.dtype))
781781
.SetAttributes(new { align }));
782782

783783
public static Tensor matrix_set_diag(Tensor input,
@@ -900,7 +900,7 @@ public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null,
900900
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
901901
}
902902

903-
public static Tensor transpose<T1>(T1 a, Shape perm, string name = "transpose", bool conjugate = false)
903+
public static Tensor transpose<T1>(T1 a, Axis perm, string name = "transpose", bool conjugate = false)
904904
{
905905
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
906906
{

src/TensorFlowNET.Core/Operations/linalg_ops.cs

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ public Tensor eye(int num_rows,
2020
var diag_size = Math.Min(num_rows, num_columns);
2121
if (batch_shape == null)
2222
batch_shape = new Shape(new int[0]);
23-
var diag_shape = batch_shape.dims.concat(new long[] { diag_size });
23+
var batch_shape_tensor = ops.convert_to_tensor(batch_shape, dtype: tf.int32, name: "shape");
24+
var diag_shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { diag_size }) }, axis: 0);
2425

25-
long[] shape = null;
26+
Tensor shape = null;
2627
if (!is_square)
27-
shape = batch_shape.dims.concat(new long[] { num_rows, num_columns });
28+
shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { num_rows, num_columns }) }, axis: 0);
2829

2930
var diag_ones = array_ops.ones(diag_shape, dtype: dtype);
3031
if (is_square)
@@ -36,5 +37,81 @@ public Tensor eye(int num_rows,
3637
}
3738
});
3839
}
40+
41+
public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null)
42+
=> tf.Context.ExecuteOp("MatrixInverse", name,
43+
new ExecuteOpArgs(input).SetAttributes(new
44+
{
45+
adjoint
46+
}));
47+
48+
public Tensor matrix_solve_ls(Tensor matrix, Tensor rhs,
49+
Tensor l2_regularizer = null, bool fast = true, string name = null)
50+
{
51+
return _composite_impl(matrix, rhs, l2_regularizer: l2_regularizer);
52+
}
53+
54+
Tensor _composite_impl(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null)
55+
{
56+
Shape matrix_shape = matrix.shape[^2..];
57+
if (matrix_shape.IsFullyDefined)
58+
{
59+
if (matrix_shape[-2] >= matrix_shape[-1])
60+
return _overdetermined(matrix, rhs, l2_regularizer);
61+
else
62+
return _underdetermined(matrix, rhs, l2_regularizer);
63+
}
64+
65+
throw new NotImplementedException("");
66+
}
67+
68+
Tensor _overdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null)
69+
{
70+
var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: true);
71+
return cholesky_solve(chol, math_ops.matmul(matrix, rhs, adjoint_a: true));
72+
}
73+
74+
Tensor _underdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null)
75+
{
76+
var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: false);
77+
return math_ops.matmul(matrix, cholesky_solve(chol, rhs), adjoint_a: true);
78+
}
79+
80+
Tensor _RegularizedGramianCholesky(Tensor matrix, Tensor l2_regularizer, bool first_kind)
81+
{
82+
var gramian = math_ops.matmul(matrix, matrix, adjoint_a: first_kind, adjoint_b: !first_kind);
83+
84+
if (l2_regularizer != null)
85+
{
86+
var matrix_shape = array_ops.shape(matrix);
87+
var batch_shape = matrix_shape[":-2"];
88+
var small_dim = first_kind ? matrix_shape[-1] : matrix_shape[-2];
89+
var identity = eye(small_dim.numpy(), batch_shape: batch_shape.shape, dtype: matrix.dtype);
90+
var small_dim_static = matrix.shape[first_kind ? -1 : -2];
91+
identity.shape = matrix.shape[..^2].concat(new[] { small_dim_static, small_dim_static });
92+
gramian += l2_regularizer * identity;
93+
}
94+
95+
return cholesky(gramian);
96+
}
97+
98+
public Tensor cholesky(Tensor input, string name = null)
99+
=> tf.Context.ExecuteOp("Cholesky", name, new ExecuteOpArgs(input));
100+
101+
public Tensor cholesky_solve(Tensor chol, Tensor rhs, string name = null)
102+
=> tf_with(ops.name_scope(name, default_name: "eye", new { chol, rhs }), scope =>
103+
{
104+
var y = matrix_triangular_solve(chol, rhs, adjoint: false, lower: true);
105+
var x = matrix_triangular_solve(chol, y, adjoint: true, lower: true);
106+
return x;
107+
});
108+
109+
public Tensor matrix_triangular_solve(Tensor matrix, Tensor rhs, bool lower = true, bool adjoint = false, string name = null)
110+
=> tf.Context.ExecuteOp("MatrixTriangularSolve", name,
111+
new ExecuteOpArgs(matrix, rhs).SetAttributes(new
112+
{
113+
lower,
114+
adjoint
115+
}));
39116
}
40117
}

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,18 @@ public static Tensor matmul(Tensor a, Tensor b,
791791
if (transpose_b && adjoint_b)
792792
throw new ValueError("Only one of transpose_b and adjoint_b can be True.");
793793

794+
if(adjoint_a)
795+
{
796+
a = conj(a);
797+
transpose_a = true;
798+
}
799+
800+
if (adjoint_b)
801+
{
802+
b = conj(b);
803+
transpose_b = true;
804+
}
805+
794806
result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name);
795807
});
796808

0 commit comments

Comments
 (0)