@@ -303,15 +303,14 @@ def test_rfftfreq(n, kw):
303303 ph .assert_shape ("rfftfreq" , out_shape = out .shape , expected = (n // 2 + 1 ,), kw = {"n" : n })
304304
305305
306- @given (x = hh .arrays (xps .floating_dtypes (), fft_shapes_strat ))
307- def test_fftshift (x ):
308- out = xp .fft .fftshift (x )
309- ph .assert_dtype ("fftshift" , in_dtype = x .dtype , out_dtype = out .dtype )
310- ph .assert_shape ("fftshift" , out_shape = out .shape , expected = x .shape )
311-
312-
313- @given (x = hh .arrays (xps .floating_dtypes (), fft_shapes_strat ))
314- def test_ifftshift (x ):
315- out = xp .fft .ifftshift (x )
316- ph .assert_dtype ("ifftshift" , in_dtype = x .dtype , out_dtype = out .dtype )
317- ph .assert_shape ("ifftshift" , out_shape = out .shape , expected = x .shape )
306+ @pytest .mark .parametrize ("func_name" , ["fftshift" , "ifftshift" ])
307+ @given (x = hh .arrays (xps .floating_dtypes (), fft_shapes_strat ), data = st .data ())
308+ def test_shift_func (func_name , x , data ):
309+ func = getattr (xp .fft , func_name )
310+ axes = data .draw (
311+ st .none () | st .lists (st .sampled_from (list (range (x .ndim ))), min_size = 1 , unique = True ),
312+ label = "axes" ,
313+ )
314+ out = func (x , axes = axes )
315+ ph .assert_dtype (func_name , in_dtype = x .dtype , out_dtype = out .dtype )
316+ ph .assert_shape (func_name , out_shape = out .shape , expected = x .shape )
0 commit comments