22
33import array as py_array
44import ctypes
5- import math
65from dataclasses import dataclass
76
87from arrayfire import backend , safe_call # TODO refactoring
98from arrayfire .array import _in_display_dims_limit # TODO refactoring
109
11- from ._dtypes import Dtype , c_dim_t , float32 , supported_dtypes
10+ from ._dtypes import CShape , Dtype , c_dim_t , float32 , supported_dtypes
1211from ._utils import Device , PointerSource , to_str
1312
1413ShapeType = tuple [None | int , ...]
@@ -28,7 +27,6 @@ class Array:
2827 __array_priority__ = 30
2928
3029 # Initialisation
31- _array_buffer = _ArrayBuffer ()
3230 arr = ctypes .c_void_p (0 )
3331
3432 def __init__ (
@@ -46,12 +44,12 @@ def __init__(
4644 if x is None :
4745 if not shape : # shape is None or empty tuple
4846 safe_call (backend .get ().af_create_handle (
49- ctypes .pointer (self .arr ), 0 , ctypes .pointer (dim4 () ), dtype .c_api_value ))
47+ ctypes .pointer (self .arr ), 0 , ctypes .pointer (CShape (). c_array ), dtype .c_api_value ))
5048 return
5149
5250 # NOTE: applies inplace changes for self.arr
5351 safe_call (backend .get ().af_create_handle (
54- ctypes .pointer (self .arr ), len (shape ), ctypes .pointer (dim4 (* shape )), dtype .c_api_value ))
52+ ctypes .pointer (self .arr ), len (shape ), ctypes .pointer (CShape (* shape ). c_array ), dtype .c_api_value ))
5553 return
5654
5755 if isinstance (x , Array ):
@@ -61,19 +59,16 @@ def __init__(
6159 if isinstance (x , py_array .array ):
6260 _type_char = x .typecode
6361 _array_buffer = _ArrayBuffer (* x .buffer_info ())
64- numdims , idims = _get_info (shape , _array_buffer .length )
6562
6663 elif isinstance (x , list ):
6764 _array = py_array .array ("f" , x ) # BUG [True, False] -> dtype: f32 # TODO add int and float
6865 _type_char = _array .typecode
6966 _array_buffer = _ArrayBuffer (* _array .buffer_info ())
70- numdims , idims = _get_info (shape , _array_buffer .length )
7167
7268 elif isinstance (x , int ) or isinstance (x , ctypes .c_void_p ): # TODO
7369 _array_buffer = _ArrayBuffer (x if not isinstance (x , ctypes .c_void_p ) else x .value )
74- numdims , idims = _get_info (shape , _array_buffer .length )
7570
76- if not math . prod ( idims ) :
71+ if not shape :
7772 raise RuntimeError ("Expected to receive the initial shape due to the x being a data pointer." )
7873
7974 if _no_initial_dtype :
@@ -84,34 +79,37 @@ def __init__(
8479 else :
8580 raise TypeError ("Passed object x is an object of unsupported class." )
8681
82+ _cshape = _get_cshape (shape , _array_buffer .length )
83+
8784 if not _no_initial_dtype and dtype .typecode != _type_char :
8885 raise TypeError ("Can not create array of requested type from input data type" )
8986
9087 if not (offset or strides ):
9188 if pointer_source == PointerSource .host :
9289 safe_call (backend .get ().af_create_array (
93- ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), numdims ,
94- ctypes .pointer (dim4 ( * idims ) ), dtype .c_api_value ))
90+ ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), _cshape . original_shape ,
91+ ctypes .pointer (_cshape . c_array ), dtype .c_api_value ))
9592 return
9693
9794 safe_call (backend .get ().af_device_array (
98- ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), numdims ,
99- ctypes .pointer (dim4 ( * idims ) ), dtype .c_api_value ))
95+ ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), _cshape . original_shape ,
96+ ctypes .pointer (_cshape . c_array ), dtype .c_api_value ))
10097 return
10198
102- if offset is None : # TODO
99+ if offset is None :
103100 offset = c_dim_t (0 )
104101
105- if strides is None : # TODO
106- strides = (1 , idims [0 ], idims [0 ]* idims [1 ], idims [0 ]* idims [1 ]* idims [2 ])
102+ if strides is None :
103+ strides = (1 , _cshape [0 ], _cshape [0 ]* _cshape [1 ], _cshape [0 ]* _cshape [1 ]* _cshape [2 ])
107104
108105 if len (strides ) < 4 :
109106 strides += (strides [- 1 ], ) * (4 - len (strides ))
110- strides_dim4 = dim4 (* strides )
107+ strides_cshape = CShape (* strides ). c_array
111108
112109 safe_call (backend .get ().af_create_strided_array (
113- ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), offset , numdims ,
114- ctypes .pointer (dim4 (* idims )), ctypes .pointer (strides_dim4 ), dtype .c_api_value , pointer_source .value ))
110+ ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), offset , _cshape .original_shape ,
111+ ctypes .pointer (_cshape .c_array ), ctypes .pointer (strides_cshape ), dtype .c_api_value ,
112+ pointer_source .value ))
115113
116114 def __str__ (self ) -> str : # FIXME
117115 if not _in_display_dims_limit (self .shape ):
@@ -126,7 +124,7 @@ def __len__(self) -> int:
126124 return self .shape [0 ] if self .shape else 0 # type: ignore[return-value]
127125
128126 def __pos__ (self ) -> Array :
129- """y
127+ """
130128 Return +self
131129 """
132130 return self
@@ -190,8 +188,7 @@ def shape(self) -> ShapeType:
190188 d3 = c_dim_t (0 )
191189 safe_call (backend .get ().af_get_dims (
192190 ctypes .pointer (d0 ), ctypes .pointer (d1 ), ctypes .pointer (d2 ), ctypes .pointer (d3 ), self .arr ))
193- dims = (d0 .value , d1 .value , d2 .value , d3 .value )
194- return dims [:self .ndim ] # FIXME An array dimension must be None if and only if a dimension is unknown
191+ return (d0 .value , d1 .value , d2 .value , d3 .value )[:self .ndim ] # Skip passing None values
195192
196193 def _as_str (self ) -> str :
197194 arr_str = ctypes .c_char_p (0 )
@@ -201,30 +198,6 @@ def _as_str(self) -> str:
201198 safe_call (backend .get ().af_free_host (arr_str ))
202199 return py_str
203200
204- # def _get_metadata_str(self, show_dims: bool = True) -> str:
205- # return (
206- # "arrayfire.Array()\n"
207- # f"Type: {self.dtype.typename}\n"
208- # f"Dims: {str(self._dims) if show_dims else ''}")
209-
210- # @property
211- # def dtype(self) -> ...:
212- # dty = ctypes.c_int()
213- # safe_call(backend.get().af_get_type(ctypes.pointer(dty), self.arr)) # -> new dty
214-
215- # @safe_call
216- # def backend()
217- # ...
218-
219- # @backend(safe=True)
220- # def af_get_type(arr) -> ...:
221- # dty = ctypes.c_int()
222- # safe_call(backend.get().af_get_type(ctypes.pointer(dty), self.arr)) # -> new dty
223- # return dty
224-
225- # def new_dtype():
226- # return af_get_type(self.arr)
227-
228201
229202def _metadata_string (dtype : Dtype , dims : None | ShapeType = None ) -> str :
230203 return (
@@ -233,20 +206,14 @@ def _metadata_string(dtype: Dtype, dims: None | ShapeType = None) -> str:
233206 f"Dims: { str (dims ) if dims else '' } " )
234207
235208
236- def _get_info (shape : None | tuple [int ], buffer_length : int ) -> tuple [int , list [int ]]:
237- # TODO refactor
209+ def _get_cshape (shape : None | tuple [int ], buffer_length : int ) -> CShape :
238210 if shape :
239- numdims = len (shape )
240- idims = [1 ]* 4
241- for i in range (numdims ):
242- idims [i ] = shape [i ]
243- elif (buffer_length != 0 ):
244- idims = [buffer_length , 1 , 1 , 1 ]
245- numdims = 1
246- else :
247- raise RuntimeError ("Invalid size" )
211+ return CShape (* shape )
212+
213+ if buffer_length != 0 :
214+ return CShape (buffer_length )
248215
249- return numdims , idims
216+ raise RuntimeError ( "Shape and buffer length are size invalid." )
250217
251218
252219def _c_api_value_to_dtype (value : int ) -> Dtype :
@@ -282,16 +249,6 @@ def _str_to_dtype(value: int) -> Dtype:
282249# return out
283250
284251
285- def dim4 (d0 : int = 1 , d1 : int = 1 , d2 : int = 1 , d3 : int = 1 ): # type: ignore # FIXME
286- c_dim4 = c_dim_t * 4 # ctypes.c_int | ctypes.c_longlong * 4
287- out = c_dim4 (1 , 1 , 1 , 1 )
288-
289- for i , dim in enumerate ((d0 , d1 , d2 , d3 )):
290- if dim is not None :
291- out [i ] = c_dim_t (dim )
292-
293- return out
294-
295252# TODO replace candidate below
296253# def dim4_to_tuple(shape: ShapeType, default: int=1) -> ShapeType:
297254# assert(isinstance(dims, tuple))
0 commit comments