22from typing import List , Optional
33
44import pytest
5- from hypothesis import given
5+ from hypothesis import assume , given
66from hypothesis import strategies as st
77
88from array_api_tests .typing import Array , DataType
2424fft_shapes_strat = hh .shapes (min_dims = 1 ).filter (lambda s : math .prod (s ) > 1 )
2525
2626
27- def draw_n_axis_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
27+ def draw_n_axis_norm_kwargs (x : Array , data : st .DataObject , * , size_gt_1 = False ) -> tuple :
2828 size = math .prod (x .shape )
29- n = data .draw (st .none () | st .integers ((size // 2 ), math .ceil (size * 1.5 )), label = "n" )
29+ n = data .draw (
30+ st .none () | st .integers ((size // 2 ), math .ceil (size * 1.5 )), label = "n"
31+ )
3032 axis = data .draw (st .integers (- 1 , x .ndim - 1 ), label = "axis" )
33+ if size_gt_1 :
34+ _axis = x .ndim - 1 if axis == - 1 else axis
35+ assume (x .shape [_axis ] > 1 )
3136 norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
3237 kwargs = data .draw (
3338 hh .specified_kwargs (
@@ -40,7 +45,7 @@ def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
4045 return n , axis , norm , kwargs
4146
4247
43- def draw_s_axes_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
48+ def draw_s_axes_norm_kwargs (x : Array , data : st .DataObject , * , size_gt_1 = False ) -> tuple :
4449 all_axes = list (range (x .ndim ))
4550 axes = data .draw (
4651 st .none () | st .lists (st .sampled_from (all_axes ), min_size = 1 , unique = True ),
@@ -54,6 +59,14 @@ def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
5459 if axes is None :
5560 s_strat = st .none () | s_strat
5661 s = data .draw (s_strat , label = "s" )
62+ if size_gt_1 :
63+ _s = x .shape if s is None else s
64+ for i in range (x .ndim ):
65+ if i in _axes :
66+ side = _s [_axes .index (i )]
67+ else :
68+ side = x .shape [i ]
69+ assume (side > 1 )
5770 norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
5871 kwargs = data .draw (
5972 hh .specified_kwargs (
@@ -163,7 +176,7 @@ def test_ifftn(x, data):
163176
164177
165178@given (
166- x = xps .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
179+ x = xps .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ),
167180 data = st .data (),
168181)
169182def test_rfft (x , data ):
@@ -176,23 +189,70 @@ def test_rfft(x, data):
176189
177190
178191@given (
179- x = xps .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
192+ x = xps .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ),
180193 data = st .data (),
181194)
182195def test_irfft (x , data ):
183- n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
196+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
184197
185198 out = xp .fft .irfft (x , ** kwargs )
186199
187200 assert_fft_dtype ("irfft" , in_dtype = x .dtype , out_dtype = out .dtype )
188201 # TODO: assert shape
189202
190203
191- # TODO:
192- # test_rfftn
193- # test_irfftn
194- # test_hfft
195- # test_ihfft
204+ @given (
205+ x = xps .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ),
206+ data = st .data (),
207+ )
208+ def test_rfftn (x , data ):
209+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
210+
211+ out = xp .fft .rfftn (x , ** kwargs )
212+
213+ assert_fft_dtype ("rfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
214+ assert_s_axes_shape ("rfftn" , x = x , s = s , axes = axes , out = out )
215+
216+
217+ @given (
218+ x = xps .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ),
219+ data = st .data (),
220+ )
221+ def test_irfftn (x , data ):
222+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data , size_gt_1 = True )
223+
224+ out = xp .fft .irfftn (x , ** kwargs )
225+
226+ assert_fft_dtype ("irfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
227+ assert_s_axes_shape ("irfftn" , x = x , s = s , axes = axes , out = out )
228+
229+
230+ @given (
231+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
232+ data = st .data (),
233+ )
234+ def test_hfft (x , data ):
235+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
236+
237+ out = xp .fft .hfft (x , ** kwargs )
238+
239+ assert_fft_dtype ("hfft" , in_dtype = x .dtype , out_dtype = out .dtype )
240+ # TODO: shape
241+
242+
243+ @given (
244+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
245+ data = st .data (),
246+ )
247+ def test_ihfft (x , data ):
248+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
249+
250+ out = xp .fft .ihfft (x , ** kwargs )
251+
252+ assert_fft_dtype ("ihfft" , in_dtype = x .dtype , out_dtype = out .dtype )
253+ # TODO: shape
254+
255+
196256# fftfreq
197257# rfftfreq
198258# fftshift
0 commit comments