@@ -163,24 +163,87 @@ def sync(device=None):
163163 safe_call (backend .get ().af_sync (dev ))
164164
165165def __eval (* args ):
166- for A in args :
167- if isinstance (A , tuple ):
168- __eval (* A )
169- if isinstance (A , list ):
170- __eval (* A )
171- if isinstance (A , Array ):
172- safe_call (backend .get ().af_eval (A .arr ))
166+ nargs = len (args )
167+ if (nargs == 1 ):
168+ safe_call (backend .get ().af_eval (args [0 ].arr ))
169+ else :
170+ c_void_p_n = ct .c_void_p * nargs
171+ arrs = c_void_p_n ()
172+ for n in range (nargs ):
173+ arrs [n ] = args [n ].arr
174+ safe_call (backend .get ().af_eval_multiple (ct .c_int (nargs ), ct .pointer (arrs )))
175+ return
173176
174177def eval (* args ):
175178 """
176- Evaluate the input
179+ Evaluate one or more inputs together
177180
178181 Parameters
179182 -----------
180183 args : arguments to be evaluated
184+
185+ Note
186+ -----
187+
188+ All the input arrays to this function should be of the same size.
189+
190+ Examples
191+ --------
192+
193+ >>> a = af.constant(1, 3, 3)
194+ >>> b = af.constant(2, 3, 3)
195+ >>> c = a + b
196+ >>> d = a - b
197+ >>> af.eval(c, d) # A single kernel is launched here
198+ >>> c
199+ arrayfire.Array()
200+ Type: float
201+ [3 3 1 1]
202+ 3.0000 3.0000 3.0000
203+ 3.0000 3.0000 3.0000
204+ 3.0000 3.0000 3.0000
205+
206+ >>> d
207+ arrayfire.Array()
208+ Type: float
209+ [3 3 1 1]
210+ -1.0000 -1.0000 -1.0000
211+ -1.0000 -1.0000 -1.0000
212+ -1.0000 -1.0000 -1.0000
213+ """
214+ for arg in args :
215+ if not isinstance (arg , Array ):
216+ raise RuntimeError ("All inputs to eval must be of type arrayfire.Array" )
217+
218+ __eval (* args )
219+
220+ def set_manual_eval_flag (flag ):
221+ """
222+ Tells the backend JIT engine to disable heuristics for determining when to evaluate a JIT tree.
223+
224+ Parameters
225+ ----------
226+
227+ flag : optional: bool.
228+ - Specifies if the heuristic evaluation of the JIT tree needs to be disabled.
229+
230+ Note
231+ ----
232+ This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
181233 """
234+ safe_call (backend .get ().af_set_manual_eval_flag (flag ))
182235
183- __eval (args )
236+ def get_manual_eval_flag ():
237+ """
238+ Query the backend JIT engine to see if the user disabled heuristic evaluation of the JIT tree.
239+
240+ Note
241+ ----
242+ This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
243+ """
244+ res = ct .c_bool (False )
245+ safe_call (backend .get ().af_get_manual_eval_flag (ct .pointer (res )))
246+ return res .value
184247
185248def device_mem_info ():
186249 """
@@ -258,10 +321,27 @@ def lock_array(a):
258321
259322 Note
260323 -----
261- - The device pointer of `a` is not freed by memory manager until `unlock_device_ptr ()` is called.
324+ - The device pointer of `a` is not freed by memory manager until `unlock_array ()` is called.
262325 """
263326 safe_call (backend .get ().af_lock_array (a .arr ))
264327
328+ def is_locked_array (a ):
329+ """
330+ Check if the input array is locked by the user.
331+
332+ Parameters
333+ ----------
334+ a: af.Array
335+ - A multi dimensional arrayfire array.
336+
337+ Returns
338+ -----------
339+ A bool specifying if the input array is locked.
340+ """
341+ res = ct .c_bool (False )
342+ safe_call (backend .get ().af_is_locked_array (ct .pointer (res ), a .arr ))
343+ return res .value
344+
265345def unlock_device_ptr (a ):
266346 """
267347 This functions is deprecated. Please use unlock_array instead.
0 commit comments