@@ -113,6 +113,24 @@ def _ctype_to_lists(ctype_arr, dim, shape, offset=0):
113113 offset += shape [0 ]
114114 return res
115115
116+ def _slice_to_length (key , dim ):
117+ tkey = [key .start , key .stop , key .step ]
118+
119+ if tkey [0 ] is None :
120+ tkey [0 ] = 0
121+ elif tkey [0 ] < 0 :
122+ tkey [0 ] = dim - tkey [0 ]
123+
124+ if tkey [1 ] is None :
125+ tkey [1 ] = dim
126+ elif tkey [1 ] < 0 :
127+ tkey [1 ] = dim - tkey [1 ]
128+
129+ if tkey [2 ] is None :
130+ tkey [2 ] = 1
131+
132+ return int (((tkey [1 ] - tkey [0 ] - 1 ) / tkey [2 ]) + 1 )
133+
116134def _get_info (dims , buf_len ):
117135 elements = 1
118136 numdims = len (dims )
@@ -132,6 +150,102 @@ def _get_info(dims, buf_len):
132150 return numdims , idims
133151
134152
153+ def _get_indices (key ):
154+
155+ index_vec = Index * 4
156+ S = Index (slice (None ))
157+ inds = index_vec (S , S , S , S )
158+
159+ if isinstance (key , tuple ):
160+ n_idx = len (key )
161+ for n in range (n_idx ):
162+ inds [n ] = Index (key [n ])
163+ else :
164+ inds [0 ] = Index (key )
165+
166+ return inds
167+
168+ def _get_assign_dims (key , idims ):
169+
170+ dims = [1 ]* 4
171+
172+ for n in range (len (idims )):
173+ dims [n ] = idims [n ]
174+
175+ if is_number (key ):
176+ dims [0 ] = 1
177+ return dims
178+ elif isinstance (key , slice ):
179+ dims [0 ] = _slice_to_length (key , idims [0 ])
180+ return dims
181+ elif isinstance (key , ParallelRange ):
182+ dims [0 ] = _slice_to_length (key .S , idims [0 ])
183+ return dims
184+ elif isinstance (key , BaseArray ):
185+ dims [0 ] = key .elements ()
186+ return dims
187+ elif isinstance (key , tuple ):
188+ n_inds = len (key )
189+
190+ for n in range (n_inds ):
191+ if (is_number (key [n ])):
192+ dims [n ] = 1
193+ elif (isinstance (key [n ], BaseArray )):
194+ dims [n ] = key [n ].elements ()
195+ elif (isinstance (key [n ], slice )):
196+ dims [n ] = _slice_to_length (key [n ], idims [n ])
197+ elif (isinstance (key [n ], ParallelRange )):
198+ dims [n ] = _slice_to_length (key [n ].S , idims [n ])
199+ else :
200+ raise IndexError ("Invalid type while assigning to arrayfire.array" )
201+
202+ return dims
203+ else :
204+ raise IndexError ("Invalid type while assigning to arrayfire.array" )
205+
206+
207+ def transpose (a , conj = False ):
208+ """
209+ Perform the transpose on an input.
210+
211+ Parameters
212+ -----------
213+ a : af.Array
214+ Multi dimensional arrayfire array.
215+
216+ conj : optional: bool. default: False.
217+ Flag to specify if a complex conjugate needs to applied for complex inputs.
218+
219+ Returns
220+ --------
221+ out : af.Array
222+ Containing the tranpose of `a` for all batches.
223+
224+ """
225+ out = Array ()
226+ safe_call (backend .get ().af_transpose (ct .pointer (out .arr ), a .arr , conj ))
227+ return out
228+
229+ def transpose_inplace (a , conj = False ):
230+ """
231+ Perform inplace transpose on an input.
232+
233+ Parameters
234+ -----------
235+ a : af.Array
236+ - Multi dimensional arrayfire array.
237+ - Contains transposed values on exit.
238+
239+ conj : optional: bool. default: False.
240+ Flag to specify if a complex conjugate needs to applied for complex inputs.
241+
242+ Note
243+ -------
244+ Input `a` needs to be a square matrix or a batch of square matrices.
245+
246+ """
247+ safe_call (backend .get ().af_transpose_inplace (a .arr , conj ))
248+
135249class Array (BaseArray ):
136250
137251 """
@@ -757,7 +871,7 @@ def __getitem__(self, key):
757871 try :
758872 out = Array ()
759873 n_dims = self .numdims ()
760- inds = get_indices (key )
874+ inds = _get_indices (key )
761875
762876 safe_call (backend .get ().af_index_gen (ct .pointer (out .arr ),
763877 self .arr , ct .c_longlong (n_dims ), ct .pointer (inds )))
@@ -778,13 +892,13 @@ def __setitem__(self, key, val):
778892 n_dims = self .numdims ()
779893
780894 if (is_number (val )):
781- tdims = get_assign_dims (key , self .dims ())
895+ tdims = _get_assign_dims (key , self .dims ())
782896 other_arr = constant_array (val , tdims [0 ], tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
783897 else :
784898 other_arr = val .arr
785899
786900 out_arr = ct .c_void_p (0 )
787- inds = get_indices (key )
901+ inds = _get_indices (key )
788902
789903 safe_call (backend .get ().af_assign_gen (ct .pointer (out_arr ),
790904 self .arr , ct .c_longlong (n_dims ), ct .pointer (inds ),
0 commit comments