@@ -868,133 +868,92 @@ public static Tensor conj(Tensor x, string name = null)
868868 public static Tensor tanh ( Tensor x , string name = null )
869869 => gen_math_ops . tanh ( x , name ) ;
870870
871- public static Tensor tensordot ( Tensor x , Tensor y , int [ ] axes , string name = null )
871+ public static Tensor tensordot ( Tensor a , Tensor b , NDArray axes , string name = null )
872872 {
873- Tensor _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
873+ return tf_with ( ops . name_scope ( name , "Tensordot" , new { a , b , axes } ) , scope =>
874874 {
875- if ( a . shape . IsFullyDefined && isinstance ( axes , ( typeof ( List < object > ) , typeof ( Tuple ) ) ) )
876- {
877- var shape_a = a . shape . dims ;
878-
879- // axes
880- int iter = 0 ;
881- foreach ( int i in axes )
882- {
883- if ( i >= 0 )
884- axes [ 0 + iter ] = i ;
885- else
886- axes [ 0 + iter ] = i + len ( shape_a ) ;
887- iter ++ ;
888- }
889-
890- // free
891- int [ ] free = { } ;
892- iter = 0 ;
893- foreach ( int i in Enumerable . Range ( 0 , len ( axes ) ) )
894- if ( ! Array . Exists ( axes , i => i == i ) )
895- free [ free . Length ] = i ;
896-
897- // free_dims
898- int [ ] free_dims = { } ;
899- foreach ( int i in free )
900- free_dims [ free_dims . Length ] = ( int ) shape_a [ i ] ;
901-
902- int prod_free = ( int ) np . prod ( free_dims ) ;
903-
904- // prod_axes
905- int [ ] prod_axes_pre = { } ;
906- foreach ( int i in axes )
907- prod_axes_pre [ prod_axes_pre . Length ] = ( int ) shape_a [ i ] ;
908- int prod_axes = ( int ) np . prod ( prod_axes_pre ) ;
909-
910- // perm
911- Tensor perm ;
912- if ( flipped )
913- perm = ops . convert_to_tensor ( list ( free ) ) + ops . convert_to_tensor ( free ) ;
914- else
915- perm = ops . convert_to_tensor ( list ( free ) ) + ops . convert_to_tensor ( free )
916- + ops . convert_to_tensor ( list ( axes ) ) ;
917-
918- // new_shape
919- Shape new_shape ;
920- if ( flipped )
921- new_shape = new Shape ( new int [ ] { prod_axes , prod_free } ) ;
922- else
923- new_shape = new Shape ( new int [ ] { prod_free , prod_axes } ) ;
924- }
875+ name = scope ;
876+ var ( a_axes , b_axes ) = _tensordot_axes ( a , axes ) ;
877+ var ( a_reshape , a_free_dims , a_free_dims_static ) = _tensordot_reshape ( a , a_axes ) ;
878+ var ( b_reshape , b_free_dims , b_free_dims_static ) = _tensordot_reshape ( b , b_axes , true ) ;
879+ var ab_matmul = matmul ( a_reshape , b_reshape ) ;
880+ var dims = new List < int > ( ) ;
881+ dims . AddRange ( a_free_dims ) ;
882+ dims . AddRange ( b_free_dims ) ;
883+ if ( ab_matmul . shape . Equals ( dims ) )
884+ return ab_matmul ;
885+ else
886+ return array_ops . reshape ( ab_matmul , tf . constant ( dims . ToArray ( ) ) , name : name ) ;
887+ } ) ;
888+ }
925889
926- throw new NotImplementedException ( "_tensordot_reshape" ) ;
890+ static ( int [ ] , int [ ] ) _tensordot_axes ( Tensor a , NDArray axes )
891+ {
892+ if ( axes . rank == 0 )
893+ {
894+ int axe = axes ;
895+ if ( axe > a . shape . ndim )
896+ throw new ValueError ( "`axes` must not be larger than the number of " +
897+ $ "dimensions of tensor { a } . Received { axes } , vs " +
898+ $ "tensor dimensions { a . ndim } .") ;
899+ return ( Binding . range ( a . shape . ndim - axe , a . shape . ndim ) . ToArray ( ) ,
900+ Binding . range ( 0 , axe ) . ToArray ( ) ) ;
901+ }
902+ else
903+ {
904+ ( int a_axe , int b_axe ) = ( axes [ 0 ] , axes [ 1 ] ) ;
905+ return ( new [ ] { a_axe } , new [ ] { b_axe } ) ;
927906 }
928-
929- throw new NotImplementedException ( "tensordot" ) ;
930907 }
931908
932- public static Tensor tensordot ( Tensor x , Tensor y , Tensor axes , string name = null )
909+ static ( Tensor , int [ ] , int [ ] ) _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
933910 {
934- Tensor _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
911+ if ( a . shape . IsFullyDefined && isinstance ( axes , ( typeof ( int [ ] ) , typeof ( Tuple ) ) ) )
935912 {
936- if ( a . shape . IsFullyDefined && isinstance ( axes , ( typeof ( List < object > ) , typeof ( Tuple ) ) ) )
937- {
938- var shape_a = a . shape . dims ;
913+ var shape_a = a . shape . as_int_list ( ) ;
939914
940- // axes
941- int iter = 0 ;
942- foreach ( int i in axes )
943- {
944- if ( i >= 0 )
945- axes [ 0 + iter ] = i ;
946- else
947- axes [ 0 + iter ] = i + len ( shape_a ) ;
948- iter ++ ;
949- }
915+ // axes
916+ axes = axes . Select ( i => i >= 0 ? i : i + len ( shape_a ) ) . ToArray ( ) ;
917+
918+ // free
919+ int [ ] free = Binding . range ( a . shape . ndim ) . Where ( i => ! axes . Contains ( i ) ) . ToArray ( ) ;
920+
921+ // free_dims
922+ int [ ] free_dims = free . Select ( i => shape_a [ i ] ) . ToArray ( ) ;
950923
951- // free
952- int [ ] free = { } ;
953- iter = 0 ;
954- foreach ( int i in Enumerable . Range ( 0 , len ( axes ) ) )
955- if ( ! Array . Exists ( axes , i => i == i ) )
956- free [ free . Length ] = i ;
957-
958- // free_dims
959- int [ ] free_dims = { } ;
960- foreach ( int i in free )
961- free_dims [ free_dims . Length ] = ( int ) shape_a [ i ] ;
962-
963- int prod_free = ( int ) np . prod ( free_dims ) ;
964-
965- // prod_axes
966- int [ ] prod_axes_pre = { } ;
967- foreach ( int i in axes )
968- prod_axes_pre [ prod_axes_pre . Length ] = ( int ) shape_a [ i ] ;
969- int prod_axes = ( int ) np . prod ( prod_axes_pre ) ;
970-
971- // perm
972- Tensor perm ;
973- if ( flipped )
974- perm = ops . convert_to_tensor ( list ( free ) ) + ops . convert_to_tensor ( free ) ;
975- else
976- perm = ops . convert_to_tensor ( list ( free ) ) + ops . convert_to_tensor ( free )
977- + ops . convert_to_tensor ( list ( axes ) ) ;
978-
979- // new_shape
980- Shape new_shape ;
981- if ( flipped )
982- new_shape = new Shape ( new int [ ] { prod_axes , prod_free } ) ;
983- else
984- new_shape = new Shape ( new int [ ] { prod_free , prod_axes } ) ;
924+ int prod_free = np . prod ( free_dims ) ;
925+
926+ // prod_axes
927+ int prod_axes = np . prod ( axes . Select ( i => shape_a [ i ] ) . ToArray ( ) ) ;
928+
929+ // perm
930+ List < int > perm = new List < int > ( ) ;
931+ if ( flipped )
932+ {
933+ perm . AddRange ( axes ) ;
934+ perm . AddRange ( free ) ;
935+ }
936+ else
937+ {
938+ perm . AddRange ( free ) ;
939+ perm . AddRange ( axes ) ;
985940 }
986941
987- throw new NotImplementedException ( "_tensordot_reshape" ) ;
942+ // new_shape
943+ Shape new_shape ;
944+ if ( flipped )
945+ new_shape = new Shape ( new int [ ] { prod_axes , prod_free } ) ;
946+ else
947+ new_shape = new Shape ( new int [ ] { prod_free , prod_axes } ) ;
948+ var a_trans = a ;
949+ var reshaped_a = array_ops . reshape ( a_trans , new_shape ) ;
950+ return ( reshaped_a , free_dims , free_dims ) ;
988951 }
989952
990- return tf_with ( ops . name_scope ( name , "Tensordot" , new { x , y , axes } ) , scope =>
991- {
992- name = scope ;
993- var ( a_axes , b_axes ) = ( axes [ 0 ] , axes [ 1 ] ) ;
994- return x ;
995- } ) ;
953+ throw new NotImplementedException ( "_tensordot_reshape" ) ;
996954 }
997955
956+
998957 public static Tensor truediv ( Tensor x , Tensor y , string name = null )
999958 => _truediv_python3 ( x , y , name ) ;
1000959
0 commit comments