@@ -93,14 +93,24 @@ def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType)
9393
9494
9595def assert_n_axis_shape (
96- func_name : str , * , x : Array , n : Optional [int ], axis : int , out : Array
96+ func_name : str ,
97+ * ,
98+ x : Array ,
99+ n : Optional [int ],
100+ axis : int ,
101+ out : Array ,
102+ size_gt_1 = False ,
97103):
104+ _axis = len (x .shape ) - 1 if axis == - 1 else axis
98105 if n is None :
99- expected_shape = x .shape
106+ if size_gt_1 :
107+ axis_side = 2 * (x .shape [_axis ] - 1 )
108+ else :
109+ axis_side = x .shape [_axis ]
100110 else :
101- _axis = len ( x . shape ) - 1 if axis == - 1 else axis
102- expected_shape = x .shape [:_axis ] + (n ,) + x .shape [_axis + 1 :]
103- ph .assert_shape (func_name , out_shape = out .shape , expected = expected_shape )
111+ axis_side = n
112+ expected = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
113+ ph .assert_shape (func_name , out_shape = out .shape , expected = expected )
104114
105115
106116def assert_s_axes_shape (
@@ -198,7 +208,14 @@ def test_irfft(x, data):
198208 out = xp .fft .irfft (x , ** kwargs )
199209
200210 assert_fft_dtype ("irfft" , in_dtype = x .dtype , out_dtype = out .dtype )
201- # TODO: assert shape
211+
212+ _axis = x .ndim - 1 if axis == - 1 else axis
213+ if n is None :
214+ axis_side = 2 * (x .shape [_axis ] - 1 )
215+ else :
216+ axis_side = n
217+ expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
218+ ph .assert_shape ("irfft" , out_shape = out .shape , expected = expected_shape )
202219
203220
204221@given (
@@ -224,7 +241,7 @@ def test_irfftn(x, data):
224241 out = xp .fft .irfftn (x , ** kwargs )
225242
226243 assert_fft_dtype ("irfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
227- assert_s_axes_shape ( "irfftn" , x = x , s = s , axes = axes , out = out )
244+ # TODO: shape
228245
229246
230247@given (
@@ -237,7 +254,14 @@ def test_hfft(x, data):
237254 out = xp .fft .hfft (x , ** kwargs )
238255
239256 assert_fft_dtype ("hfft" , in_dtype = x .dtype , out_dtype = out .dtype )
240- # TODO: shape
257+
258+ _axis = x .ndim - 1 if axis == - 1 else axis
259+ if n is None :
260+ axis_side = 2 * (x .shape [_axis ] - 1 )
261+ else :
262+ axis_side = n
263+ expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
264+ ph .assert_shape ("hfft" , out_shape = out .shape , expected = expected_shape )
241265
242266
243267@given (
@@ -250,9 +274,10 @@ def test_ihfft(x, data):
250274 out = xp .fft .ihfft (x , ** kwargs )
251275
252276 assert_fft_dtype ("ihfft" , in_dtype = x .dtype , out_dtype = out .dtype )
253- # TODO: shape
277+ assert_n_axis_shape ( "ihfft" , x = x , n = n , axis = axis , out = out , size_gt_1 = True )
254278
255279
280+ # TODO:
256281# fftfreq
257282# rfftfreq
258283# fftshift
0 commit comments