@@ -905,13 +905,29 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
905905 var ( a_reshape , a_free_dims , a_free_dims_static ) = _tensordot_reshape ( a , a_axes ) ;
906906 var ( b_reshape , b_free_dims , b_free_dims_static ) = _tensordot_reshape ( b , b_axes , true ) ;
907907 var ab_matmul = matmul ( a_reshape , b_reshape ) ;
908- var dims = new List < int > ( ) ;
909- dims . AddRange ( a_free_dims ) ;
910- dims . AddRange ( b_free_dims ) ;
911- if ( ab_matmul . shape . Equals ( dims ) )
912- return ab_matmul ;
908+ if ( a_free_dims is int [ ] a_free_dims_list && b_free_dims is int [ ] b_free_dims_list )
909+ {
910+ var total_free_dims = a_free_dims_list . Concat ( b_free_dims_list ) . ToArray ( ) ;
911+ if ( ab_matmul . shape . IsFullyDefined && ab_matmul . shape . as_int_list ( ) . SequenceEqual ( total_free_dims ) )
912+ {
913+ return ab_matmul ;
914+ }
915+ else
916+ {
917+ return array_ops . reshape ( ab_matmul , ops . convert_to_tensor ( total_free_dims ) , name ) ;
918+ }
919+ }
913920 else
914- return array_ops . reshape ( ab_matmul , tf . constant ( dims . ToArray ( ) ) , name : name ) ;
921+ {
922+ var a_free_dims_tensor = ops . convert_to_tensor ( a_free_dims , dtype : dtypes . int32 ) ;
923+ var b_free_dims_tensor = ops . convert_to_tensor ( b_free_dims , dtype : dtypes . int32 ) ;
924+ var product = array_ops . reshape ( ab_matmul , array_ops . concat ( new [ ] { a_free_dims_tensor , b_free_dims_tensor } , 0 ) , name ) ;
925+ if ( a_free_dims_static is not null && b_free_dims_static is not null )
926+ {
927+ product . shape = new Shape ( a_free_dims_static . Concat ( b_free_dims_static ) . ToArray ( ) ) ;
928+ }
929+ return product ;
930+ }
915931 } ) ;
916932 }
917933
@@ -927,14 +943,42 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
927943 return ( Binding . range ( a . shape . ndim - axe , a . shape . ndim ) . ToArray ( ) ,
928944 Binding . range ( 0 , axe ) . ToArray ( ) ) ;
929945 }
930- else
946+ else if ( axes . rank == 1 )
931947 {
948+ if ( axes . shape [ 0 ] != 2 )
949+ {
950+ throw new ValueError ( $ "`axes` must be an integer or have length 2. Received { axes } .") ;
951+ }
932952 ( int a_axe , int b_axe ) = ( axes [ 0 ] , axes [ 1 ] ) ;
933953 return ( new [ ] { a_axe } , new [ ] { b_axe } ) ;
934954 }
955+ else if ( axes . rank == 2 )
956+ {
957+ if ( axes . shape [ 0 ] != 2 )
958+ {
959+ throw new ValueError ( $ "`axes` must be an integer or have length 2. Received { axes } .") ;
960+ }
961+ int [ ] a_axes = new int [ axes . shape [ 1 ] ] ;
962+ int [ ] b_axes = new int [ axes . shape [ 1 ] ] ;
963+ for ( int i = 0 ; i < a_axes . Length ; i ++ )
964+ {
965+ a_axes [ i ] = axes [ 0 , i ] ;
966+ b_axes [ i ] = axes [ 1 , i ] ;
967+ if ( a_axes [ i ] == - 1 || b_axes [ i ] == - 1 )
968+ {
969+ throw new ValueError ( $ "Different number of contraction axes `a` and `b`," +
970+ $ "{ len ( a_axes ) } != { len ( b_axes ) } .") ;
971+ }
972+ }
973+ return ( a_axes , b_axes ) ;
974+ }
975+ else
976+ {
977+ throw new ValueError ( $ "Invalid rank { axes . rank } to make tensor dot.") ;
978+ }
935979 }
936980
937- static ( Tensor , int [ ] , int [ ] ) _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
981+ static ( Tensor , object , int [ ] ) _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
938982 {
939983 if ( a . shape . IsFullyDefined && isinstance ( axes , ( typeof ( int [ ] ) , typeof ( Tuple ) ) ) )
940984 {
@@ -977,6 +1021,58 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
9771021 var reshaped_a = array_ops . reshape ( a_trans , new_shape ) ;
9781022 return ( reshaped_a , free_dims , free_dims ) ;
9791023 }
1024+ else
1025+ {
1026+ int [ ] free_dims_static ;
1027+ Tensor converted_shape_a , converted_axes , converted_free ;
1028+ if ( a . shape . ndim != - 1 )
1029+ {
1030+ var shape_a = a . shape . as_int_list ( ) ;
1031+ for ( int i = 0 ; i < axes . Length ; i ++ )
1032+ {
1033+ if ( axes [ i ] < 0 )
1034+ {
1035+ axes [ i ] += shape_a . Length ;
1036+ }
1037+ }
1038+ var free = Enumerable . Range ( 0 , shape_a . Length ) . Where ( i => ! axes . Contains ( i ) ) . ToArray ( ) ;
1039+
1040+ var axes_dims = axes . Select ( i => shape_a [ i ] ) ;
1041+ var free_dims = free . Select ( i => shape_a [ i ] ) . ToArray ( ) ;
1042+ free_dims_static = free_dims ;
1043+ converted_axes = ops . convert_to_tensor ( axes , dtypes . int32 , "axes" ) ;
1044+ converted_free = ops . convert_to_tensor ( free , dtypes . int32 , "free" ) ;
1045+ converted_shape_a = array_ops . shape ( a ) ;
1046+ }
1047+ else
1048+ {
1049+ free_dims_static = null ;
1050+ converted_shape_a = array_ops . shape ( a ) ;
1051+ var rank_a = array_ops . rank ( a ) ;
1052+ converted_axes = ops . convert_to_tensor ( axes , dtypes . int32 , "axes" ) ;
1053+ converted_axes = array_ops . where_v2 ( converted_axes >= 0 , converted_axes , converted_axes + rank_a ) ;
1054+ ( converted_free , var _ ) = gen_ops . list_diff ( gen_math_ops . range ( ops . convert_to_tensor ( 0 ) , rank_a , ops . convert_to_tensor ( 1 ) ) ,
1055+ converted_axes , dtypes . int32 ) ;
1056+ }
1057+ var converted_free_dims = array_ops . gather ( converted_shape_a , converted_free ) ;
1058+ var converted_axes_dims = array_ops . gather ( converted_shape_a , converted_axes ) ;
1059+ var prod_free_dims = reduce_prod ( converted_free_dims ) ;
1060+ var prod_axes_dims = reduce_prod ( converted_axes_dims ) ;
1061+ Tensor reshaped_a ;
1062+ if ( flipped )
1063+ {
1064+ var perm = array_ops . concat ( new [ ] { converted_axes , converted_free } , 0 ) ;
1065+ var new_shape = array_ops . stack ( new [ ] { prod_axes_dims , prod_free_dims } ) ;
1066+ reshaped_a = array_ops . reshape ( array_ops . transpose ( a , perm ) , new_shape ) ;
1067+ }
1068+ else
1069+ {
1070+ var perm = array_ops . concat ( new [ ] { converted_free , converted_axes } , 0 ) ;
1071+ var new_shape = array_ops . stack ( new [ ] { prod_free_dims , prod_axes_dims } ) ;
1072+ reshaped_a = array_ops . reshape ( array_ops . transpose ( a , perm ) , new_shape ) ;
1073+ }
1074+ return ( reshaped_a , converted_free_dims , free_dims_static ) ;
1075+ }
9801076
9811077 throw new NotImplementedException ( "_tensordot_reshape" ) ;
9821078 }
0 commit comments