@@ -66,12 +66,16 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
6666 The ``fft`` method is ``O(N * log2(N))`` with
6767 ``N = len(scale) + len(data) - 1``. It is well suited for large size
6868 signals but slightly slower than ``conv`` on small ones.
69+ axis: int, optional
70+ Axis over which to compute the CWT. If not given, the last axis is
71+ used.
6972
7073 Returns
7174 -------
7275 coefs : array_like
7376 Continuous wavelet transform of the input signal for the given scales
74- and wavelet
77+ and wavelet. The first axis of ``coefs`` corresponds to the scales.
78+ The remaining axes match the shape of ``data``.
7579 frequencies : array_like
7680 If the unit of sampling period are seconds and given, than frequencies
7781 are in hertz. Otherwise, a sampling period of 1 is assumed.
@@ -135,7 +139,7 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
135139 # move axis to be transformed last (so it is contiguous)
136140 data = data .swapaxes (- 1 , axis )
137141
138- # reshape to (n_batch, data.shape[axis ])
142+ # reshape to (n_batch, data.shape[-1 ])
139143 data_shape_pre = data .shape
140144 data = data .reshape ((- 1 , data .shape [- 1 ]))
141145
@@ -195,4 +199,3 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
195199 frequencies = np .array ([frequencies ])
196200 frequencies /= sampling_period
197201 return out , frequencies
198-
0 commit comments