Skip to content

Commit 9ba3a1c

Browse files
authored
Merge pull request #509 from grlee77/batch_cwt
add axis support to cwt
2 parents 20ab3c1 + 648a4ce commit 9ba3a1c

File tree

3 files changed

+158
-82
lines changed

3 files changed

+158
-82
lines changed

benchmarks/benchmarks/cwt_benchmarks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def setup(self, n, wavelet, max_scale, dtype, method):
2020
except ImportError:
2121
raise NotImplementedError("cwt not available")
2222
self.data = np.ones(n, dtype=dtype)
23+
self.batch_data = np.ones((5, n), dtype=dtype)
2324
self.scales = np.arange(1, max_scale + 1)
2425

2526

@@ -33,3 +34,12 @@ def time_cwt(self, n, wavelet, max_scale, dtype, method):
3334
raise NotImplementedError(
3435
"fft-based convolution not available.")
3536
pywt.cwt(self.data, self.scales, wavelet)
37+
38+
def time_cwt_batch(self, n, wavelet, max_scale, dtype, method):
39+
try:
40+
pywt.cwt(self.batch_data, self.scales, wavelet, method=method,
41+
axis=-1)
42+
except TypeError:
43+
# older PyWavelets does not support the axis argument
44+
raise NotImplementedError(
45+
"axis argument not available.")

pywt/_cwt.py

Lines changed: 87 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def next_fast_len(n):
3434
return 2**ceil(np.log2(n))
3535

3636

37-
def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
37+
def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
3838
"""
3939
cwt(data, scales, wavelet)
4040
@@ -66,12 +66,16 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
6666
The ``fft`` method is ``O(N * log2(N))`` with
6767
``N = len(scale) + len(data) - 1``. It is well suited for large size
6868
signals but slightly slower than ``conv`` on small ones.
69+
axis: int, optional
70+
Axis over which to compute the CWT. If not given, the last axis is
71+
used.
6972
7073
Returns
7174
-------
7275
coefs : array_like
7376
Continuous wavelet transform of the input signal for the given scales
74-
and wavelet
77+
and wavelet. The first axis of ``coefs`` corresponds to the scales.
78+
The remaining axes match the shape of ``data``.
7579
frequencies : array_like
7680
If the unit of sampling period are seconds and given, than frequencies
7781
are in hertz. Otherwise, a sampling period of 1 is assumed.
@@ -112,62 +116,86 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
112116
wavelet = DiscreteContinuousWavelet(wavelet)
113117
if np.isscalar(scales):
114118
scales = np.array([scales])
115-
if data.ndim == 1:
116-
dt_out = dt_cplx if wavelet.complex_cwt else dt
117-
out = np.empty((np.size(scales), data.size), dtype=dt_out)
118-
precision = 10
119-
int_psi, x = integrate_wavelet(wavelet, precision=precision)
120-
121-
# convert int_psi, x to the same precision as the data
122-
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
123-
int_psi = np.asarray(int_psi, dtype=dt_psi)
124-
x = np.asarray(x, dtype=data.real.dtype)
125-
126-
if method == 'fft':
127-
size_scale0 = -1
128-
fft_data = None
129-
elif not method == 'conv':
130-
raise ValueError("method must be 'conv' or 'fft'")
131-
132-
for i, scale in enumerate(scales):
133-
step = x[1] - x[0]
134-
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
135-
j = j.astype(int) # floor
136-
if j[-1] >= int_psi.size:
137-
j = np.extract(j < int_psi.size, j)
138-
int_psi_scale = int_psi[j][::-1]
139-
140-
if method == 'conv':
119+
if not np.isscalar(axis):
120+
raise ValueError("axis must be a scalar.")
121+
122+
dt_out = dt_cplx if wavelet.complex_cwt else dt
123+
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
124+
precision = 10
125+
int_psi, x = integrate_wavelet(wavelet, precision=precision)
126+
127+
# convert int_psi, x to the same precision as the data
128+
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
129+
int_psi = np.asarray(int_psi, dtype=dt_psi)
130+
x = np.asarray(x, dtype=data.real.dtype)
131+
132+
if method == 'fft':
133+
size_scale0 = -1
134+
fft_data = None
135+
elif not method == 'conv':
136+
raise ValueError("method must be 'conv' or 'fft'")
137+
138+
if data.ndim > 1:
139+
# move axis to be transformed last (so it is contiguous)
140+
data = data.swapaxes(-1, axis)
141+
142+
# reshape to (n_batch, data.shape[-1])
143+
data_shape_pre = data.shape
144+
data = data.reshape((-1, data.shape[-1]))
145+
146+
for i, scale in enumerate(scales):
147+
step = x[1] - x[0]
148+
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
149+
j = j.astype(int) # floor
150+
if j[-1] >= int_psi.size:
151+
j = np.extract(j < int_psi.size, j)
152+
int_psi_scale = int_psi[j][::-1]
153+
154+
if method == 'conv':
155+
if data.ndim == 1:
141156
conv = np.convolve(data, int_psi_scale)
142157
else:
143-
# The padding is selected for:
144-
# - optimal FFT complexity
145-
# - to be larger than the two signals length to avoid circular
146-
# convolution
147-
size_scale = next_fast_len(data.size + int_psi_scale.size - 1)
148-
if size_scale != size_scale0:
149-
# Must recompute fft_data when the padding size changes.
150-
fft_data = fftmodule.fft(data, size_scale)
151-
size_scale0 = size_scale
152-
fft_wav = fftmodule.fft(int_psi_scale, size_scale)
153-
conv = fftmodule.ifft(fft_wav * fft_data)
154-
conv = conv[:data.size + int_psi_scale.size - 1]
155-
156-
coef = - np.sqrt(scale) * np.diff(conv)
157-
if out.dtype.kind != 'c':
158-
coef = coef.real
159-
d = (coef.size - data.size) / 2.
160-
if d > 0:
161-
out[i, :] = coef[floor(d):-ceil(d)]
162-
elif d == 0.:
163-
out[i, :] = coef
164-
else:
165-
raise ValueError(
166-
"Selected scale of {} too small.".format(scale))
167-
frequencies = scale2frequency(wavelet, scales, precision)
168-
if np.isscalar(frequencies):
169-
frequencies = np.array([frequencies])
170-
frequencies /= sampling_period
171-
return out, frequencies
172-
else:
173-
raise ValueError("Only dim == 1 supported")
158+
# batch convolution via loop
159+
conv_shape = list(data.shape)
160+
conv_shape[-1] += int_psi_scale.size - 1
161+
conv_shape = tuple(conv_shape)
162+
conv = np.empty(conv_shape, dtype=dt_out)
163+
for n in range(data.shape[0]):
164+
conv[n, :] = np.convolve(data[n], int_psi_scale)
165+
else:
166+
# The padding is selected for:
167+
# - optimal FFT complexity
168+
# - to be larger than the two signals length to avoid circular
169+
# convolution
170+
size_scale = next_fast_len(
171+
data.shape[-1] + int_psi_scale.size - 1
172+
)
173+
if size_scale != size_scale0:
174+
# Must recompute fft_data when the padding size changes.
175+
fft_data = fftmodule.fft(data, size_scale, axis=-1)
176+
size_scale0 = size_scale
177+
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
178+
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
179+
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
180+
181+
coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
182+
if out.dtype.kind != 'c':
183+
coef = coef.real
184+
# transform axis is always -1 due to the data reshape above
185+
d = (coef.shape[-1] - data.shape[-1]) / 2.
186+
if d > 0:
187+
coef = coef[..., floor(d):-ceil(d)]
188+
elif d < 0:
189+
raise ValueError(
190+
"Selected scale of {} too small.".format(scale))
191+
if data.ndim > 1:
192+
# restore original data shape and axis position
193+
coef = coef.reshape(data_shape_pre)
194+
coef = coef.swapaxes(axis, -1)
195+
out[i, ...] = coef
196+
197+
frequencies = scale2frequency(wavelet, scales, precision)
198+
if np.isscalar(frequencies):
199+
frequencies = np.array([frequencies])
200+
frequencies /= sampling_period
201+
return out, frequencies

pywt/tests/test_cwt_wavelets.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#!/usr/bin/env python
22
from __future__ import division, print_function, absolute_import
3+
from itertools import product
34

45
from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
56
assert_raises, assert_equal)
7+
import pytest
68
import numpy as np
79
import pywt
810

@@ -344,29 +346,65 @@ def test_cwt_parameters_in_names():
344346
assert_raises(ValueError, func, 'fbsp1-1-1-1')
345347

346348

347-
def test_cwt_complex():
348-
for dtype, tol in [(np.float32, 1e-5), (np.float64, 1e-13)]:
349-
time, sst = pywt.data.nino()
350-
sst = np.asarray(sst, dtype=dtype)
351-
dt = time[1] - time[0]
352-
wavelet = 'cmor1.5-1.0'
353-
scales = np.arange(1, 32)
354-
355-
for method in ['conv', 'fft']:
356-
# real-valued tranfsorm as a reference
357-
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)
358-
359-
# verify same precision
360-
assert_equal(cfs.real.dtype, sst.dtype)
361-
362-
# complex-valued transform equals sum of the transforms of the real
363-
# and imaginary components
364-
sst_complex = sst + 1j*sst
365-
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
366-
method=method)
367-
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
368-
# verify dtype is preserved
369-
assert_equal(cfs_complex.dtype, sst_complex.dtype)
349+
@pytest.mark.parametrize('dtype, tol, method',
350+
[(np.float32, 1e-5, 'conv'),
351+
(np.float32, 1e-5, 'fft'),
352+
(np.float64, 1e-13, 'conv'),
353+
(np.float64, 1e-13, 'fft')])
354+
def test_cwt_complex(dtype, tol, method):
355+
time, sst = pywt.data.nino()
356+
sst = np.asarray(sst, dtype=dtype)
357+
dt = time[1] - time[0]
358+
wavelet = 'cmor1.5-1.0'
359+
scales = np.arange(1, 32)
360+
361+
# real-valued tranfsorm as a reference
362+
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)
363+
364+
# verify same precision
365+
assert_equal(cfs.real.dtype, sst.dtype)
366+
367+
# complex-valued transform equals sum of the transforms of the real
368+
# and imaginary components
369+
sst_complex = sst + 1j*sst
370+
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
371+
method=method)
372+
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
373+
# verify dtype is preserved
374+
assert_equal(cfs_complex.dtype, sst_complex.dtype)
375+
376+
377+
@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft']))
378+
def test_cwt_batch(axis, method):
379+
dtype = np.float64
380+
time, sst = pywt.data.nino()
381+
n_batch = 8
382+
batch_axis = 1 - axis
383+
sst1 = np.asarray(sst, dtype=dtype)
384+
sst = np.stack((sst1, ) * n_batch, axis=batch_axis)
385+
dt = time[1] - time[0]
386+
wavelet = 'cmor1.5-1.0'
387+
scales = np.arange(1, 32)
388+
389+
# non-batch transform as reference
390+
[cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis)
391+
392+
shape_in = sst.shape
393+
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis)
394+
395+
# shape of input is not modified
396+
assert_equal(shape_in, sst.shape)
397+
398+
# verify same precision
399+
assert_equal(cfs.real.dtype, sst.dtype)
400+
401+
# verify expected shape
402+
assert_equal(cfs.shape[0], len(scales))
403+
assert_equal(cfs.shape[1 + batch_axis], n_batch)
404+
assert_equal(cfs.shape[1 + axis], sst.shape[axis])
405+
406+
# batch result on stacked input is the same as stacked 1d result
407+
assert_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1))
370408

371409

372410
def test_cwt_small_scales():

0 commit comments

Comments
 (0)