@@ -13,6 +13,7 @@ You may obtain a copy of the License at
1313 See the License for the specific language governing permissions and
1414 limitations under the License.
1515******************************************************************************/
16+ using static Tensorflow . Binding ;
1617
1718namespace Tensorflow
1819{
@@ -37,8 +38,8 @@ public Tensor diag(Tensor diagonal, string name = null)
3738 public Tensor matmul ( Tensor a , Tensor b )
3839 => math_ops . matmul ( a , b ) ;
3940
40- public Tensor batch_matmul ( Tensor x , Tensor y )
41- => gen_math_ops . batch_mat_mul ( x , y ) ;
41+ public Tensor batch_matmul ( Tensor x , Tensor y , bool adj_x = false , bool adj_y = false , string name = null )
42+ => math_ops . batch_matmul ( x , y , adj_x : adj_x , adj_y : adj_y , name : name ) ;
4243 }
4344
4445 public Tensor diag ( Tensor diagonal , string name = null )
@@ -47,7 +48,32 @@ public Tensor diag(Tensor diagonal, string name = null)
4748 public Tensor matmul ( Tensor a , Tensor b )
4849 => math_ops . matmul ( a , b ) ;
4950
50- public Tensor batch_matmul ( Tensor x , Tensor y )
51- => gen_math_ops . batch_mat_mul ( x , y ) ;
51+ /// <summary>
52+ /// Multiply slices of the two matrices "x" and "y".
53+ /// </summary>
54+ /// <remarks>
55+ /// The `BatchMatMul` operation is embedded into the
56+ /// `MatMul` operation on the DLL side. However the expected
57+ /// attributes are not the same, hence we need to expose this
58+ /// method to have the right args list on the `_apply_op_helper`
59+ /// function.
60+ ///
61+ /// For each rank > 2 the first rank - 2 dimensions are considered
62+ /// as fixed, and have to be consistent across the two matrices. A
63+ /// common matrix multiplication is then applied over the residual
64+ /// 2 dimensions.
65+ ///
66+ /// e.g.
67+ /// x is (3, 6, 12); y is (3, 12, 6)
68+ /// batch_matmul(x, y) ==> (3, 6, 6)
69+ /// </remarks>
70+ /// <param name="x"></param>
71+ /// <param name="y"></param>
72+ /// <param name="adj_x"></param>
73+ /// <param name="adj_y"></param>
74+ /// <param name="name"></param>
75+ /// <returns></returns>
76+ public Tensor batch_matmul ( Tensor x , Tensor y , bool adj_x = false , bool adj_y = false , string name = null )
77+ => math_ops . batch_matmul ( x , y , adj_x : adj_x , adj_y : adj_y , name : name ) ;
5278 }
5379}
0 commit comments