@@ -161,6 +161,114 @@ def test_not_implemented_kwargs(self, kwargs):
161161 dpnp .clip (a , 1 , 5 , ** kwargs )
162162
163163
164+ class TestCumLogSumExp :
165+ def _assert_arrays (self , res , exp , axis , include_initial ):
166+ if include_initial :
167+ if axis != None :
168+ res_initial = dpnp .take (res , dpnp .array ([0 ]), axis = axis )
169+ res_no_initial = dpnp .take (
170+ res , dpnp .array (range (1 , res .shape [axis ])), axis = axis
171+ )
172+ else :
173+ res_initial = res [0 ]
174+ res_no_initial = res [1 :]
175+ assert_dtype_allclose (res_no_initial , exp )
176+ assert (res_initial == - dpnp .inf ).all ()
177+ else :
178+ assert_dtype_allclose (res , exp )
179+
180+ def _get_exp_array (self , a , axis , dtype ):
181+ np_a = dpnp .asnumpy (a )
182+ if axis != None :
183+ return numpy .logaddexp .accumulate (np_a , axis = axis , dtype = dtype )
184+ return numpy .logaddexp .accumulate (np_a .ravel (), dtype = dtype )
185+
186+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_complex = True ))
187+ @pytest .mark .parametrize ("axis" , [None , 2 , - 1 ])
188+ @pytest .mark .parametrize ("include_initial" , [True , False ])
189+ def test_basic (self , dtype , axis , include_initial ):
190+ a = dpnp .ones ((3 , 4 , 5 , 6 , 7 ), dtype = dtype )
191+ res = dpnp .cumlogsumexp (a , axis = axis , include_initial = include_initial )
192+
193+ exp_dt = None
194+ if dtype == dpnp .bool :
195+ exp_dt = dpnp .default_float_type (a .device )
196+
197+ exp = self ._get_exp_array (a , axis , exp_dt )
198+ self ._assert_arrays (res , exp , axis , include_initial )
199+
200+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_complex = True ))
201+ @pytest .mark .parametrize ("axis" , [None , 2 , - 1 ])
202+ @pytest .mark .parametrize ("include_initial" , [True , False ])
203+ def test_out (self , dtype , axis , include_initial ):
204+ a = dpnp .ones ((3 , 4 , 5 , 6 , 7 ), dtype = dtype )
205+
206+ if dpnp .issubdtype (a , dpnp .float32 ):
207+ exp_dt = dpnp .float32
208+ else :
209+ exp_dt = dpnp .default_float_type (a .device )
210+
211+ if axis != None :
212+ if include_initial :
213+ norm_axis = numpy .core .numeric .normalize_axis_index (
214+ axis , a .ndim , "axis"
215+ )
216+ out_sh = (
217+ a .shape [:norm_axis ]
218+ + (a .shape [norm_axis ] + 1 ,)
219+ + a .shape [norm_axis + 1 :]
220+ )
221+ else :
222+ out_sh = a .shape
223+ else :
224+ out_sh = (a .size + int (include_initial ),)
225+ out = dpnp .empty_like (a , shape = out_sh , dtype = exp_dt )
226+ res = dpnp .cumlogsumexp (
227+ a , axis = axis , include_initial = include_initial , out = out
228+ )
229+
230+ exp = self ._get_exp_array (a , axis , exp_dt )
231+
232+ assert res is out
233+ self ._assert_arrays (res , exp , axis , include_initial )
234+
235+ def test_axis_tuple (self ):
236+ a = dpnp .ones ((3 , 4 ))
237+ assert_raises (TypeError , dpnp .cumlogsumexp , a , axis = (0 , 1 ))
238+
239+ @pytest .mark .parametrize (
240+ "in_dtype" , get_all_dtypes (no_bool = True , no_complex = True )
241+ )
242+ @pytest .mark .parametrize ("out_dtype" , get_all_dtypes (no_bool = True ))
243+ def test_dtype (self , in_dtype , out_dtype ):
244+ a = dpnp .ones (100 , dtype = in_dtype )
245+ res = dpnp .cumlogsumexp (a , dtype = out_dtype )
246+ exp = numpy .logaddexp .accumulate (dpnp .asnumpy (a ))
247+ exp = exp .astype (out_dtype )
248+
249+ assert_allclose (res , exp , rtol = 1e-06 )
250+
251+ @pytest .mark .usefixtures ("suppress_invalid_numpy_warnings" )
252+ @pytest .mark .parametrize (
253+ "arr_dt" , get_all_dtypes (no_none = True , no_complex = True )
254+ )
255+ @pytest .mark .parametrize (
256+ "out_dt" , get_all_dtypes (no_none = True , no_complex = True )
257+ )
258+ @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
259+ def test_out_dtype (self , arr_dt , out_dt , dtype ):
260+ a = numpy .arange (10 , 20 ).reshape ((2 , 5 )).astype (dtype = arr_dt )
261+ out = numpy .zeros_like (a , dtype = out_dt )
262+
263+ ia = dpnp .array (a )
264+ iout = dpnp .array (out )
265+
266+ result = dpnp .cumlogsumexp (ia , out = iout , dtype = dtype , axis = 1 )
267+ exp = numpy .logaddexp .accumulate (a , out = out , axis = 1 )
268+ assert_allclose (result , exp .astype (dtype ), rtol = 1e-06 )
269+ assert result is iout
270+
271+
164272class TestCumProd :
165273 @pytest .mark .parametrize (
166274 "arr, axis" ,
0 commit comments