55from dataclasses import dataclass
66from typing import Any
77
8- from arrayfire import backend , safe_call # TODO refactoring
9- from arrayfire .array import _in_display_dims_limit # TODO refactoring
8+ from arrayfire import backend , safe_call # TODO refactor
9+ from arrayfire .algorithm import count # TODO refactor
10+ from arrayfire .array import _get_indices , _in_display_dims_limit # TODO refactor
1011
1112from ._dtypes import CShape , Dtype
1213from ._dtypes import bool as af_bool
@@ -37,15 +38,15 @@ class Array:
3738 # arrayfire's __radd__() instead of numpy's __add__()
3839 __array_priority__ = 30
3940
40- # Initialisation
41- arr = ctypes .c_void_p (0 )
42-
4341 def __init__ (
4442 self , x : None | Array | py_array .array | int | ctypes .c_void_p | list = None , dtype : None | Dtype = None ,
4543 pointer_source : PointerSource = PointerSource .host , shape : None | ShapeType = None ,
4644 offset : None | ctypes ._SimpleCData [int ] = None , strides : None | ShapeType = None ) -> None :
4745 _no_initial_dtype = False # HACK, FIXME
4846
47+ # Initialise array object
48+ self .arr = ctypes .c_void_p (0 )
49+
4950 if isinstance (dtype , str ):
5051 dtype = _str_to_dtype (dtype )
5152
@@ -127,7 +128,7 @@ def __str__(self) -> str: # FIXME
127128 if not _in_display_dims_limit (self .shape ):
128129 return _metadata_string (self .dtype , self .shape )
129130
130- return _metadata_string (self .dtype ) + self . _as_str ( )
131+ return _metadata_string (self .dtype ) + _array_as_str ( self )
131132
132133 def __repr__ (self ) -> str : # FIXME
133134 return _metadata_string (self .dtype , self .shape )
@@ -173,6 +174,7 @@ def __truediv__(self, other: int | float | bool | complex | Array, /) -> Array:
173174 return _process_c_function (self , other , backend .get ().af_div )
174175
175176 def __floordiv__ (self , other : int | float | bool | complex | Array , / ) -> Array :
177+ # TODO
176178 return NotImplemented
177179
178180 def __mod__ (self , other : int | float | bool | complex | Array , / ) -> Array :
@@ -187,6 +189,25 @@ def __pow__(self, other: int | float | bool | complex | Array, /) -> Array:
187189 """
188190 return _process_c_function (self , other , backend .get ().af_pow )
189191
192+ def __matmul__ (self , other : Array , / ) -> Array :
193+ # TODO
194+ return NotImplemented
195+
196+ def __getitem__ (self , key : int | slice | tuple [int | slice ] | Array , / ) -> Array :
197+ # TODO: API Specification - key: int | slice | ellipsis | tuple[int | slice] | Array
198+ # TODO: refactor
199+ out = Array ()
200+ ndims = self .ndim
201+
202+ if isinstance (key , Array ) and key == af_bool .c_api_value :
203+ ndims = 1
204+ if count (key ) == 0 :
205+ return out
206+
207+ safe_call (backend .get ().af_index_gen (
208+ ctypes .pointer (out .arr ), self .arr , c_dim_t (ndims ), _get_indices (key ).pointer ))
209+ return out
210+
190211 @property
191212 def dtype (self ) -> Dtype :
192213 out = ctypes .c_int ()
@@ -234,13 +255,23 @@ def shape(self) -> ShapeType:
234255 ctypes .pointer (d0 ), ctypes .pointer (d1 ), ctypes .pointer (d2 ), ctypes .pointer (d3 ), self .arr ))
235256 return (d0 .value , d1 .value , d2 .value , d3 .value )[:self .ndim ] # Skip passing None values
236257
237- def _as_str (self ) -> str :
238- arr_str = ctypes .c_char_p (0 )
239- # FIXME add description to passed arguments
240- safe_call (backend .get ().af_array_to_string (ctypes .pointer (arr_str ), "" , self .arr , 4 , True ))
241- py_str = to_str (arr_str )
242- safe_call (backend .get ().af_free_host (arr_str ))
243- return py_str
258+ def scalar (self ) -> int | float | bool | complex :
259+ """
260+ Return the first element of the array
261+ """
262+ # BUG seg fault on empty array
263+ out = self .dtype .c_type ()
264+ safe_call (backend .get ().af_get_scalar (ctypes .pointer (out ), self .arr ))
265+ return out .value # type: ignore[no-any-return] # FIXME
266+
267+
268+ def _array_as_str (array : Array ) -> str :
269+ arr_str = ctypes .c_char_p (0 )
270+ # FIXME add description to passed arguments
271+ safe_call (backend .get ().af_array_to_string (ctypes .pointer (arr_str ), "" , array .arr , 4 , True ))
272+ py_str = to_str (arr_str )
273+ safe_call (backend .get ().af_free_host (arr_str ))
274+ return py_str
244275
245276
246277def _metadata_string (dtype : Dtype , dims : None | ShapeType = None ) -> str :
@@ -283,9 +314,8 @@ def _process_c_function(
283314 if isinstance (other , Array ):
284315 safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other .arr , _bcast_var ))
285316 elif is_number (other ):
286- target_c_shape = CShape (* target .shape )
287317 other_dtype = _implicit_dtype (other , target .dtype )
288- other_array = _constant_array (other , target_c_shape , other_dtype )
318+ other_array = _constant_array (other , CShape ( * target . shape ) , other_dtype )
289319 safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other_array .arr , _bcast_var ))
290320 else :
291321 raise TypeError (f"{ type (other )} is not supported and can not be passed to C binary function." )
@@ -326,7 +356,7 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
326356
327357 safe_call (backend .get ().af_constant_complex (
328358 ctypes .pointer (out .arr ), ctypes .c_double (value .real ), ctypes .c_double (value .imag ), 4 ,
329- ctypes .pointer (shape .c_array ), dtype ))
359+ ctypes .pointer (shape .c_array ), dtype . c_api_value ))
330360 elif dtype == af_int64 :
331361 safe_call (backend .get ().af_constant_long (
332362 ctypes .pointer (out .arr ), ctypes .c_longlong (value .real ), 4 , ctypes .pointer (shape .c_array )))
@@ -335,6 +365,6 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
335365 ctypes .pointer (out .arr ), ctypes .c_ulonglong (value .real ), 4 , ctypes .pointer (shape .c_array )))
336366 else :
337367 safe_call (backend .get ().af_constant (
338- ctypes .pointer (out .arr ), ctypes .c_double (value ), 4 , ctypes .pointer (shape .c_array ), dtype ))
368+ ctypes .pointer (out .arr ), ctypes .c_double (value ), 4 , ctypes .pointer (shape .c_array ), dtype . c_api_value ))
339369
340370 return out
0 commit comments