@@ -291,8 +291,26 @@ def test_ihfft(x, data):
291291 assert_n_axis_shape ("ihfft" , x = x , n = n , axis = axis , out = out , size_gt_1 = True )
292292
293293
294- # TODO:
295- # fftfreq
296- # rfftfreq
297- # fftshift
298- # ifftshift
294+ @given ( n = st .integers (1 , 100 ), kw = hh .kwargs (d = st .floats (0.1 , 5 )))
295+ def test_fftfreq (n , kw ):
296+ out = xp .fft .fftfreq (n , ** kw )
297+ ph .assert_shape ("fftfreq" , out_shape = out .shape , expected = (n ,), kw = {"n" : n })
298+
299+
300+ @given (n = st .integers (1 , 100 ), kw = hh .kwargs (d = st .floats (0.1 , 5 )))
301+ def test_rfftfreq (n , kw ):
302+ out = xp .fft .rfftfreq (n , ** kw )
303+ ph .assert_shape ("rfftfreq" , out_shape = out .shape , expected = (n // 2 + 1 ,), kw = {"n" : n })
304+
305+
306+ @pytest .mark .parametrize ("func_name" , ["fftshift" , "ifftshift" ])
307+ @given (x = hh .arrays (xps .floating_dtypes (), fft_shapes_strat ), data = st .data ())
308+ def test_shift_func (func_name , x , data ):
309+ func = getattr (xp .fft , func_name )
310+ axes = data .draw (
311+ st .none () | st .lists (st .sampled_from (list (range (x .ndim ))), min_size = 1 , unique = True ),
312+ label = "axes" ,
313+ )
314+ out = func (x , axes = axes )
315+ ph .assert_dtype (func_name , in_dtype = x .dtype , out_dtype = out .dtype )
316+ ph .assert_shape (func_name , out_shape = out .shape , expected = x .shape )
0 commit comments