33import array as py_array
44import ctypes
55from dataclasses import dataclass
6+ from typing import Any
67
78from arrayfire import backend , safe_call # TODO refactoring
89from arrayfire .array import _in_display_dims_limit # TODO refactoring
910
10- from ._dtypes import CShape , Dtype , c_dim_t , float32 , supported_dtypes
11- from ._utils import Device , PointerSource , to_str
11+ from ._dtypes import CShape , Dtype
12+ from ._dtypes import bool as af_bool
13+ from ._dtypes import c_dim_t
14+ from ._dtypes import complex64 as af_complex64
15+ from ._dtypes import complex128 as af_complex128
16+ from ._dtypes import float32 as af_float32
17+ from ._dtypes import float64 as af_float64
18+ from ._dtypes import int64 as af_int64
19+ from ._dtypes import supported_dtypes
20+ from ._dtypes import uint64 as af_uint64
21+ from ._utils import PointerSource , is_number , to_str
1222
1323ShapeType = tuple [int , ...]
24+ _bcast_var = False # HACK, TODO replace for actual bcast_var after refactoring
1425
1526
1627@dataclass
@@ -40,7 +51,7 @@ def __init__(
4051
4152 if dtype is None :
4253 _no_initial_dtype = True
43- dtype = float32
54+ dtype = af_float32
4455
4556 if x is None :
4657 if not shape : # shape is None or empty tuple
@@ -134,15 +145,47 @@ def __neg__(self) -> Array:
134145 """
135146 Return -self
136147 """
137- # return 0 - self
138- raise NotImplementedError
148+ return 0 - self
139149
140150 def __add__ (self , other : int | float | Array , / ) -> Array :
151+ # TODO discuss either we need to support complex and bool as other input type
141152 """
142153 Return self + other.
143154 """
144- # return _binary_func(self, other, backend.get().af_add) # TODO
145- raise NotImplementedError
155+ return _process_c_function (self , other , backend .get ().af_add )
156+
157+ def __sub__ (self , other : int | float | bool | complex | Array , / ) -> Array :
158+ """
159+ Return self - other.
160+ """
161+ return _process_c_function (self , other , backend .get ().af_sub )
162+
163+ def __mul__ (self , other : int | float | bool | complex | Array , / ) -> Array :
164+ """
165+ Return self * other.
166+ """
167+ return _process_c_function (self , other , backend .get ().af_mul )
168+
169+ def __truediv__ (self , other : int | float | bool | complex | Array , / ) -> Array :
170+ """
171+ Return self / other.
172+ """
173+ return _process_c_function (self , other , backend .get ().af_div )
174+
175+ def __floordiv__ (self , other : int | float | bool | complex | Array , / ) -> Array :
176+ return NotImplemented
177+
178+ def __mod__ (self , other : int | float | bool | complex | Array , / ) -> Array :
179+ """
180+ Return self % other.
181+ """
182+ return _process_c_function (self , other , backend .get ().af_mod )
183+
184+ def __pow__ (self , other : int | float | bool | complex | Array , / ) -> Array :
185+ """
186+ Return self ** other.
187+ """
188+ return _process_c_function (self , other , backend .get ().af_pow )
146189
147190 @property
148191 def dtype (self ) -> Dtype :
@@ -151,7 +194,7 @@ def dtype(self) -> Dtype:
151194 return _c_api_value_to_dtype (out .value )
152195
153196 @property
154- def device (self ) -> Device :
197+ def device (self ) -> Any :
155198 raise NotImplementedError
156199
157200 @property
@@ -232,41 +275,66 @@ def _str_to_dtype(value: int) -> Dtype:
232275
233276 raise TypeError ("There is no supported dtype that matches passed dtype typecode." )
234277
235- # TODO
236- # def _binary_func(lhs: int | float | Array, rhs: int | float | Array, c_func: Any) -> Array: # TODO replace Any
237- # out = Array()
238- # other = rhs
239-
240- # if is_number(rhs):
241- # ldims = _fill_dim4_tuple(lhs.shape)
242- # rty = implicit_dtype(rhs, lhs.type())
243- # other = Array()
244- # other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty.value)
245- # elif not isinstance(rhs, Array):
246- # raise TypeError("Invalid parameter to binary function")
247-
248- # safe_call(c_func(c_pointer(out.arr), lhs.arr, other.arr, _bcast_var.get()))
249-
250- # return out
251-
252-
253- # TODO replace candidate below
254- # def dim4_to_tuple(shape: ShapeType, default: int=1) -> ShapeType:
255- # assert(isinstance(dims, tuple))
256-
257- # if (default is not None):
258- # assert(is_number(default))
259-
260- # out = [default]*4
261-
262- # for i, dim in enumerate(dims):
263- # out[i] = dim
264-
265- # return tuple(out)
266-
267- # def _fill_dim4_tuple(shape: ShapeType) -> tuple[int, ...]:
268- # out = tuple([1 if value is None else value for value in shape])
269- # if len(out) == 4:
270- # return out
271278
272- # return out + (1,)*(4-len(out))
279+ def _process_c_function (
280+ target : Array , other : int | float | bool | complex | Array , c_function : Any ) -> Array :
281+ out = Array ()
282+
283+ if isinstance (other , Array ):
284+ safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other .arr , _bcast_var ))
285+ elif is_number (other ):
286+ target_c_shape = CShape (* target .shape )
287+ other_dtype = _implicit_dtype (other , target .dtype )
288+ other_array = _constant_array (other , target_c_shape , other_dtype )
289+ safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other_array .arr , _bcast_var ))
290+ else :
291+ raise TypeError (f"{ type (other )} is not supported and can not be passed to C binary function." )
292+
293+ return out
294+
295+
296+ def _implicit_dtype (value : int | float | bool | complex , array_dtype : Dtype ) -> Dtype :
297+ if isinstance (value , bool ):
298+ value_dtype = af_bool
299+ if isinstance (value , int ):
300+ value_dtype = af_int64
301+ elif isinstance (value , float ):
302+ value_dtype = af_float64
303+ elif isinstance (value , complex ):
304+ value_dtype = af_complex128
305+ else :
306+ raise TypeError (f"{ type (value )} is not supported and can not be converted to af.Dtype." )
307+
308+ if not (array_dtype == af_float32 or array_dtype == af_complex64 ):
309+ return value_dtype
310+
311+ if value_dtype == af_float64 :
312+ return af_float32
313+
314+ if value_dtype == af_complex128 :
315+ return af_complex64
316+
317+ return value_dtype
318+
319+
320+ def _constant_array (value : int | float | bool | complex , shape : CShape , dtype : Dtype ) -> Array :
321+ out = Array ()
322+
323+ if isinstance (value , complex ):
324+ if dtype != af_complex64 and dtype != af_complex128 :
325+ dtype = af_complex64
326+
327+ safe_call (backend .get ().af_constant_complex (
328+ ctypes .pointer (out .arr ), ctypes .c_double (value .real ), ctypes .c_double (value .imag ), 4 ,
329+ ctypes .pointer (shape .c_array ), dtype ))
330+ elif dtype == af_int64 :
331+ safe_call (backend .get ().af_constant_long (
332+ ctypes .pointer (out .arr ), ctypes .c_longlong (value .real ), 4 , ctypes .pointer (shape .c_array )))
333+ elif dtype == af_uint64 :
334+ safe_call (backend .get ().af_constant_ulong (
335+ ctypes .pointer (out .arr ), ctypes .c_ulonglong (value .real ), 4 , ctypes .pointer (shape .c_array )))
336+ else :
337+ safe_call (backend .get ().af_constant (
338+ ctypes .pointer (out .arr ), ctypes .c_double (value ), 4 , ctypes .pointer (shape .c_array ), dtype ))
339+
340+ return out
0 commit comments