@@ -66,14 +66,7 @@ def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -
6666 if axes is None :
6767 s_strat = st .none () | s_strat
6868 s = data .draw (s_strat , label = "s" )
69- if size_gt_1 :
70- _s = x .shape if s is None else s
71- for i in range (x .ndim ):
72- if i in _axes :
73- side = _s [_axes .index (i )]
74- else :
75- side = x .shape [i ]
76- assume (side > 1 )
69+
7770 norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
7871 kwargs = data .draw (
7972 hh .specified_kwargs (
@@ -86,14 +79,14 @@ def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -
8679 return s , axes , norm , kwargs
8780
8881
89- def assert_fft_dtype (func_name : str , * , in_dtype : DataType , out_dtype : DataType ):
82+ def assert_float_to_complex_dtype (
83+ func_name : str , * , in_dtype : DataType , out_dtype : DataType
84+ ):
9085 if in_dtype == xp .float32 :
9186 expected = xp .complex64
92- elif in_dtype == xp .float64 :
93- expected = xp .complex128
9487 else :
95- assert dh . is_float_dtype ( in_dtype ) # sanity check
96- expected = in_dtype
88+ assert in_dtype == xp . float64 # sanity check
89+ expected = xp . complex128
9790 ph .assert_dtype (
9891 func_name , in_dtype = in_dtype , out_dtype = out_dtype , expected = expected
9992 )
@@ -106,14 +99,10 @@ def assert_n_axis_shape(
10699 n : Optional [int ],
107100 axis : int ,
108101 out : Array ,
109- size_gt_1 : bool = False ,
110102):
111103 _axis = len (x .shape ) - 1 if axis == - 1 else axis
112104 if n is None :
113- if size_gt_1 :
114- axis_side = 2 * (x .shape [_axis ] - 1 )
115- else :
116- axis_side = x .shape [_axis ]
105+ axis_side = x .shape [_axis ]
117106 else :
118107 axis_side = n
119108 expected = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
@@ -127,7 +116,6 @@ def assert_s_axes_shape(
127116 s : Optional [List [int ]],
128117 axes : Optional [List [int ]],
129118 out : Array ,
130- size_gt_1 : bool = False ,
131119):
132120 _axes = sh .normalise_axis (axes , x .ndim )
133121 _s = x .shape if s is None else s
@@ -138,88 +126,78 @@ def assert_s_axes_shape(
138126 else :
139127 side = x .shape [i ]
140128 expected .append (side )
141- if size_gt_1 :
142- last_axis = _axes [- 1 ]
143- expected [last_axis ] = 2 * (expected [last_axis ] - 1 )
144- assume (expected [last_axis ] > 0 ) # TODO: generate valid examples
145129 ph .assert_shape (func_name , out_shape = out .shape , expected = tuple (expected ))
146130
147131
148- @given (
149- x = hh .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
150- data = st .data (),
151- )
132+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
152133def test_fft (x , data ):
153134 n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
154135
155136 out = xp .fft .fft (x , ** kwargs )
156137
157- assert_fft_dtype ("fft" , in_dtype = x .dtype , out_dtype = out .dtype )
138+ ph . assert_dtype ("fft" , in_dtype = x .dtype , out_dtype = out .dtype )
158139 assert_n_axis_shape ("fft" , x = x , n = n , axis = axis , out = out )
159140
160141
161- @given (
162- x = hh .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
163- data = st .data (),
164- )
142+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
165143def test_ifft (x , data ):
166144 n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
167145
168146 out = xp .fft .ifft (x , ** kwargs )
169147
170- assert_fft_dtype ("ifft" , in_dtype = x .dtype , out_dtype = out .dtype )
148+ ph . assert_dtype ("ifft" , in_dtype = x .dtype , out_dtype = out .dtype )
171149 assert_n_axis_shape ("ifft" , x = x , n = n , axis = axis , out = out )
172150
173151
174- @given (
175- x = hh .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
176- data = st .data (),
177- )
152+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
178153def test_fftn (x , data ):
179154 s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
180155
181156 out = xp .fft .fftn (x , ** kwargs )
182157
183- assert_fft_dtype ("fftn" , in_dtype = x .dtype , out_dtype = out .dtype )
158+ ph . assert_dtype ("fftn" , in_dtype = x .dtype , out_dtype = out .dtype )
184159 assert_s_axes_shape ("fftn" , x = x , s = s , axes = axes , out = out )
185160
186161
187- @given (
188- x = hh .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
189- data = st .data (),
190- )
162+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
191163def test_ifftn (x , data ):
192164 s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
193165
194166 out = xp .fft .ifftn (x , ** kwargs )
195167
196- assert_fft_dtype ("ifftn" , in_dtype = x .dtype , out_dtype = out .dtype )
168+ ph . assert_dtype ("ifftn" , in_dtype = x .dtype , out_dtype = out .dtype )
197169 assert_s_axes_shape ("ifftn" , x = x , s = s , axes = axes , out = out )
198170
199171
200- @given (
201- x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
202- data = st .data (),
203- )
172+ @given (x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ), data = st .data ())
204173def test_rfft (x , data ):
205174 n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
206175
207176 out = xp .fft .rfft (x , ** kwargs )
208177
209- assert_fft_dtype ("rfft" , in_dtype = x .dtype , out_dtype = out .dtype )
210- assert_n_axis_shape ("rfft" , x = x , n = n , axis = axis , out = out )
178+ assert_float_to_complex_dtype ("rfft" , in_dtype = x .dtype , out_dtype = out .dtype )
179+
180+ _axis = x .ndim - 1 if axis == - 1 else axis
181+ if n is None :
182+ axis_side = x .shape [_axis ] // 2 + 1
183+ else :
184+ axis_side = n // 2 + 1
185+ expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
186+ ph .assert_shape ("rfft" , out_shape = out .shape , expected = expected_shape )
211187
212188
213- @given (
214- x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ),
215- data = st .data (),
216- )
189+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
217190def test_irfft (x , data ):
218191 n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
219192
220193 out = xp .fft .irfft (x , ** kwargs )
221194
222- assert_fft_dtype ("irfft" , in_dtype = x .dtype , out_dtype = out .dtype )
195+ ph .assert_dtype (
196+ "irfft" ,
197+ in_dtype = x .dtype ,
198+ out_dtype = out .dtype ,
199+ expected = dh .dtype_components [x .dtype ],
200+ )
223201
224202 _axis = x .ndim - 1 if axis == - 1 else axis
225203 if n is None :
@@ -230,17 +208,25 @@ def test_irfft(x, data):
230208 ph .assert_shape ("irfft" , out_shape = out .shape , expected = expected_shape )
231209
232210
233- @given (
234- x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
235- data = st .data (),
236- )
211+ @given (x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ), data = st .data ())
237212def test_rfftn (x , data ):
238213 s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
239214
240215 out = xp .fft .rfftn (x , ** kwargs )
241216
242- assert_fft_dtype ("rfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
243- assert_s_axes_shape ("rfftn" , x = x , s = s , axes = axes , out = out )
217+ assert_float_to_complex_dtype ("rfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
218+
219+ _axes = sh .normalise_axis (axes , x .ndim )
220+ _s = x .shape if s is None else s
221+ expected = []
222+ for i in range (x .ndim ):
223+ if i in _axes :
224+ side = _s [_axes .index (i )]
225+ else :
226+ side = x .shape [i ]
227+ expected .append (side )
228+ expected [_axes [- 1 ]] = _s [- 1 ] // 2 + 1
229+ ph .assert_shape ("rfftn" , out_shape = out .shape , expected = tuple (expected ))
244230
245231
246232@given (
@@ -250,24 +236,44 @@ def test_rfftn(x, data):
250236 data = st .data (),
251237)
252238def test_irfftn (x , data ):
253- s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data , size_gt_1 = True )
239+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
254240
255241 out = xp .fft .irfftn (x , ** kwargs )
256242
257- assert_fft_dtype ("irfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
258- assert_s_axes_shape ("rfftn" , x = x , s = s , axes = axes , out = out , size_gt_1 = True )
259-
243+ ph .assert_dtype (
244+ "irfftn" ,
245+ in_dtype = x .dtype ,
246+ out_dtype = out .dtype ,
247+ expected = dh .dtype_components [x .dtype ],
248+ )
260249
261- @given (
262- x = hh .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
263- data = st .data (),
264- )
250+ # TODO: assert shape correctly
251+ # _axes = sh.normalise_axis(axes, x.ndim)
252+ # _s = x.shape if s is None else s
253+ # expected = []
254+ # for i in range(x.ndim):
255+ # if i in _axes:
256+ # side = _s[_axes.index(i)]
257+ # else:
258+ # side = x.shape[i]
259+ # expected.append(side)
260+ # last_axis = max(_axes)
261+ # expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1
262+ # ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
263+
264+
265+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
265266def test_hfft (x , data ):
266267 n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
267268
268269 out = xp .fft .hfft (x , ** kwargs )
269270
270- assert_fft_dtype ("hfft" , in_dtype = x .dtype , out_dtype = out .dtype )
271+ ph .assert_dtype (
272+ "hfft" ,
273+ in_dtype = x .dtype ,
274+ out_dtype = out .dtype ,
275+ expected = dh .dtype_components [x .dtype ],
276+ )
271277
272278 _axis = x .ndim - 1 if axis == - 1 else axis
273279 if n is None :
@@ -278,20 +284,24 @@ def test_hfft(x, data):
278284 ph .assert_shape ("hfft" , out_shape = out .shape , expected = expected_shape )
279285
280286
281- @given (
282- x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
283- data = st .data (),
284- )
287+ @given (x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ), data = st .data ())
285288def test_ihfft (x , data ):
286289 n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
287290
288291 out = xp .fft .ihfft (x , ** kwargs )
289292
290- assert_fft_dtype ("ihfft" , in_dtype = x .dtype , out_dtype = out .dtype )
291- assert_n_axis_shape ("ihfft" , x = x , n = n , axis = axis , out = out , size_gt_1 = True )
293+ assert_float_to_complex_dtype ("ihfft" , in_dtype = x .dtype , out_dtype = out .dtype )
294+
295+ _axis = x .ndim - 1 if axis == - 1 else axis
296+ if n is None :
297+ axis_side = x .shape [_axis ] // 2 + 1
298+ else :
299+ axis_side = n // 2 + 1
300+ expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
301+ ph .assert_shape ("ihfft" , out_shape = out .shape , expected = expected_shape )
292302
293303
294- @given ( n = st .integers (1 , 100 ), kw = hh .kwargs (d = st .floats (0.1 , 5 )))
304+ @given (n = st .integers (1 , 100 ), kw = hh .kwargs (d = st .floats (0.1 , 5 )))
295305def test_fftfreq (n , kw ):
296306 out = xp .fft .fftfreq (n , ** kw )
297307 ph .assert_shape ("fftfreq" , out_shape = out .shape , expected = (n ,), kw = {"n" : n })
@@ -300,15 +310,18 @@ def test_fftfreq(n, kw):
300310@given (n = st .integers (1 , 100 ), kw = hh .kwargs (d = st .floats (0.1 , 5 )))
301311def test_rfftfreq (n , kw ):
302312 out = xp .fft .rfftfreq (n , ** kw )
303- ph .assert_shape ("rfftfreq" , out_shape = out .shape , expected = (n // 2 + 1 ,), kw = {"n" : n })
313+ ph .assert_shape (
314+ "rfftfreq" , out_shape = out .shape , expected = (n // 2 + 1 ,), kw = {"n" : n }
315+ )
304316
305317
306318@pytest .mark .parametrize ("func_name" , ["fftshift" , "ifftshift" ])
307319@given (x = hh .arrays (xps .floating_dtypes (), fft_shapes_strat ), data = st .data ())
308320def test_shift_func (func_name , x , data ):
309321 func = getattr (xp .fft , func_name )
310322 axes = data .draw (
311- st .none () | st .lists (st .sampled_from (list (range (x .ndim ))), min_size = 1 , unique = True ),
323+ st .none ()
324+ | st .lists (st .sampled_from (list (range (x .ndim ))), min_size = 1 , unique = True ),
312325 label = "axes" ,
313326 )
314327 out = func (x , axes = axes )
0 commit comments