3939class UnaryElementwiseFunc :
4040 """
4141 Class that implements unary element-wise functions.
42+
43+ Args:
44+ name (str):
45+ Name of the unary function
46+ result_type_resovler_fn (callable):
47+ Function that takes dtype of the input and
48+ returns the dtype of the result if the
49+ implementation functions supports it, or
50+ returns `None` otherwise.
51+ unary_dp_impl_fn (callable):
52+ Data-parallel implementation function with signature
53+ `impl_fn(src: usm_ndarray, dst: usm_ndarray,
54+ sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
55+ where the `src` is the argument array, `dst` is the
56+ array to be populated with function values, effectively
57+ evaluating `dst = func(src)`.
58+ The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
59+ The first event corresponds to data-management host tasks,
60+ including lifetime management of argument Python objects to ensure
61+ that their associated USM allocation is not freed before offloaded
62+ computational tasks complete execution, while the second event
63+ corresponds to computational tasks associated with function
64+ evaluation.
65+ docs (str):
66+ Documentation string for the unary function.
4267 """
4368
4469 def __init__ (self , name , result_type_resolver_fn , unary_dp_impl_fn , docs ):
@@ -55,8 +80,31 @@ def __str__(self):
5580 def __repr__ (self ):
5681 return f"<{ self .__name__ } '{ self .name_ } '>"
5782
83+ def get_implementation_function (self ):
84+ """Returns the implementation function for
85+ this elementwise unary function.
86+
87+ """
88+ return self .unary_fn_
89+
90+ def get_type_result_resolver_function (self ):
91+ """Returns the type resolver function for this
92+ elementwise unary function.
93+ """
94+ return self .result_type_resolver_fn_
95+
5896 @property
5997 def types (self ):
98+ """Returns information about types supported by
99+ implementation function, using NumPy's character
100+ encoding for data types, e.g.
101+
102+ :Example:
103+ .. code-block:: python
104+
105+ dpctl.tensor.sin.types
106+ # Outputs: ['e->e', 'f->f', 'd->d', 'F->F', 'D->D']
107+ """
60108 types = self .types_
61109 if not types :
62110 types = []
@@ -363,6 +411,56 @@ def _get_shape(o):
363411class BinaryElementwiseFunc :
364412 """
365413 Class that implements binary element-wise functions.
414+
415+ Args:
416+ name (str):
417+ Name of the unary function
418+ result_type_resovle_fn (callable):
419+ Function that takes dtypes of the input and
420+ returns the dtype of the result if the
421+ implementation functions supports it, or
422+ returns `None` otherwise.
423+ binary_dp_impl_fn (callable):
424+ Data-parallel implementation function with signature
425+ `impl_fn(src1: usm_ndarray, src2: usm_ndarray, dst: usm_ndarray,
426+ sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
427+ where the `src1` and `src2` are the argument arrays, `dst` is the
428+ array to be populated with function values,
429+ i.e. `dst=func(src1, src2)`.
430+ The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
431+ The first event corresponds to data-management host tasks,
432+ including lifetime management of argument Python objects to ensure
433+ that their associated USM allocation is not freed before offloaded
434+ computational tasks complete execution, while the second event
435+ corresponds to computational tasks associated with function
436+ evaluation.
437+ docs (str):
438+ Documentation string for the unary function.
439+ binary_inplace_fn (callable, optional):
440+ Data-parallel implementation function with signature
441+ `impl_fn(src: usm_ndarray, dst: usm_ndarray,
442+ sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
443+ where the `src` is the argument array, `dst` is the
444+ array to be populated with function values,
445+ i.e. `dst=func(dst, src)`.
446+ The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
447+ The first event corresponds to data-management host tasks,
448+ including async lifetime management of Python arguments,
449+ while the second event corresponds to computational tasks
450+ associated with function evaluation.
451+ acceptance_fn (callable, optional):
452+ Function to influence type promotion behavior of this binary
453+ function. The function takes 6 arguments:
454+ arg1_dtype - Data type of the first argument
455+ arg2_dtype - Data type of the second argument
456+ ret_buf1_dtype - Data type the first argument would be cast to
457+ ret_buf2_dtype - Data type the second argument would be cast to
458+ res_dtype - Data type of the output array with function values
459+ sycl_dev - The :class:`dpctl.SyclDevice` where the function
460+ evaluation is carried out.
461+ The function is only called when both arguments of the binary
462+ function require casting, e.g. both arguments of
463+ `dpctl.tensor.logaddexp` are arrays with integral data type.
366464 """
367465
368466 def __init__ (
@@ -392,8 +490,60 @@ def __str__(self):
392490 def __repr__ (self ):
393491 return f"<{ self .__name__ } '{ self .name_ } '>"
394492
493+ def get_implementation_function (self ):
494+ """Returns the out-of-place implementation
495+ function for this elementwise binary function.
496+
497+ """
498+ return self .binary_fn_
499+
500+ def get_implementation_inplace_function (self ):
501+ """Returns the in-place implementation
502+ function for this elementwise binary function.
503+
504+ """
505+ return self .binary_inplace_fn_
506+
507+ def get_type_result_resolver_function (self ):
508+ """Returns the type resolver function for this
509+ elementwise binary function.
510+ """
511+ return self .result_type_resolver_fn_
512+
513+ def get_type_promotion_path_acceptance_function (self ):
514+ """Returns the acceptance function for this
515+ elementwise binary function.
516+
517+ Acceptance function influences the type promotion
518+ behavior of this binary function.
519+ The function takes 6 arguments:
520+ arg1_dtype - Data type of the first argument
521+ arg2_dtype - Data type of the second argument
522+ ret_buf1_dtype - Data type the first argument would be cast to
523+ ret_buf2_dtype - Data type the second argument would be cast to
524+ res_dtype - Data type of the output array with function values
525+ sycl_dev - :class:`dpctl.SyclDevice` on which function evaluation
526+ is carried out.
527+
528+ The acceptance function is only invoked if both input arrays must be
529+ cast to intermediary data types, as would happen during call of
530+ `dpctl.tensor.hypot` with both arrays being of integral data type.
531+ """
532+ return self .acceptance_fn_
533+
395534 @property
396535 def types (self ):
536+ """Returns information about types supported by
537+ implementation function, using NumPy's character
538+ encoding for data types, e.g.
539+
540+ :Example:
541+ .. code-block:: python
542+
543+ dpctl.tensor.divide.types
544+ # Outputs: ['ee->e', 'ff->f', 'fF->F', 'dd->d', 'dD->D',
545+ # 'Ff->F', 'FF->F', 'Dd->D', 'DD->D']
546+ """
397547 types = self .types_
398548 if not types :
399549 types = []
0 commit comments