@@ -20,11 +20,12 @@ public Tensor eye(int num_rows,
2020 var diag_size = Math . Min ( num_rows , num_columns ) ;
2121 if ( batch_shape == null )
2222 batch_shape = new Shape ( new int [ 0 ] ) ;
23- var diag_shape = batch_shape . dims . concat ( new long [ ] { diag_size } ) ;
23+ var batch_shape_tensor = ops . convert_to_tensor ( batch_shape , dtype : tf . int32 , name : "shape" ) ;
24+ var diag_shape = array_ops . concat ( new [ ] { batch_shape_tensor , tf . constant ( new int [ ] { diag_size } ) } , axis : 0 ) ;
2425
25- long [ ] shape = null ;
26+ Tensor shape = null ;
2627 if ( ! is_square )
27- shape = batch_shape . dims . concat ( new long [ ] { num_rows , num_columns } ) ;
28+ shape = array_ops . concat ( new [ ] { batch_shape_tensor , tf . constant ( new int [ ] { num_rows , num_columns } ) } , axis : 0 ) ;
2829
2930 var diag_ones = array_ops . ones ( diag_shape , dtype : dtype ) ;
3031 if ( is_square )
@@ -36,5 +37,81 @@ public Tensor eye(int num_rows,
3637 }
3738 } ) ;
3839 }
40+
41+ public Tensor matrix_inverse ( Tensor input , bool adjoint = false , string name = null )
42+ => tf . Context . ExecuteOp ( "MatrixInverse" , name ,
43+ new ExecuteOpArgs ( input ) . SetAttributes ( new
44+ {
45+ adjoint
46+ } ) ) ;
47+
48+ public Tensor matrix_solve_ls ( Tensor matrix , Tensor rhs ,
49+ Tensor l2_regularizer = null , bool fast = true , string name = null )
50+ {
51+ return _composite_impl ( matrix , rhs , l2_regularizer : l2_regularizer ) ;
52+ }
53+
54+ Tensor _composite_impl ( Tensor matrix , Tensor rhs , Tensor l2_regularizer = null )
55+ {
56+ Shape matrix_shape = matrix . shape [ ^ 2 ..] ;
57+ if ( matrix_shape . IsFullyDefined )
58+ {
59+ if ( matrix_shape [ - 2 ] >= matrix_shape [ - 1 ] )
60+ return _overdetermined ( matrix , rhs , l2_regularizer ) ;
61+ else
62+ return _underdetermined ( matrix , rhs , l2_regularizer ) ;
63+ }
64+
65+ throw new NotImplementedException ( "" ) ;
66+ }
67+
68+ Tensor _overdetermined ( Tensor matrix , Tensor rhs , Tensor l2_regularizer = null )
69+ {
70+ var chol = _RegularizedGramianCholesky ( matrix , l2_regularizer : l2_regularizer , first_kind : true ) ;
71+ return cholesky_solve ( chol , math_ops . matmul ( matrix , rhs , adjoint_a : true ) ) ;
72+ }
73+
74+ Tensor _underdetermined ( Tensor matrix , Tensor rhs , Tensor l2_regularizer = null )
75+ {
76+ var chol = _RegularizedGramianCholesky ( matrix , l2_regularizer : l2_regularizer , first_kind : false ) ;
77+ return math_ops . matmul ( matrix , cholesky_solve ( chol , rhs ) , adjoint_a : true ) ;
78+ }
79+
80+ Tensor _RegularizedGramianCholesky ( Tensor matrix , Tensor l2_regularizer , bool first_kind )
81+ {
82+ var gramian = math_ops . matmul ( matrix , matrix , adjoint_a : first_kind , adjoint_b : ! first_kind ) ;
83+
84+ if ( l2_regularizer != null )
85+ {
86+ var matrix_shape = array_ops . shape ( matrix ) ;
87+ var batch_shape = matrix_shape [ ":-2" ] ;
88+ var small_dim = first_kind ? matrix_shape [ - 1 ] : matrix_shape [ - 2 ] ;
89+ var identity = eye ( small_dim . numpy ( ) , batch_shape : batch_shape . shape , dtype : matrix . dtype ) ;
90+ var small_dim_static = matrix . shape [ first_kind ? - 1 : - 2 ] ;
91+ identity . shape = matrix . shape [ ..^ 2 ] . concat ( new [ ] { small_dim_static , small_dim_static } ) ;
92+ gramian += l2_regularizer * identity ;
93+ }
94+
95+ return cholesky ( gramian ) ;
96+ }
97+
98+ public Tensor cholesky ( Tensor input , string name = null )
99+ => tf . Context . ExecuteOp ( "Cholesky" , name , new ExecuteOpArgs ( input ) ) ;
100+
101+ public Tensor cholesky_solve ( Tensor chol , Tensor rhs , string name = null )
102+ => tf_with ( ops . name_scope ( name , default_name : "eye" , new { chol , rhs } ) , scope =>
103+ {
104+ var y = matrix_triangular_solve ( chol , rhs , adjoint : false , lower : true ) ;
105+ var x = matrix_triangular_solve ( chol , y , adjoint : true , lower : true ) ;
106+ return x ;
107+ } ) ;
108+
109+ public Tensor matrix_triangular_solve ( Tensor matrix , Tensor rhs , bool lower = true , bool adjoint = false , string name = null )
110+ => tf . Context . ExecuteOp ( "MatrixTriangularSolve" , name ,
111+ new ExecuteOpArgs ( matrix , rhs ) . SetAttributes ( new
112+ {
113+ lower ,
114+ adjoint
115+ } ) ) ;
39116 }
40117}
0 commit comments