1919from ._dtypes import int64 as af_int64
2020from ._dtypes import supported_dtypes
2121from ._dtypes import uint64 as af_uint64
22- from ._utils import PointerSource , is_number , to_str
22+ from ._utils import PointerSource , to_str
2323
2424ShapeType = tuple [int , ...]
2525_bcast_var = False # HACK, TODO replace for actual bcast_var after refactoring
@@ -286,25 +286,25 @@ def __radd__(self, other: Array, /) -> Array:
286286 """
287287 Return other + self.
288288 """
289- return _process_c_function (self , other , backend .get ().af_add )
289+ return _process_c_function (other , self , backend .get ().af_add )
290290
291291 def __rsub__ (self , other : Array , / ) -> Array :
292292 """
293293 Return other - self.
294294 """
295- return _process_c_function (self , other , backend .get ().af_sub )
295+ return _process_c_function (other , self , backend .get ().af_sub )
296296
297297 def __rmul__ (self , other : Array , / ) -> Array :
298298 """
299299 Return other * self.
300300 """
301- return _process_c_function (self , other , backend .get ().af_mul )
301+ return _process_c_function (other , self , backend .get ().af_mul )
302302
303303 def __rtruediv__ (self , other : Array , / ) -> Array :
304304 """
305305 Return other / self.
306306 """
307- return _process_c_function (self , other , backend .get ().af_div )
307+ return _process_c_function (other , self , backend .get ().af_div )
308308
309309 def __rfloordiv__ (self , other : Array , / ) -> Array :
310310 # TODO
@@ -314,13 +314,13 @@ def __rmod__(self, other: Array, /) -> Array:
314314 """
315315 Return other / self.
316316 """
317- return _process_c_function (self , other , backend .get ().af_mod )
317+ return _process_c_function (other , self , backend .get ().af_mod )
318318
319319 def __rpow__ (self , other : Array , / ) -> Array :
320320 """
321321 Return other ** self.
322322 """
323- return _process_c_function (self , other , backend .get ().af_pow )
323+ return _process_c_function (other , self , backend .get ().af_pow )
324324
325325 # Reflected Array Operators
326326
@@ -334,31 +334,31 @@ def __rand__(self, other: Array, /) -> Array:
334334 """
335335 Return other & self.
336336 """
337- return _process_c_function (self , other , backend .get ().af_bitand )
337+ return _process_c_function (other , self , backend .get ().af_bitand )
338338
339339 def __ror__ (self , other : Array , / ) -> Array :
340340 """
341341 Return other & self.
342342 """
343- return _process_c_function (self , other , backend .get ().af_bitor )
343+ return _process_c_function (other , self , backend .get ().af_bitor )
344344
345345 def __rxor__ (self , other : Array , / ) -> Array :
346346 """
347347 Return other ^ self.
348348 """
349- return _process_c_function (self , other , backend .get ().af_bitxor )
349+ return _process_c_function (other , self , backend .get ().af_bitxor )
350350
351351 def __rlshift__ (self , other : Array , / ) -> Array :
352352 """
353353 Return other << self.
354354 """
355- return _process_c_function (self , other , backend .get ().af_bitshiftl )
355+ return _process_c_function (other , self , backend .get ().af_bitshiftl )
356356
357357 def __rrshift__ (self , other : Array , / ) -> Array :
358358 """
359359 Return other >> self.
360360 """
361- return _process_c_function (self , other , backend .get ().af_bitshiftr )
361+ return _process_c_function (other , self , backend .get ().af_bitshiftr )
362362
363363 # In-place Arithmetic Operators
364364
@@ -614,20 +614,32 @@ def _str_to_dtype(value: int) -> Dtype:
614614
615615
616616def _process_c_function (
617- target : Array , other : int | float | bool | complex | Array , c_function : Any ) -> Array :
617+ lhs : int | float | bool | complex | Array , rhs : int | float | bool | complex | Array ,
618+ c_function : Any ) -> Array :
618619 out = Array ()
619620
620- # TODO discuss the difference between binary_func and binary_funcr
621- # because implementation looks like exectly the same.
622- # consider chaging to __iadd__ = __radd__ = __add__ interfce if no difference
623- if isinstance (other , Array ):
624- safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other .arr , _bcast_var ))
625- elif is_number (other ):
626- other_dtype = _implicit_dtype (other , target .dtype )
627- other_array = _constant_array (other , CShape (* target .shape ), other_dtype )
628- safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other_array .arr , _bcast_var ))
621+ if isinstance (lhs , Array ) and isinstance (rhs , Array ):
622+ lhs_array = lhs .arr
623+ rhs_array = rhs .arr
624+
625+ elif isinstance (lhs , Array ) and isinstance (rhs , int | float | bool | complex ):
626+ rhs_dtype = _implicit_dtype (rhs , lhs .dtype )
627+ rhs_constant_array = _constant_array (rhs , CShape (* lhs .shape ), rhs_dtype )
628+
629+ lhs_array = lhs .arr
630+ rhs_array = rhs_constant_array .arr
631+
632+ elif isinstance (lhs , int | float | bool | complex ) and isinstance (rhs , Array ):
633+ lhs_dtype = _implicit_dtype (lhs , rhs .dtype )
634+ lhs_constant_array = _constant_array (lhs , CShape (* rhs .shape ), lhs_dtype )
635+
636+ lhs_array = lhs_constant_array .arr
637+ rhs_array = rhs .arr
638+
629639 else :
630- raise TypeError (f"{ type (other )} is not supported and can not be passed to C binary function." )
640+ raise TypeError (f"{ type (rhs )} is not supported and can not be passed to C binary function." )
641+
642+ safe_call (c_function (ctypes .pointer (out .arr ), lhs_array , rhs_array , _bcast_var ))
631643
632644 return out
633645
0 commit comments