@@ -52,6 +52,28 @@ def _default_reduction_dtype(inp_dt, q):
5252 return res_dt
5353
5454
55+ def _default_reduction_dtype_fp_types (inp_dt , q ):
56+ """Gives default output data type for given input data
57+ type `inp_dt` when reduction is performed on queue `q`
58+ and the reduction supports only floating-point data types
59+ """
60+ inp_kind = inp_dt .kind
61+ if inp_kind in "biu" :
62+ res_dt = dpt .dtype (ti .default_device_fp_type (q ))
63+ can_cast_v = dpt .can_cast (inp_dt , res_dt )
64+ if not can_cast_v :
65+ _fp64 = q .sycl_device .has_aspect_fp64
66+ res_dt = dpt .float64 if _fp64 else dpt .float32
67+ elif inp_kind in "f" :
68+ res_dt = dpt .dtype (ti .default_device_fp_type (q ))
69+ if res_dt .itemsize < inp_dt .itemsize :
70+ res_dt = inp_dt
71+ elif inp_kind in "c" :
72+ raise TypeError ("reduction not defined for complex types" )
73+
74+ return res_dt
75+
76+
5577def _reduction_over_axis (
5678 x ,
5779 axis ,
@@ -91,12 +113,15 @@ def _reduction_over_axis(
91113 res_shape = res_shape + (1 ,) * red_nd
92114 inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
93115 res_shape = tuple (res_shape [i ] for i in inv_perm )
94- return dpt .full (
95- res_shape ,
96- _identity ,
97- dtype = res_dt ,
98- usm_type = res_usm_type ,
99- sycl_queue = q ,
116+ return dpt .astype (
117+ dpt .full (
118+ res_shape ,
119+ _identity ,
120+ dtype = _default_reduction_type_fn (inp_dt , q ),
121+ usm_type = res_usm_type ,
122+ sycl_queue = q ,
123+ ),
124+ res_dt ,
100125 )
101126 if red_nd == 0 :
102127 return dpt .astype (x , res_dt , copy = False )
@@ -116,7 +141,7 @@ def _reduction_over_axis(
116141 "Automatically determined reduction data type does not "
117142 "have direct implementation"
118143 )
119- tmp_dt = _default_reduction_dtype (inp_dt , q )
144+ tmp_dt = _default_reduction_type_fn (inp_dt , q )
120145 tmp = dpt .empty (
121146 res_shape , dtype = tmp_dt , usm_type = res_usm_type , sycl_queue = q
122147 )
@@ -161,13 +186,13 @@ def sum(x, axis=None, dtype=None, keepdims=False):
161186 the returned array will have the default real-valued
162187 floating-point data type for the device where input
163188 array `x` is allocated.
164- * If x` has signed integral data type, the returned array
189+ * If ` x` has signed integral data type, the returned array
165190 will have the default signed integral type for the device
166191 where input array `x` is allocated.
167192 * If `x` has unsigned integral data type, the returned array
168193 will have the default unsigned integral type for the device
169194 where input array `x` is allocated.
170- * If `x` has a complex-valued floating-point data typee ,
195+ * If `x` has a complex-valued floating-point data type ,
171196 the returned array will have the default complex-valued
172197 floating-pointer data type for the device where input
173198 array `x` is allocated.
@@ -222,13 +247,13 @@ def prod(x, axis=None, dtype=None, keepdims=False):
222247 the returned array will have the default real-valued
223248 floating-point data type for the device where input
224249 array `x` is allocated.
225- * If x` has signed integral data type, the returned array
250+ * If ` x` has signed integral data type, the returned array
226251 will have the default signed integral type for the device
227252 where input array `x` is allocated.
228253 * If `x` has unsigned integral data type, the returned array
229254 will have the default unsigned integral type for the device
230255 where input array `x` is allocated.
231- * If `x` has a complex-valued floating-point data typee ,
256+ * If `x` has a complex-valued floating-point data type ,
232257 the returned array will have the default complex-valued
233258 floating-pointer data type for the device where input
234259 array `x` is allocated.
@@ -263,6 +288,118 @@ def prod(x, axis=None, dtype=None, keepdims=False):
263288 )
264289
265290
291+ def logsumexp (x , axis = None , dtype = None , keepdims = False ):
292+ """logsumexp(x, axis=None, dtype=None, keepdims=False)
293+
294+ Calculates the logarithm of the sum of exponentials of elements in the
295+ input array `x`.
296+
297+ Args:
298+ x (usm_ndarray):
299+ input array.
300+ axis (Optional[int, Tuple[int, ...]]):
301+ axis or axes along which values must be computed. If a tuple
302+ of unique integers, values are computed over multiple axes.
303+ If `None`, the result is computed over the entire array.
304+ Default: `None`.
305+ dtype (Optional[dtype]):
306+ data type of the returned array. If `None`, the default data
307+ type is inferred from the "kind" of the input array data type.
308+ * If `x` has a real-valued floating-point data type,
309+ the returned array will have the default real-valued
310+ floating-point data type for the device where input
311+ array `x` is allocated.
312+ * If `x` has a boolean or integral data type, the returned array
313+ will have the default floating point data type for the device
314+ where input array `x` is allocated.
315+ * If `x` has a complex-valued floating-point data type,
316+ an error is raised.
317+ If the data type (either specified or resolved) differs from the
318+ data type of `x`, the input array elements are cast to the
319+ specified data type before computing the result. Default: `None`.
320+ keepdims (Optional[bool]):
321+ if `True`, the reduced axes (dimensions) are included in the result
322+ as singleton dimensions, so that the returned array remains
323+ compatible with the input arrays according to Array Broadcasting
324+ rules. Otherwise, if `False`, the reduced axes are not included in
325+ the returned array. Default: `False`.
326+ Returns:
327+ usm_ndarray:
328+ an array containing the results. If the result was computed over
329+ the entire array, a zero-dimensional array is returned. The returned
330+ array has the data type as described in the `dtype` parameter
331+ description above.
332+ """
333+ return _reduction_over_axis (
334+ x ,
335+ axis ,
336+ dtype ,
337+ keepdims ,
338+ ti ._logsumexp_over_axis ,
339+ lambda inp_dt , res_dt , * _ : ti ._logsumexp_over_axis_dtype_supported (
340+ inp_dt , res_dt
341+ ),
342+ _default_reduction_dtype_fp_types ,
343+ _identity = - dpt .inf ,
344+ )
345+
346+
347+ def reduce_hypot (x , axis = None , dtype = None , keepdims = False ):
348+ """reduce_hypot(x, axis=None, dtype=None, keepdims=False)
349+
350+ Calculates the square root of the sum of squares of elements in the input
351+ array `x`.
352+
353+ Args:
354+ x (usm_ndarray):
355+ input array.
356+ axis (Optional[int, Tuple[int, ...]]):
357+ axis or axes along which values must be computed. If a tuple
358+ of unique integers, values are computed over multiple axes.
359+ If `None`, the result is computed over the entire array.
360+ Default: `None`.
361+ dtype (Optional[dtype]):
362+ data type of the returned array. If `None`, the default data
363+ type is inferred from the "kind" of the input array data type.
364+ * If `x` has a real-valued floating-point data type,
365+ the returned array will have the default real-valued
366+ floating-point data type for the device where input
367+ array `x` is allocated.
368+ * If `x` has a boolean or integral data type, the returned array
369+ will have the default floating point data type for the device
370+ where input array `x` is allocated.
371+ * If `x` has a complex-valued floating-point data type,
372+ an error is raised.
373+ If the data type (either specified or resolved) differs from the
374+ data type of `x`, the input array elements are cast to the
375+ specified data type before computing the result. Default: `None`.
376+ keepdims (Optional[bool]):
377+ if `True`, the reduced axes (dimensions) are included in the result
378+ as singleton dimensions, so that the returned array remains
379+ compatible with the input arrays according to Array Broadcasting
380+ rules. Otherwise, if `False`, the reduced axes are not included in
381+ the returned array. Default: `False`.
382+ Returns:
383+ usm_ndarray:
384+ an array containing the results. If the result was computed over
385+ the entire array, a zero-dimensional array is returned. The returned
386+ array has the data type as described in the `dtype` parameter
387+ description above.
388+ """
389+ return _reduction_over_axis (
390+ x ,
391+ axis ,
392+ dtype ,
393+ keepdims ,
394+ ti ._hypot_over_axis ,
395+ lambda inp_dt , res_dt , * _ : ti ._hypot_over_axis_dtype_supported (
396+ inp_dt , res_dt
397+ ),
398+ _default_reduction_dtype_fp_types ,
399+ _identity = 0 ,
400+ )
401+
402+
266403def _comparison_over_axis (x , axis , keepdims , _reduction_fn ):
267404 if not isinstance (x , dpt .usm_ndarray ):
268405 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
0 commit comments