Skip to content

Commit 3cb5436

Browse files
committed
move pad from _doc_utils.py to _dwt.py and add docstring
1 parent d872ef6 commit 3cb5436

File tree

3 files changed

+83
-60
lines changed

3 files changed

+83
-60
lines changed

pywt/_doc_utils.py

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import numpy as np
55
from matplotlib import pyplot as plt
66

7+
from ._dwt import pad
8+
79
__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis',
8-
'draw_2d_fswavedecn_basis', 'pad', 'boundary_mode_subplot']
10+
'draw_2d_fswavedecn_basis', 'boundary_mode_subplot']
911

1012

1113
def wavedec_keys(level):
@@ -149,63 +151,6 @@ def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None,
149151
return fig, ax
150152

151153

152-
def pad(x, pad_widths, mode):
153-
"""Extend a 1D signal using a given boundary mode.
154-
155-
Like numpy.pad but supports all PyWavelets boundary modes.
156-
"""
157-
if np.isscalar(pad_widths):
158-
pad_widths = (pad_widths, pad_widths)
159-
160-
if x.ndim > 1:
161-
raise ValueError("This padding function is only for 1D signals.")
162-
163-
if mode in ['symmetric', 'reflect']:
164-
xp = np.pad(x, pad_widths, mode=mode)
165-
elif mode in ['periodic', 'periodization']:
166-
if mode == 'periodization' and x.size % 2 == 1:
167-
raise ValueError("periodization expects an even length signal.")
168-
xp = np.pad(x, pad_widths, mode='wrap')
169-
elif mode == 'zeros':
170-
xp = np.pad(x, pad_widths, mode='constant', constant_values=0)
171-
elif mode == 'constant':
172-
xp = np.pad(x, pad_widths, mode='edge')
173-
elif mode == 'smooth':
174-
xp = np.pad(x, pad_widths, mode='linear_ramp',
175-
end_values=(x[0] + pad_widths[0] * (x[0] - x[1]),
176-
x[-1] + pad_widths[1] * (x[-1] - x[-2])))
177-
elif mode == 'antisymmetric':
178-
# implement by flipping portions symmetric padding
179-
npad_l, npad_r = pad_widths
180-
xp = np.pad(x, pad_widths, mode='symmetric')
181-
r_edge = npad_l + x.size - 1
182-
l_edge = npad_l
183-
# width of each reflected segment
184-
seg_width = x.size
185-
# flip reflected segments on the right of the original signal
186-
n = 1
187-
while r_edge <= xp.size:
188-
segment_slice = slice(r_edge + 1,
189-
min(r_edge + 1 + seg_width, xp.size))
190-
if n % 2:
191-
xp[segment_slice] *= -1
192-
r_edge += seg_width
193-
n += 1
194-
195-
# flip reflected segments on the left of the original signal
196-
n = 1
197-
while l_edge >= 0:
198-
segment_slice = slice(max(0, l_edge - seg_width), l_edge)
199-
if n % 2:
200-
xp[segment_slice] *= -1
201-
l_edge -= seg_width
202-
n += 1
203-
elif mode == 'antireflect':
204-
npad_l, npad_r = pad_widths
205-
xp = np.pad(x, pad_widths, mode='reflect', reflect_type='odd')
206-
return xp
207-
208-
209154
def boundary_mode_subplot(x, mode, ax, symw=True):
210155
"""Plot an illustration of the boundary mode in a subplot axis."""
211156

pywt/_dwt.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level",
15-
"dwt_coeff_len"]
15+
"dwt_coeff_len", "pad"]
1616

1717

1818
def dwt_max_level(data_len, filter_len):
@@ -401,3 +401,80 @@ def upcoef(part, coeffs, wavelet, level=1, take=0):
401401
if part not in 'ad':
402402
raise ValueError("Argument 1 must be 'a' or 'd', not '%s'." % part)
403403
return np.asarray(_upcoef(part == 'a', coeffs, wavelet, level, take))
404+
405+
406+
def pad(x, pad_widths, mode):
407+
"""Extend a 1D signal using a given boundary mode.
408+
409+
This is like `numpy.pad` but supports all PyWavelets boundary modes.
410+
411+
Parameters
412+
----------
413+
x : ndarray
414+
The array to pad
415+
pad_widths : {sequence, array_like, int}
416+
Number of values padded to the edges of each axis.
417+
((before_1, after_1), … (before_N, after_N)) unique pad widths for each
418+
axis. ((before, after),) yields same before and after pad for each
419+
axis. (pad,) or int is a shortcut for before = after = pad width for
420+
all axes.
421+
mode : str, optional
422+
Signal extension mode, see Modes.
423+
424+
Returns
425+
-------
426+
pad : ndarray
427+
Padded array of rank equal to array with shape increased according to
428+
`pad_width`.
429+
430+
"""
431+
if np.isscalar(pad_widths):
432+
pad_widths = (pad_widths, pad_widths)
433+
434+
if x.ndim > 1:
435+
raise ValueError("This padding function is only for 1D signals.")
436+
437+
if mode in ['symmetric', 'reflect']:
438+
xp = np.pad(x, pad_widths, mode=mode)
439+
elif mode in ['periodic', 'periodization']:
440+
if mode == 'periodization' and x.size % 2 == 1:
441+
raise ValueError("periodization expects an even length signal.")
442+
xp = np.pad(x, pad_widths, mode='wrap')
443+
elif mode == 'zeros':
444+
xp = np.pad(x, pad_widths, mode='constant', constant_values=0)
445+
elif mode == 'constant':
446+
xp = np.pad(x, pad_widths, mode='edge')
447+
elif mode == 'smooth':
448+
xp = np.pad(x, pad_widths, mode='linear_ramp',
449+
end_values=(x[0] + pad_widths[0] * (x[0] - x[1]),
450+
x[-1] + pad_widths[1] * (x[-1] - x[-2])))
451+
elif mode == 'antisymmetric':
452+
# implement by flipping portions symmetric padding
453+
npad_l, npad_r = pad_widths
454+
xp = np.pad(x, pad_widths, mode='symmetric')
455+
r_edge = npad_l + x.size - 1
456+
l_edge = npad_l
457+
# width of each reflected segment
458+
seg_width = x.size
459+
# flip reflected segments on the right of the original signal
460+
n = 1
461+
while r_edge <= xp.size:
462+
segment_slice = slice(r_edge + 1,
463+
min(r_edge + 1 + seg_width, xp.size))
464+
if n % 2:
465+
xp[segment_slice] *= -1
466+
r_edge += seg_width
467+
n += 1
468+
469+
# flip reflected segments on the left of the original signal
470+
n = 1
471+
while l_edge >= 0:
472+
segment_slice = slice(max(0, l_edge - seg_width), l_edge)
473+
if n % 2:
474+
xp[segment_slice] *= -1
475+
l_edge -= seg_width
476+
n += 1
477+
elif mode == 'antireflect':
478+
npad_l, npad_r = pad_widths
479+
xp = np.pad(x, pad_widths, mode='reflect', reflect_type='odd')
480+
return xp

pywt/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# <https://github.com/PyWavelets/pywt>
33
# See COPYING for license details.
44
import inspect
5+
import numpy as np
56
import sys
67
from collections.abc import Iterable
78

@@ -17,7 +18,7 @@
1718

1819

1920
def _as_wavelet(wavelet):
20-
"""Convert wavelet name to a Wavelet object"""
21+
"""Convert wavelet name to a Wavelet object."""
2122
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
2223
wavelet = DiscreteContinuousWavelet(wavelet)
2324
if isinstance(wavelet, ContinuousWavelet):

0 commit comments

Comments
 (0)