@@ -34,7 +34,7 @@ def next_fast_len(n):
3434 return 2 ** ceil (np .log2 (n ))
3535
3636
37- def cwt (data , scales , wavelet , sampling_period = 1. , method = 'conv' ):
37+ def cwt (data , scales , wavelet , sampling_period = 1. , method = 'conv' , axis = - 1 ):
3838 """
3939 cwt(data, scales, wavelet)
4040
@@ -66,12 +66,16 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
6666 The ``fft`` method is ``O(N * log2(N))`` with
6767 ``N = len(scale) + len(data) - 1``. It is well suited for large size
6868 signals but slightly slower than ``conv`` on small ones.
69+ axis: int, optional
70+ Axis over which to compute the CWT. If not given, the last axis is
71+ used.
6972
7073 Returns
7174 -------
7275 coefs : array_like
7376 Continuous wavelet transform of the input signal for the given scales
74- and wavelet
77+ and wavelet. The first axis of ``coefs`` corresponds to the scales.
78+ The remaining axes match the shape of ``data``.
7579 frequencies : array_like
7680 If the unit of sampling period are seconds and given, than frequencies
7781 are in hertz. Otherwise, a sampling period of 1 is assumed.
@@ -112,62 +116,86 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
112116 wavelet = DiscreteContinuousWavelet (wavelet )
113117 if np .isscalar (scales ):
114118 scales = np .array ([scales ])
115- if data .ndim == 1 :
116- dt_out = dt_cplx if wavelet .complex_cwt else dt
117- out = np .empty ((np .size (scales ), data .size ), dtype = dt_out )
118- precision = 10
119- int_psi , x = integrate_wavelet (wavelet , precision = precision )
120-
121- # convert int_psi, x to the same precision as the data
122- dt_psi = dt_cplx if int_psi .dtype .kind == 'c' else dt
123- int_psi = np .asarray (int_psi , dtype = dt_psi )
124- x = np .asarray (x , dtype = data .real .dtype )
125-
126- if method == 'fft' :
127- size_scale0 = - 1
128- fft_data = None
129- elif not method == 'conv' :
130- raise ValueError ("method must be 'conv' or 'fft'" )
131-
132- for i , scale in enumerate (scales ):
133- step = x [1 ] - x [0 ]
134- j = np .arange (scale * (x [- 1 ] - x [0 ]) + 1 ) / (scale * step )
135- j = j .astype (int ) # floor
136- if j [- 1 ] >= int_psi .size :
137- j = np .extract (j < int_psi .size , j )
138- int_psi_scale = int_psi [j ][::- 1 ]
139-
140- if method == 'conv' :
119+ if not np .isscalar (axis ):
120+ raise ValueError ("axis must be a scalar." )
121+
122+ dt_out = dt_cplx if wavelet .complex_cwt else dt
123+ out = np .empty ((np .size (scales ),) + data .shape , dtype = dt_out )
124+ precision = 10
125+ int_psi , x = integrate_wavelet (wavelet , precision = precision )
126+
127+ # convert int_psi, x to the same precision as the data
128+ dt_psi = dt_cplx if int_psi .dtype .kind == 'c' else dt
129+ int_psi = np .asarray (int_psi , dtype = dt_psi )
130+ x = np .asarray (x , dtype = data .real .dtype )
131+
132+ if method == 'fft' :
133+ size_scale0 = - 1
134+ fft_data = None
135+ elif not method == 'conv' :
136+ raise ValueError ("method must be 'conv' or 'fft'" )
137+
138+ if data .ndim > 1 :
139+ # move axis to be transformed last (so it is contiguous)
140+ data = data .swapaxes (- 1 , axis )
141+
142+ # reshape to (n_batch, data.shape[-1])
143+ data_shape_pre = data .shape
144+ data = data .reshape ((- 1 , data .shape [- 1 ]))
145+
146+ for i , scale in enumerate (scales ):
147+ step = x [1 ] - x [0 ]
148+ j = np .arange (scale * (x [- 1 ] - x [0 ]) + 1 ) / (scale * step )
149+ j = j .astype (int ) # floor
150+ if j [- 1 ] >= int_psi .size :
151+ j = np .extract (j < int_psi .size , j )
152+ int_psi_scale = int_psi [j ][::- 1 ]
153+
154+ if method == 'conv' :
155+ if data .ndim == 1 :
141156 conv = np .convolve (data , int_psi_scale )
142157 else :
143- # The padding is selected for:
144- # - optimal FFT complexity
145- # - to be larger than the two signals length to avoid circular
146- # convolution
147- size_scale = next_fast_len (data .size + int_psi_scale .size - 1 )
148- if size_scale != size_scale0 :
149- # Must recompute fft_data when the padding size changes.
150- fft_data = fftmodule .fft (data , size_scale )
151- size_scale0 = size_scale
152- fft_wav = fftmodule .fft (int_psi_scale , size_scale )
153- conv = fftmodule .ifft (fft_wav * fft_data )
154- conv = conv [:data .size + int_psi_scale .size - 1 ]
155-
156- coef = - np .sqrt (scale ) * np .diff (conv )
157- if out .dtype .kind != 'c' :
158- coef = coef .real
159- d = (coef .size - data .size ) / 2.
160- if d > 0 :
161- out [i , :] = coef [floor (d ):- ceil (d )]
162- elif d == 0. :
163- out [i , :] = coef
164- else :
165- raise ValueError (
166- "Selected scale of {} too small." .format (scale ))
167- frequencies = scale2frequency (wavelet , scales , precision )
168- if np .isscalar (frequencies ):
169- frequencies = np .array ([frequencies ])
170- frequencies /= sampling_period
171- return out , frequencies
172- else :
173- raise ValueError ("Only dim == 1 supported" )
158+ # batch convolution via loop
159+ conv_shape = list (data .shape )
160+ conv_shape [- 1 ] += int_psi_scale .size - 1
161+ conv_shape = tuple (conv_shape )
162+ conv = np .empty (conv_shape , dtype = dt_out )
163+ for n in range (data .shape [0 ]):
164+ conv [n , :] = np .convolve (data [n ], int_psi_scale )
165+ else :
166+ # The padding is selected for:
167+ # - optimal FFT complexity
168+ # - to be larger than the two signals length to avoid circular
169+ # convolution
170+ size_scale = next_fast_len (
171+ data .shape [- 1 ] + int_psi_scale .size - 1
172+ )
173+ if size_scale != size_scale0 :
174+ # Must recompute fft_data when the padding size changes.
175+ fft_data = fftmodule .fft (data , size_scale , axis = - 1 )
176+ size_scale0 = size_scale
177+ fft_wav = fftmodule .fft (int_psi_scale , size_scale , axis = - 1 )
178+ conv = fftmodule .ifft (fft_wav * fft_data , axis = - 1 )
179+ conv = conv [..., :data .shape [- 1 ] + int_psi_scale .size - 1 ]
180+
181+ coef = - np .sqrt (scale ) * np .diff (conv , axis = - 1 )
182+ if out .dtype .kind != 'c' :
183+ coef = coef .real
184+ # transform axis is always -1 due to the data reshape above
185+ d = (coef .shape [- 1 ] - data .shape [- 1 ]) / 2.
186+ if d > 0 :
187+ coef = coef [..., floor (d ):- ceil (d )]
188+ elif d < 0 :
189+ raise ValueError (
190+ "Selected scale of {} too small." .format (scale ))
191+ if data .ndim > 1 :
192+ # restore original data shape and axis position
193+ coef = coef .reshape (data_shape_pre )
194+ coef = coef .swapaxes (axis , - 1 )
195+ out [i , ...] = coef
196+
197+ frequencies = scale2frequency (wavelet , scales , precision )
198+ if np .isscalar (frequencies ):
199+ frequencies = np .array ([frequencies ])
200+ frequencies /= sampling_period
201+ return out , frequencies
0 commit comments