@@ -929,6 +929,72 @@ Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
929929 throw new NotImplementedException ( "tensordot" ) ;
930930 }
931931
932+ public static Tensor tensordot ( Tensor x , Tensor y , Tensor axes , string name = null )
933+ {
934+ Tensor _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
935+ {
936+ if ( a . shape . IsFullyDefined && isinstance ( axes , ( typeof ( List < object > ) , typeof ( Tuple ) ) ) )
937+ {
938+ var shape_a = a . shape . dims ;
939+
940+ // axes
941+ int iter = 0 ;
942+ foreach ( int i in axes )
943+ {
944+ if ( i >= 0 )
945+ axes [ 0 + iter ] = i ;
946+ else
947+ axes [ 0 + iter ] = i + len ( shape_a ) ;
948+ iter ++ ;
949+ }
950+
951+ // free
952+ int [ ] free = { } ;
953+ iter = 0 ;
954+ foreach ( int i in Enumerable . Range ( 0 , len ( axes ) ) )
955+ if ( ! Array . Exists ( axes , i => i == i ) )
956+ free [ free . Length ] = i ;
957+
958+ // free_dims
959+ int [ ] free_dims = { } ;
960+ foreach ( int i in free )
961+ free_dims [ free_dims . Length ] = ( int ) shape_a [ i ] ;
962+
963+ int prod_free = ( int ) np . prod ( free_dims ) ;
964+
965+ // prod_axes
966+ int [ ] prod_axes_pre = { } ;
967+ foreach ( int i in axes )
968+ prod_axes_pre [ prod_axes_pre . Length ] = ( int ) shape_a [ i ] ;
969+ int prod_axes = ( int ) np . prod ( prod_axes_pre ) ;
970+
971+ // perm
972+ Tensor perm ;
973+ if ( flipped )
974+ perm = ops . convert_to_tensor ( list ( free ) ) + ops . convert_to_tensor ( free ) ;
975+ else
976+ perm = ops . convert_to_tensor ( list ( free ) ) + ops . convert_to_tensor ( free )
977+ + ops . convert_to_tensor ( list ( axes ) ) ;
978+
979+ // new_shape
980+ Shape new_shape ;
981+ if ( flipped )
982+ new_shape = new Shape ( new int [ ] { prod_axes , prod_free } ) ;
983+ else
984+ new_shape = new Shape ( new int [ ] { prod_free , prod_axes } ) ;
985+ }
986+
987+ throw new NotImplementedException ( "_tensordot_reshape" ) ;
988+ }
989+
990+ return tf_with ( ops . name_scope ( name , "Tensordot" , new { x , y , axes } ) , scope =>
991+ {
992+ name = scope ;
993+ var ( a_axes , b_axes ) = ( axes [ 0 ] , axes [ 1 ] ) ;
994+ return x ;
995+ } ) ;
996+ }
997+
932998 public static Tensor truediv ( Tensor x , Tensor y , string name = null )
933999 => _truediv_python3 ( x , y , name ) ;
9341000
0 commit comments