1+ import functools
12import numpy as np
23from . import fftw
34
@@ -371,28 +372,37 @@ class FFTNumPy(FFTBase): #pragma: no cover
371372
372373 """
373374
374- def __init__ (self , shape , axes = None , dtype = float , padding = False , ** kw ):
375+ def __init__ (self , shape , axes = None , dtype = float , padding = False ,
376+ transforms = None , ** kw ):
375377 FFTBase .__init__ (self , shape , axes , dtype , padding )
376378 typecode = self .dtype .char
377379
378380 self .sizes = list (np .take (self .shape , self .axes ))
379381 arrayA = np .zeros (self .shape , self .dtype )
380- if self .real_transform :
381- axis = self .axes [- 1 ]
382- self .shape [axis ] = self .shape [axis ]// 2 + 1
383- arrayB = np .zeros (self .shape , typecode .upper ())
384- fwd = np .fft .rfftn
385- bck = np .fft .irfftn
382+ transforms = {} if transforms is None else transforms
383+ if tuple (self .axes ) in transforms :
384+ fwd , bck = transforms [tuple (self .axes )]
385+ arrayB = fwd (arrayA , axes = self .axes ).astype (typecode )
386+ self .fwd = functools .partial (fwd , shape = self .sizes )
387+ self .bck = functools .partial (bck , shape = self .sizes )
388+
386389 else :
387- arrayB = np .zeros (self .shape , typecode )
388- fwd = np .fft .fftn
389- bck = np .fft .ifftn
390+ if self .real_transform :
391+ fwd = np .fft .rfftn
392+ bck = np .fft .irfftn
393+ arrayB = fwd (arrayA , s = self .sizes , axes = self .axes ).astype (typecode .upper ())
394+ self .shape = arrayB .shape
395+ else :
396+ fwd = np .fft .fftn
397+ bck = np .fft .ifftn
398+ arrayB = fwd (arrayA , s = self .sizes , axes = self .axes ).astype (typecode )
399+ self .fwd = functools .partial (fwd , s = self .sizes )
400+ self .bck = functools .partial (bck , s = self .sizes )
390401
391- fwd .input_array = arrayA
392- fwd .output_array = arrayB
393- bck .input_array = arrayB
394- bck .output_array = arrayA
395- self .fwd , self .bck = fwd , bck
402+ self .fwd_input_array = arrayA
403+ self .fwd_output_array = arrayB
404+ self .bck_input_array = arrayB
405+ self .bck_output_array = arrayA
396406
397407 self .padding_factor = 1
398408 if padding is not False :
@@ -407,14 +417,14 @@ def __init__(self, shape, axes=None, dtype=float, padding=False, **kw):
407417 self .backward = _Xfftn_wrap (self ._backward , arrayB , arrayA )
408418
409419 def _forward (self , ** kw ):
410- self .fwd . output_array [:] = self .fwd (self .fwd . input_array , s = self . sizes ,
420+ self .fwd_output_array [:] = self .fwd (self .fwd_input_array ,
411421 axes = self .axes , ** kw )
412- self ._truncation_forward (self .fwd . output_array , self .forward .output_array )
422+ self ._truncation_forward (self .fwd_output_array , self .forward .output_array )
413423 return self .forward .output_array
414424
415425 def _backward (self , ** kw ):
416- self ._padding_backward (self .backward .input_array , self .bck . input_array )
417- self .backward .output_array [:] = self .bck (self .bck . input_array , s = self . sizes ,
426+ self ._padding_backward (self .backward .input_array , self .bck_input_array )
427+ self .backward .output_array [:] = self .bck (self .bck_input_array ,
418428 axes = self .axes , ** kw )
419429 return self .backward .output_array
420430
0 commit comments