Skip to content

Commit 648a4ce

Browse files
committed
add batch cwt cases to the benchmarks
1 parent d974834 commit 648a4ce

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

benchmarks/benchmarks/cwt_benchmarks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def setup(self, n, wavelet, max_scale, dtype, method):
2020
except ImportError:
2121
raise NotImplementedError("cwt not available")
2222
self.data = np.ones(n, dtype=dtype)
23+
self.batch_data = np.ones((5, n), dtype=dtype)
2324
self.scales = np.arange(1, max_scale + 1)
2425

2526

@@ -33,3 +34,12 @@ def time_cwt(self, n, wavelet, max_scale, dtype, method):
3334
raise NotImplementedError(
3435
"fft-based convolution not available.")
3536
pywt.cwt(self.data, self.scales, wavelet)
37+
38+
def time_cwt_batch(self, n, wavelet, max_scale, dtype, method):
39+
try:
40+
pywt.cwt(self.batch_data, self.scales, wavelet, method=method,
41+
axis=-1)
42+
except TypeError:
43+
# older PyWavelets does not support the axis argument
44+
raise NotImplementedError(
45+
"axis argument not available.")

0 commit comments

Comments
 (0)