Skip to content

Commit f5c4e26

Browse files
authored
Merge pull request #476 from grlee77/swt_enhancements
swt normalization and option to trim the approximation coefficients
2 parents 9563a43 + 1652e84 commit f5c4e26

File tree

6 files changed

+454
-133
lines changed

6 files changed

+454
-133
lines changed

pywt/_dwt.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def dwt_coeff_len(data_len, filter_len, mode):
9292
Data length.
9393
filter_len : int
9494
Filter length.
95-
mode : str, optional (default: 'symmetric')
96-
Signal extension mode, see Modes
95+
mode : str, optional
96+
Signal extension mode, see :ref:`Modes <ref-modes>`.
9797
9898
Returns
9999
-------
@@ -130,7 +130,7 @@ def dwt(data, wavelet, mode='symmetric', axis=-1):
130130
wavelet : Wavelet object or name
131131
Wavelet to use
132132
mode : str, optional
133-
Signal extension mode, see Modes
133+
Signal extension mode, see :ref:`Modes <ref-modes>`.
134134
axis: int, optional
135135
Axis over which to compute the DWT. If not given, the
136136
last axis is used.
@@ -199,14 +199,14 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
199199
----------
200200
cA : array_like or None
201201
Approximation coefficients. If None, will be set to array of zeros
202-
with same shape as `cD`.
202+
with same shape as ``cD``.
203203
cD : array_like or None
204204
Detail coefficients. If None, will be set to array of zeros
205-
with same shape as `cA`.
205+
with same shape as ``cA``.
206206
wavelet : Wavelet object or name
207207
Wavelet to use
208208
mode : str, optional (default: 'symmetric')
209-
Signal extension mode, see Modes
209+
Signal extension mode, see :ref:`Modes <ref-modes>`.
210210
axis: int, optional
211211
Axis over which to compute the inverse DWT. If not given, the
212212
last axis is used.
@@ -224,7 +224,7 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
224224
>>> pywt.idwt(cA, cD, 'db2', 'smooth')
225225
array([ 1., 2., 3., 4., 5., 6.])
226226
227-
One of the neat features of `idwt` is that one of the ``cA`` and ``cD``
227+
One of the neat features of ``idwt`` is that one of the ``cA`` and ``cD``
228228
arguments can be set to None. In that situation the reconstruction will be
229229
performed using only the other one. Mathematically speaking, this is
230230
equivalent to passing a zero-filled array as one of the arguments.
@@ -300,7 +300,7 @@ def downcoef(part, data, wavelet, mode='symmetric', level=1):
300300
301301
Partial Discrete Wavelet Transform data decomposition.
302302
303-
Similar to `pywt.dwt`, but computes only one set of coefficients.
303+
Similar to ``pywt.dwt``, but computes only one set of coefficients.
304304
Useful when you need only approximation or only details at the given level.
305305
306306
Parameters
@@ -316,7 +316,7 @@ def downcoef(part, data, wavelet, mode='symmetric', level=1):
316316
wavelet : Wavelet object or name
317317
Wavelet to use
318318
mode : str, optional
319-
Signal extension mode, see `Modes`. Default is 'symmetric'.
319+
Signal extension mode, see :ref:`Modes <ref-modes>`.
320320
level : int, optional
321321
Decomposition level. Default is 1.
322322

pywt/_extensions/_swt.pyx

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#cython: boundscheck=False, wraparound=False
22
from . cimport common
33
from . cimport c_wt
4+
from cpython cimport bool
45

56
import warnings
67
import numpy as np
@@ -9,6 +10,7 @@ cimport numpy as np
910
from .common cimport pywt_index_t
1011
from ._pywt cimport c_wavelet_from_object, cdata_t, Wavelet, _check_dtype
1112

13+
1214
include "config.pxi"
1315

1416
def swt_max_level(size_t input_len):
@@ -47,7 +49,8 @@ def swt_max_level(size_t input_len):
4749
return max_level
4850

4951

50-
def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level):
52+
def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level,
53+
bool trim_approx=False):
5154
cdef cdata_t[::1] cA, cD
5255
cdef Wavelet w
5356
cdef int retval
@@ -142,14 +145,20 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level):
142145
raise RuntimeError("C swt failed.")
143146

144147
data = cA
145-
ret.append((cA, cD))
148+
if not trim_approx:
149+
ret.append((np.asarray(cA), np.asarray(cD)))
150+
else:
151+
ret.append(np.asarray(cD))
146152

153+
if trim_approx:
154+
ret.append(np.asarray(cA))
147155
ret.reverse()
148156
return ret
149157

150158

151159
cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
152-
size_t start_level, unsigned int axis=0):
160+
size_t start_level, unsigned int axis=0,
161+
bool trim_approx=False):
153162
# memory-views do not support n-dimensional arrays, use np.ndarray instead
154163
cdef common.ArrayInfo data_info, output_info
155164
cdef np.ndarray cD, cA
@@ -289,13 +298,19 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
289298
if retval == -5:
290299
raise TypeError("Array must be floating point, not {}"
291300
.format(data.dtype))
292-
ret.append((cA, cD))
301+
if not trim_approx:
302+
ret.append((cA, cD))
303+
else:
304+
ret.append(cD)
293305

294306
# previous approx coeffs are the data for the next level
295307
data = cA
296308
# update data_info to match the new data array
297309
data_info.strides = <pywt_index_t *> data.strides
298310
data_info.shape = <size_t *> data.shape
299311

312+
if trim_approx:
313+
ret.append(cA)
314+
300315
ret.reverse()
301316
return ret

pywt/_multidim.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def dwt2(data, wavelet, mode='symmetric', axes=(-2, -1)):
3333
Wavelet to use. This can also be a tuple containing a wavelet to
3434
apply along each axis in ``axes``.
3535
mode : str or 2-tuple of strings, optional
36-
Signal extension mode, see Modes (default: 'symmetric'). This can
36+
Signal extension mode, see :ref:`Modes <ref-modes>`. This can
3737
also be a tuple of modes specifying the mode to use on each axis in
3838
``axes``.
3939
axes : 2-tuple of ints, optional
@@ -84,13 +84,13 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
8484
----------
8585
coeffs : tuple
8686
(cA, (cH, cV, cD)) A tuple with approximation coefficients and three
87-
details coefficients 2D arrays like from `dwt2`. If any of these
87+
details coefficients 2D arrays like from ``dwt2``. If any of these
8888
components are set to ``None``, it will be treated as zeros.
8989
wavelet : Wavelet object or name string, or 2-tuple of wavelets
9090
Wavelet to use. This can also be a tuple containing a wavelet to
9191
apply along each axis in ``axes``.
9292
mode : str or 2-tuple of strings, optional
93-
Signal extension mode, see Modes (default: 'symmetric'). This can
93+
Signal extension mode, see :ref:`Modes <ref-modes>`. This can
9494
also be a tuple of modes specifying the mode to use on each axis in
9595
``axes``.
9696
axes : 2-tuple of ints, optional
@@ -131,7 +131,7 @@ def dwtn(data, wavelet, mode='symmetric', axes=None):
131131
apply along each axis in ``axes``.
132132
mode : str or tuple of string, optional
133133
Signal extension mode used in the decomposition,
134-
see Modes (default: 'symmetric'). This can also be a tuple of modes
134+
see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
135135
specifying the mode to use on each axis in ``axes``.
136136
axes : sequence of ints, optional
137137
Axes over which to compute the DWT. Repeated elements mean the DWT will
@@ -233,7 +233,7 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
233233
apply along each axis in ``axes``.
234234
mode : str or list of string, optional
235235
Signal extension mode used in the decomposition,
236-
see Modes (default: 'symmetric'). This can also be a tuple of modes
236+
see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
237237
specifying the mode to use on each axis in ``axes``.
238238
axes : sequence of ints, optional
239239
Axes over which to compute the IDWT. Repeated elements mean the IDWT

0 commit comments

Comments
 (0)