11#!/usr/bin/env python
2- # Copyright (c) 2019, Intel Corporation
2+ # Copyright (c) 2019-2020 , Intel Corporation
33#
44# Redistribution and use in source and binary forms, with or without
55# modification, are permitted provided that the following conditions are met:
2525# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727from . import _pydfti
28+ from . import _float_utils
2829import mkl
2930
3031import scipy .fft as _fft
3738 get_workers , set_workers
3839)
3940
41+ from numpy .core import (array , asarray , shape , conjugate , take , sqrt , prod )
42+
4043__all__ = ['fft' , 'ifft' , 'fft2' , 'ifft2' , 'fftn' , 'ifftn' ,
4144 'rfft' , 'irfft' , 'rfft2' , 'irfft2' , 'rfftn' , 'irfftn' ,
4245 'hfft' , 'ihfft' , 'hfft2' , 'ihfft2' , 'hfftn' , 'ihfftn' ,
@@ -54,6 +57,7 @@ def __ua_function__(method, args, kwargs):
5457 return NotImplemented
5558 return fn (* args , ** kwargs )
5659
60+
5761def _implements (scipy_func ):
5862 """Decorator adds function to the dictionary of implemented UA functions"""
5963 def inner (func ):
@@ -70,25 +74,54 @@ def _unitary(norm):
7074 return norm is not None
7175
7276
77+ def _cook_nd_args (a , s = None , axes = None , invreal = 0 ):
78+ if s is None :
79+ shapeless = 1
80+ if axes is None :
81+ s = list (a .shape )
82+ else :
83+ s = take (a .shape , axes )
84+ else :
85+ shapeless = 0
86+ s = list (s )
87+ if axes is None :
88+ axes = list (range (- len (s ), 0 ))
89+ if len (s ) != len (axes ):
90+ raise ValueError ("Shape and axes have different lengths." )
91+ if invreal and shapeless :
92+ s [- 1 ] = (a .shape [axes [- 1 ]] - 1 ) * 2
93+ return s , axes
94+
95+
96+ def _tot_size (x , axes ):
97+ s = x .shape
98+ if axes is None :
99+ return x .size
100+ return prod ([s [ai ] for ai in axes ])
101+
102+
73103@_implements (_fft .fft )
74- def fft (x , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
104+ def fft (a , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
105+ x = _float_utils .__upcast_float16_array (a )
75106 output = _pydfti .fft (x , n = n , axis = axis , overwrite_x = overwrite_x )
76107 if _unitary (norm ):
77108 output *= 1 / sqrt (output .shape [axis ])
78109 return output
79110
80111
81112@_implements (_fft .ifft )
82- def ifft (x , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
113+ def ifft (a , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
114+ x = _float_utils .__upcast_float16_array (a )
83115 output = _pydfti .ifft (x , n = n , axis = axis , overwrite_x = overwrite_x )
84116 if _unitary (norm ):
85117 output *= sqrt (output .shape [axis ])
86118 return output
87119
88120
89121@_implements (_fft .fft2 )
90- def fft2 (x , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
91- output = _pydfti .fftn (x , s = s , axis = axis , overwrite_x = overwrite_x )
122+ def fft2 (a , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
123+ x = _float_utils .__upcast_float16_array (a )
124+ output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
92125 if _unitary (norm ):
93126 factor = 1
94127 for axis in axes :
@@ -98,8 +131,9 @@ def fft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
98131
99132
100133@_implements (_fft .ifft2 )
101- def ifft2 (x , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
102- output = _pydfti .ifftn (x , s = s , axis = axis , overwrite_x = overwrite_x )
134+ def ifft2 (a , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
135+ x = _float_utils .__upcast_float16_array (a )
136+ output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
103137 if _unitary (norm ):
104138 factor = 1
105139 _axes = range (output .ndim ) if axes is None else axes
@@ -110,8 +144,9 @@ def ifft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
110144
111145
112146@_implements (_fft .fftn )
113- def fftn (x , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
114- output = _pydfti .fftn (x , s = s , axis = axis , overwrite_x = overwrite_x )
147+ def fftn (a , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
148+ x = _float_utils .__upcast_float16_array (a )
149+ output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
115150 if _unitary (norm ):
116151 factor = 1
117152 _axes = range (output .ndim ) if axes is None else axes
@@ -122,12 +157,77 @@ def fftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
122157
123158
124159@_implements (_fft .ifftn )
125- def ifftn (x , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
126- output = _pydfti .ifftn (x , s = s , axis = axis , overwrite_x = overwrite_x )
160+ def ifftn (a , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
161+ x = _float_utils .__upcast_float16_array (a )
162+ output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
127163 if _unitary (norm ):
128164 factor = 1
129165 _axes = range (output .ndim ) if axes is None else axes
130166 for axis in _axes :
131167 factor *= sqrt (output .shape [axis ])
132168 output *= factor
133169 return output
170+
171+
172+ @_implements (_fft .rfft )
173+ def rfft (a , n = None , axis = - 1 , norm = None ):
174+ x = _float_utils .__upcast_float16_array (a )
175+ unitary = _unitary (norm )
176+ x = _float_utils .__downcast_float128_array (x )
177+ if unitary and n is None :
178+ x = asarray (x )
179+ n = x .shape [axis ]
180+ output = _pydfti .rfft_numpy (x , n = n , axis = axis )
181+ if unitary :
182+ output *= 1 / sqrt (n )
183+ return output
184+
185+
186+ @_implements (_fft .irfft )
187+ def irfft (a , n = None , axis = - 1 , norm = None ):
188+ x = _float_utils .__upcast_float16_array (a )
189+ x = _float_utils .__downcast_float128_array (x )
190+ output = _pydfti .irfft_numpy (x , n = n , axis = axis )
191+ if _unitary (norm ):
192+ output *= sqrt (output .shape [axis ])
193+ return output
194+
195+
196+ @_implements (_fft .rfft2 )
197+ def rfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None ):
198+ x = _float_utils .__upcast_float16_array (a )
199+ x = _float_utils .__downcast_float128_array (a )
200+ return rfftn (x , s , axes , norm )
201+
202+
203+ @_implements (_fft .irfft2 )
204+ def irfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None ):
205+ x = _float_utils .__upcast_float16_array (a )
206+ x = _float_utils .__downcast_float128_array (x )
207+ return irfftn (x , s , axes , norm )
208+
209+
210+ @_implements (_fft .rfftn )
211+ def rfftn (a , s = None , axes = None , norm = None ):
212+ unitary = _unitary (norm )
213+ x = _float_utils .__upcast_float16_array (a )
214+ x = _float_utils .__downcast_float128_array (x )
215+ if unitary :
216+ x = asarray (x )
217+ s , axes = _cook_nd_args (x , s , axes )
218+
219+ output = _pydfti .rfftn_numpy (x , s , axes )
220+ if unitary :
221+ n_tot = prod (asarray (s , dtype = output .dtype ))
222+ output *= 1 / sqrt (n_tot )
223+ return output
224+
225+
226+ @_implements (_fft .irfftn )
227+ def irfftn (a , s = None , axes = None , norm = None ):
228+ x = _float_utils .__upcast_float16_array (a )
229+ x = _float_utils .__downcast_float128_array (x )
230+ output = _pydfti .irfftn_numpy (x , s , axes )
231+ if _unitary (norm ):
232+ output *= sqrt (_tot_size (output , axes ))
233+ return output
0 commit comments