Skip to content

Commit 26579c2

Browse files
committed
TST: add tests for pywt.pad
1 parent 1ff0132 commit 26579c2

File tree

2 files changed

+75
-9
lines changed

2 files changed

+75
-9
lines changed

pywt/_dwt.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -434,11 +434,14 @@ def pad(x, pad_widths, mode):
434434
for modes `smooth` and `antisymmetric` as these modes are not supported in
435435
an efficient manner by the underlying `numpy.pad` function.
436436
"""
437-
if np.isscalar(pad_widths):
438-
pad_widths = (pad_widths, )
439-
if len(pad_widths) == 1:
440-
pad_widths = (pad_widths[0], ) * x.ndim
441-
pad_widths = [(p, p) if np.isscalar(p) else p for p in pad_widths]
437+
x = np.asanyarray(x)
438+
439+
# process pad_widths exactly as in numpy.pad
440+
pad_widths = np.array(pad_widths)
441+
pad_widths = np.round(pad_widths).astype(np.intp, copy=False)
442+
if pad_widths.min() < 0:
443+
raise ValueError("pad_widths must be > 0")
444+
pad_widths = np.broadcast_to(pad_widths, (x.ndim, 2)).tolist()
442445

443446
if mode in ['symmetric', 'reflect']:
444447
xp = np.pad(x, pad_widths, mode=mode)
@@ -504,7 +507,6 @@ def pad_antisymmetric(vector, pad_width, iaxis, kwargs):
504507
return vector
505508
xp = np.pad(x, pad_widths, pad_antisymmetric)
506509
elif mode == 'antireflect':
507-
npad_l, npad_r = pad_widths
508510
xp = np.pad(x, pad_widths, mode='reflect', reflect_type='odd')
509511
else:
510512
raise ValueError(

pywt/tests/test_dwt_idwt.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from __future__ import division, print_function, absolute_import
33

44
import numpy as np
5-
from numpy.testing import assert_allclose, assert_, assert_raises
6-
5+
from numpy.testing import (assert_allclose, assert_, assert_raises,
6+
assert_array_equal)
77
import pywt
88

99
# Check that float32, float64, complex64, complex128 are preserved.
@@ -228,8 +228,72 @@ def test_error_on_continuous_wavelet():
228228
def test_dwt_zero_size_axes():
229229
# raise on empty input array
230230
assert_raises(ValueError, pywt.dwt, [], 'db2')
231-
231+
232232
# >1D case uses a different code path so check there as well
233233
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
234234
assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0)
235235

236+
237+
def test_pad_1d():
238+
x = [1, 2, 3]
239+
assert_array_equal(pywt.pad(x, (4, 6), 'periodization'),
240+
[1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2])
241+
assert_array_equal(pywt.pad(x, (4, 6), 'periodic'),
242+
[3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
243+
assert_array_equal(pywt.pad(x, (4, 6), 'constant'),
244+
[1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3])
245+
assert_array_equal(pywt.pad(x, (4, 6), 'zero'),
246+
[0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0])
247+
assert_array_equal(pywt.pad(x, (4, 6), 'smooth'),
248+
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
249+
assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'),
250+
[3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3])
251+
assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'),
252+
[3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3])
253+
assert_array_equal(pywt.pad(x, (4, 6), 'reflect'),
254+
[1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1])
255+
assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'),
256+
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
257+
258+
# equivalence of various pad_width formats
259+
assert_array_equal(pywt.pad(x, 4, 'periodic'),
260+
pywt.pad(x, (4, 4), 'periodic'))
261+
262+
assert_array_equal(pywt.pad(x, (4, ), 'periodic'),
263+
pywt.pad(x, (4, 4), 'periodic'))
264+
265+
assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'),
266+
pywt.pad(x, (4, 4), 'periodic'))
267+
268+
269+
def test_pad_errors():
270+
# negative pad width
271+
x = [1, 2, 3]
272+
assert_raises(ValueError, pywt.pad, x, -2, 'periodic')
273+
274+
# wrong length pad width
275+
assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic')
276+
277+
# invalid mode name
278+
assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode')
279+
280+
281+
def test_pad_nd():
282+
for ndim in [2, 3]:
283+
x = np.arange(4**ndim).reshape((4, ) * ndim)
284+
if ndim == 2:
285+
pad_widths = [(2, 1), (2, 3)]
286+
else:
287+
pad_widths = [(2, 1), ] * ndim
288+
for mode in pywt.Modes.modes:
289+
xp = pywt.pad(x, pad_widths, mode)
290+
291+
# expected result is the same as applying along axes separably
292+
xp_expected = x.copy()
293+
for ax in range(ndim):
294+
xp_expected = np.apply_along_axis(pywt.pad,
295+
ax,
296+
xp_expected,
297+
pad_widths=[pad_widths[ax]],
298+
mode=mode)
299+
assert_array_equal(xp, xp_expected)

0 commit comments

Comments
 (0)