@@ -278,53 +278,37 @@ private static string div_or_truediv<Tx, Ty>(string name, Tx x, Ty y)
278278
279279 protected static Tensor BinaryOpWrapper < Tx , Ty > ( string name , Tx x , Ty y )
280280 {
281- TF_DataType dtype = TF_DataType . DtInvalid ;
282-
283- if ( x is Tensor tl )
284- {
285- dtype = tl . dtype . as_base_dtype ( ) ;
286- }
287-
288- if ( y is Tensor tr )
289- {
290- dtype = tr . dtype . as_base_dtype ( ) ;
291- }
292-
293281 return tf_with ( ops . name_scope ( null , name , new { x , y } ) , scope =>
294282 {
295- Tensor result ;
296- var x1 = ops . convert_to_tensor ( x , dtype : dtype , name : "x" ) ;
297- var y1 = ops . convert_to_tensor ( y , dtype : dtype , name : "y" ) ;
283+ var dtype = GetBestDType ( x , y ) ;
284+ var x1 = ops . convert_to_tensor ( x , name : "x" , dtype : dtype ) ;
285+ var y1 = ops . convert_to_tensor ( y , name : "y" , dtype : dtype ) ;
286+ string newname = scope ;
298287
299- switch ( name . ToLowerInvariant ( ) )
288+ return name . ToLowerInvariant ( ) switch
300289 {
301- case "add" :
302- result = math_ops . add_v2 ( x1 , y1 , name : scope ) ;
303- break ;
304- case "div" :
305- result = math_ops . div ( x1 , y1 , name : scope ) ;
306- break ;
307- case "floordiv" :
308- result = gen_math_ops . floor_div ( x1 , y1 , name : scope ) ;
309- break ;
310- case "truediv" :
311- result = math_ops . truediv ( x1 , y1 , name : scope ) ;
312- break ;
313- case "mul" :
314- result = math_ops . multiply ( x1 , y1 , name : scope ) ;
315- break ;
316- case "sub" :
317- result = gen_math_ops . sub ( x1 , y1 , name : scope ) ;
318- break ;
319- case "mod" :
320- result = gen_math_ops . floor_mod ( x1 , y1 , name : scope ) ;
321- break ;
322- default :
323- throw new NotImplementedException ( $ "BinaryOpWrapper: { name } - { typeof ( Tx ) . Name } , { typeof ( Ty ) . Name } ") ;
324- }
325-
326- return result ;
290+ "add" => math_ops . add_v2 ( x1 , y1 , name : newname ) ,
291+ "div" => math_ops . div ( x1 , y1 , name : newname ) ,
292+ "floordiv" => gen_math_ops . floor_div ( x1 , y1 , name : newname ) ,
293+ "truediv" => math_ops . truediv ( x1 , y1 , name : newname ) ,
294+ "mul" => math_ops . multiply ( x1 , y1 , name : newname ) ,
295+ "sub" => gen_math_ops . sub ( x1 , y1 , name : newname ) ,
296+ "mod" => gen_math_ops . floor_mod ( x1 , y1 , name : newname ) ,
297+ _ => throw new NotImplementedException ( $ "BinaryOpWrapper: { name } - { typeof ( Tx ) . Name } , { typeof ( Ty ) . Name } ")
298+ } ;
327299 } ) ;
328300 }
301+
302+ static TF_DataType GetBestDType < Tx , Ty > ( Tx x , Ty y )
303+ {
304+ var dtype1 = x . GetDataType ( ) ;
305+ var dtype2 = y . GetDataType ( ) ;
306+ if ( dtype1 . is_integer ( ) && dtype2 . is_floating ( ) )
307+ return dtype2 ;
308+ else if ( dtype1 . is_floating ( ) && dtype2 . is_integer ( ) )
309+ return dtype1 ;
310+ else
311+ return dtype1 ;
312+ }
329313 }
330314}
0 commit comments