Skip to content

Commit 9563a43

Browse files
authored
Merge pull request #507 from grlee77/cwt_dtype_fix
preserve single precision in CWT
2 parents ac5793f + fb4b030 commit 9563a43

File tree

3 files changed

+37
-23
lines changed

3 files changed

+37
-23
lines changed

benchmarks/benchmarks/cwt_benchmarks.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,22 @@ class CwtTimeSuiteBase(object):
99
params = ([32, 128, 512, 2048],
1010
['cmor', 'cgau4', 'fbsp', 'gaus4', 'mexh', 'morl', 'shan'],
1111
[16, 64, 256],
12-
['conv', 'fft'])
13-
param_names = ('n', 'wavelet', 'max_scale', 'method')
12+
[np.float32, np.float64],
13+
['conv', 'fft'],
14+
)
15+
param_names = ('n', 'wavelet', 'max_scale', 'dtype', 'method')
1416

15-
def setup(self, n, wavelet, max_scale, method):
17+
def setup(self, n, wavelet, max_scale, dtype, method):
1618
try:
1719
from pywt import cwt
1820
except ImportError:
1921
raise NotImplementedError("cwt not available")
20-
self.data = np.ones(n, dtype='float')
21-
self.scales = np.arange(1, max_scale+1)
22+
self.data = np.ones(n, dtype=dtype)
23+
self.scales = np.arange(1, max_scale + 1)
2224

2325

2426
class CwtTimeSuite(CwtTimeSuiteBase):
25-
def time_cwt(self, n, wavelet, max_scale, method):
27+
def time_cwt(self, n, wavelet, max_scale, dtype, method):
2628
try:
2729
pywt.cwt(self.data, self.scales, wavelet, method=method)
2830
except TypeError:

pywt/_cwt.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,23 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
106106

107107
# accept array_like input; make a copy to ensure a contiguous array
108108
dt = _check_dtype(data)
109-
data = np.array(data, dtype=dt)
109+
data = np.asarray(data, dtype=dt)
110+
dt_cplx = np.result_type(dt, np.complex64)
110111
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
111112
wavelet = DiscreteContinuousWavelet(wavelet)
112113
if np.isscalar(scales):
113114
scales = np.array([scales])
114-
dt_out = None # TODO: fix in/out dtype consistency in a subsequent PR
115115
if data.ndim == 1:
116-
if wavelet.complex_cwt:
117-
dt_out = complex
116+
dt_out = dt_cplx if wavelet.complex_cwt else dt
118117
out = np.empty((np.size(scales), data.size), dtype=dt_out)
119118
precision = 10
120119
int_psi, x = integrate_wavelet(wavelet, precision=precision)
121120

121+
# convert int_psi, x to the same precision as the data
122+
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
123+
int_psi = np.asarray(int_psi, dtype=dt_psi)
124+
x = np.asarray(x, dtype=data.real.dtype)
125+
122126
if method == 'fft':
123127
size_scale0 = -1
124128
fft_data = None
@@ -150,8 +154,8 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
150154
conv = conv[:data.size + int_psi_scale.size - 1]
151155

152156
coef = - np.sqrt(scale) * np.diff(conv)
153-
if not np.iscomplexobj(out):
154-
coef = np.real(coef)
157+
if out.dtype.kind != 'c':
158+
coef = coef.real
155159
d = (coef.size - data.size) / 2.
156160
if d > 0:
157161
out[i, :] = coef[floor(d):-ceil(d)]

pywt/tests/test_cwt_wavelets.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import division, print_function, absolute_import
33

44
from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
5-
assert_raises)
5+
assert_raises, assert_equal)
66
import numpy as np
77
import pywt
88

@@ -345,20 +345,28 @@ def test_cwt_parameters_in_names():
345345

346346

347347
def test_cwt_complex():
348-
for dtype in [np.float32, np.float64]:
348+
for dtype, tol in [(np.float32, 1e-5), (np.float64, 1e-13)]:
349349
time, sst = pywt.data.nino()
350350
sst = np.asarray(sst, dtype=dtype)
351351
dt = time[1] - time[0]
352352
wavelet = 'cmor1.5-1.0'
353353
scales = np.arange(1, 32)
354354

355-
# real-valued tranfsorm
356-
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt)
355+
for method in ['conv', 'fft']:
356+
# real-valued tranfsorm as a reference
357+
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)
357358

358-
# complex-valued tranfsorm equals sum of the transforms of the real and
359-
# imaginary components
360-
[cfs_complex, f] = pywt.cwt(sst + 1j*sst, scales, wavelet, dt)
361-
assert_almost_equal(cfs + 1j*cfs, cfs_complex)
359+
# verify same precision
360+
assert_equal(cfs.real.dtype, sst.dtype)
361+
362+
# complex-valued transform equals sum of the transforms of the real
363+
# and imaginary components
364+
sst_complex = sst + 1j*sst
365+
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
366+
method=method)
367+
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
368+
# verify dtype is preserved
369+
assert_equal(cfs_complex.dtype, sst_complex.dtype)
362370

363371

364372
def test_cwt_small_scales():
@@ -377,12 +385,12 @@ def test_cwt_method_fft():
377385
rstate = np.random.RandomState(1)
378386
data = rstate.randn(50)
379387
data[15] = 1.
380-
scales = np.arange(1, 64)
381-
wavelet = 'cmor1.5-1.0'
388+
scales = np.arange(1, 64)
389+
wavelet = 'cmor1.5-1.0'
382390

383391
# build a reference cwt with the legacy np.conv() method
384392
cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv')
385393

386394
# compare with the fft based convolution
387-
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
395+
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
388396
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)

0 commit comments

Comments
 (0)