From f1b04ce99d8357d11b3a6a5abceb88e97b87d357 Mon Sep 17 00:00:00 2001 From: Aadya Chinubhai Date: Thu, 31 Jul 2025 14:50:33 +0530 Subject: [PATCH 1/8] parallelize saa and pax --- .../collection/dictionary_based/_paa.py | 84 ++++++++++++++----- .../collection/dictionary_based/_sax.py | 23 ++++- 2 files changed, 86 insertions(+), 21 deletions(-) diff --git a/aeon/transformations/collection/dictionary_based/_paa.py b/aeon/transformations/collection/dictionary_based/_paa.py index 2cba574a1c..7c99e60adb 100644 --- a/aeon/transformations/collection/dictionary_based/_paa.py +++ b/aeon/transformations/collection/dictionary_based/_paa.py @@ -3,8 +3,10 @@ __maintainer__ = [] import numpy as np +from numba import njit, prange, get_num_threads, set_num_threads from aeon.transformations.collection import BaseCollectionTransformer +from aeon.utils.validation import check_n_jobs class PAA(BaseCollectionTransformer): @@ -39,12 +41,14 @@ class PAA(BaseCollectionTransformer): _tags = { "capability:multivariate": True, + "capability:multithreading": True, "fit_is_empty": True, "algorithm_type": "dictionary", } - def __init__(self, n_segments=8): + def __init__(self, n_segments=8, n_jobs=1): self.n_segments = n_segments + self.n_jobs = n_jobs super().__init__() @@ -71,7 +75,6 @@ def _transform(self, X, y=None): # of segments is 3, the indices will be [0:3], [3:6] and [6:10] # so 3 segments, two of length 3 and one of length 4 split_segments = np.array_split(all_indices, self.n_segments) - # If the series length is divisible by the number of segments # then the transformation can be done in one line # If not, a for loop is needed only on the segments while @@ -82,13 +85,13 @@ def _transform(self, X, y=None): return X_paa else: - n_samples, n_channels, _ = X.shape - X_paa = np.zeros(shape=(n_samples, n_channels, self.n_segments)) - - for _s, segment in enumerate(split_segments): - if X[:, :, segment].shape[-1] > 0: # avoids mean of empty slice error - X_paa[:, :, _s] = X[:, :, segment].mean(axis=-1) - + prev_threads = get_num_threads() + _n_jobs = check_n_jobs(self.n_jobs) + set_num_threads(_n_jobs) + X_paa = _parallel_paa_transform( + X, n_segments=self.n_segments, split_segments=split_segments + ) + set_num_threads(prev_threads) return X_paa def inverse_paa(self, X, original_length): @@ -110,17 +113,17 @@ def inverse_paa(self, X, original_length): return np.repeat(X, repeats=int(original_length / self.n_segments), axis=-1) else: - n_samples, n_channels, _ = X.shape - X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length)) - - all_indices = np.arange(original_length) - split_segments = np.array_split(all_indices, self.n_segments) - - for _s, segment in enumerate(split_segments): - X_inverse_paa[:, :, segment] = np.repeat( - X[:, :, [_s]], repeats=len(segment), axis=-1 - ) - + split_segments = np.array_split(np.arange(original_length), self.n_segments) + prev_threads = get_num_threads() + _n_jobs = check_n_jobs(self.n_jobs) + set_num_threads(_n_jobs) + X_inverse_paa = _parallel_inverse_paa_transform( + X, + original_length=original_length, + n_segments=self.n_segments, + split_segments=split_segments, + ) + set_num_threads(prev_threads) return X_inverse_paa @classmethod @@ -143,3 +146,44 @@ def _get_test_params(cls, parameter_set="default"): """ params = {"n_segments": 10} return params + + +@njit(parallel=True, fastmath=True) +def _parallel_paa_transform(X, n_segments, split_segments): + """Parallelized PAA for uneven segment splits using Numba.""" + n_samples, n_channels, _ = X.shape + X_paa = np.zeros((n_samples, n_channels, n_segments), dtype=X.dtype) + + for _s in prange(n_segments): # Parallel over segments + segment = split_segments[_s] + seg_len = segment.shape[0] + + if seg_len == 0: + continue # skip empty segment + + for i in range(n_samples): + for j in range(n_channels): + acc = 0.0 + for k in range(seg_len): + acc += X[i, j, segment[k]] + X_paa[i, j, _s] = acc / seg_len + + return X_paa + + +@njit(parallel=True, fastmath=True) +def _parallel_inverse_paa_transform(X, original_length, n_segments, split_segments): + """Parallelize the inverse PAA transformation for cases where the series length is not + divisible by the number of segments.""" + n_samples, n_channels, _ = X.shape + X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length)) + + for _s in prange(n_segments): + segment = split_segments[_s] + for idx in prange(len(segment)): + t = segment[idx] + for i in prange(n_samples): + for j in prange(n_channels): + X_inverse_paa[i, j, t] = X[i, j, _s] + + return X_inverse_paa diff --git a/aeon/transformations/collection/dictionary_based/_sax.py b/aeon/transformations/collection/dictionary_based/_sax.py index 8200f804ad..32c48708d6 100644 --- a/aeon/transformations/collection/dictionary_based/_sax.py +++ b/aeon/transformations/collection/dictionary_based/_sax.py @@ -167,7 +167,11 @@ def _get_sax_symbols(self, X_paa): sax_symbols : np.ndarray of shape = (n_cases, n_channels, n_segments) The output of the SAX transformation using np.digitize """ - sax_symbols = np.digitize(x=X_paa, bins=self.breakpoints) + prev_threads = get_num_threads() + _n_jobs = check_n_jobs(self.n_jobs) + set_num_threads(_n_jobs) + sax_symbols = _parallel_get_sax_symbols(X_paa, breakpoints=self.breakpoints) + set_num_threads(prev_threads) return sax_symbols def inverse_sax(self, X, original_length, y=None): @@ -292,3 +296,20 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid): ] return sax_inverse + + +@njit(fastmath=True, cache=True, parallel=True) +def _parallel_get_sax_symbols(X, breakpoints): + n_cases, n_channels, n_timepoints = X.shape + X_new = np.zeros((n_cases, n_channels, n_timepoints), dtype=np.intp) + n_break = breakpoints.shape[0] - 1 + for i_x in prange(n_cases): + for i_c in prange(n_channels): + for i_b in prange(n_break): + mask = np.where( + (X[i_x, i_c] >= breakpoints[i_b]) + & (X[i_x, i_c] < breakpoints[i_b + 1]) + )[0] + X_new[i_x, i_c, mask] += np.array(i_b).astype(np.intp) + + return X_new From cea39a53e4f05f03bd8f22bf9c743a8548d961c3 Mon Sep 17 00:00:00 2001 From: Aadya Chinubhai Date: Thu, 31 Jul 2025 15:14:36 +0530 Subject: [PATCH 2/8] bug fix --- .../collection/dictionary_based/_sax.py | 39 ++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/aeon/transformations/collection/dictionary_based/_sax.py b/aeon/transformations/collection/dictionary_based/_sax.py index 32c48708d6..398beb9b2b 100644 --- a/aeon/transformations/collection/dictionary_based/_sax.py +++ b/aeon/transformations/collection/dictionary_based/_sax.py @@ -299,17 +299,28 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid): @njit(fastmath=True, cache=True, parallel=True) -def _parallel_get_sax_symbols(X, breakpoints): - n_cases, n_channels, n_timepoints = X.shape - X_new = np.zeros((n_cases, n_channels, n_timepoints), dtype=np.intp) - n_break = breakpoints.shape[0] - 1 - for i_x in prange(n_cases): - for i_c in prange(n_channels): - for i_b in prange(n_break): - mask = np.where( - (X[i_x, i_c] >= breakpoints[i_b]) - & (X[i_x, i_c] < breakpoints[i_b + 1]) - )[0] - X_new[i_x, i_c, mask] += np.array(i_b).astype(np.intp) - - return X_new +def _parallel_get_sax_symbols(x, bins, right=False): + """Parallel version of `np.digitize`.""" + x_flat = x.flatten() + result = np.empty(x_flat.shape[0], dtype=np.intp) + + for i in prange(x_flat.shape[0]): + val = x_flat[i] + bin_idx = 0 + + if right: + for j in range(len(bins)): + if val <= bins[j]: + bin_idx = j + break + bin_idx = j + 1 + else: + for j in range(len(bins)): + if val < bins[j]: + bin_idx = j + break + bin_idx = j + 1 + + result[i] = bin_idx + + return result.reshape(x.shape) From f36f223decd397b9841ae8bed0af96b654a1d6b1 Mon Sep 17 00:00:00 2001 From: Aadya Chinubhai Date: Thu, 31 Jul 2025 15:24:18 +0530 Subject: [PATCH 3/8] fix --- .../collection/dictionary_based/_paa.py | 3 +- .../collection/dictionary_based/_sax.py | 2 +- test.py | 98 +++++++++++++++++++ 3 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 test.py diff --git a/aeon/transformations/collection/dictionary_based/_paa.py b/aeon/transformations/collection/dictionary_based/_paa.py index 7c99e60adb..668f2f03f8 100644 --- a/aeon/transformations/collection/dictionary_based/_paa.py +++ b/aeon/transformations/collection/dictionary_based/_paa.py @@ -174,7 +174,8 @@ def _parallel_paa_transform(X, n_segments, split_segments): @njit(parallel=True, fastmath=True) def _parallel_inverse_paa_transform(X, original_length, n_segments, split_segments): """Parallelize the inverse PAA transformation for cases where the series length is not - divisible by the number of segments.""" + divisible by the number of segments. + """ n_samples, n_channels, _ = X.shape X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length)) diff --git a/aeon/transformations/collection/dictionary_based/_sax.py b/aeon/transformations/collection/dictionary_based/_sax.py index 398beb9b2b..e08a973b28 100644 --- a/aeon/transformations/collection/dictionary_based/_sax.py +++ b/aeon/transformations/collection/dictionary_based/_sax.py @@ -299,7 +299,7 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid): @njit(fastmath=True, cache=True, parallel=True) -def _parallel_get_sax_symbols(x, bins, right=False): +def _parallel_get_sax_symbols(x, breakpoints, right=False): """Parallel version of `np.digitize`.""" x_flat = x.flatten() result = np.empty(x_flat.shape[0], dtype=np.intp) diff --git a/test.py b/test.py new file mode 100644 index 0000000000..fa5d7d30a5 --- /dev/null +++ b/test.py @@ -0,0 +1,98 @@ +import numpy as np +from numba import njit, prange + +# @njit(fastmath=True, cache=True) +# def numba_digitize(x, bins, right=False): +# """ +# Numba implementation that produces identical output to np.digitize. +# """ +# x_flat = x.flatten() +# result = np.empty(x_flat.shape[0], dtype=np.intp) + +# for i in range(x_flat.shape[0]): +# val = x_flat[i] +# bin_idx = 0 + +# if right: +# # bins[i] < x <= bins[i+1] +# for j in range(len(bins)): +# if val <= bins[j]: +# bin_idx = j +# break +# bin_idx = j + 1 +# else: +# # bins[i] <= x < bins[i+1] (default behavior) +# for j in range(len(bins)): +# if val < bins[j]: +# bin_idx = j +# break +# bin_idx = j + 1 + +# result[i] = bin_idx + +# return result.reshape(x.shape) + +@njit(fastmath=True, cache=True, parallel=True) +def numba_digitize_parallel(x, bins, right=False): + """ + Parallel version for better performance on large arrays. + """ + x_flat = x.flatten() + result = np.empty(x_flat.shape[0], dtype=np.intp) + + for i in prange(x_flat.shape[0]): + val = x_flat[i] + bin_idx = 0 + + if right: + for j in range(len(bins)): + if val <= bins[j]: + bin_idx = j + break + bin_idx = j + 1 + else: + for j in range(len(bins)): + if val < bins[j]: + bin_idx = j + break + bin_idx = j + 1 + + result[i] = bin_idx + + return result.reshape(x.shape) + + +@njit(fastmath=True, cache=True, parallel=True) +def _parallel_get_sax_symbols(X, breakpoints): + n_cases, n_channels, n_timepoints = X.shape + X_new = np.zeros((n_cases, n_channels, n_timepoints), dtype=np.intp) + n_break = breakpoints.shape[0] - 1 + for i_x in prange(n_cases): + for i_c in prange(n_channels): + for i_b in prange(n_break): + mask = np.where( + (X[i_x, i_c] >= breakpoints[i_b]) + & (X[i_x, i_c] < breakpoints[i_b + 1]) + )[0] + X_new[i_x, i_c, mask] += np.array(i_b).astype(np.intp) + + return X_new + + +# Test to verify identical output +if __name__ == "__main__": + x = np.array([[[0.2, 6.4, 3.0, 1.6]]]) + bins = np.array([0.0, 1.0, 2.5, 4.0, 10.0]) + + print("Original:", np.digitize(x, bins)) + print("Numba: ", numba_digitize_parallel(x, bins)) + print("Match: ", np.array_equal(np.digitize(x, bins), numba_digitize_parallel(x, bins))) + + print("Curr: ", _parallel_get_sax_symbols(x, bins)) + + # Test with right=True + print("\nWith right=True:") + print("Original:", np.digitize(x, bins, right=True)) + print("Numba: ", numba_digitize_parallel(x, bins, right=True)) + print("Match: ", np.array_equal(np.digitize(x, bins, right=True), numba_digitize_parallel(x, bins, right=True))) + print("Curr: ", _parallel_get_sax_symbols(x, bins)) \ No newline at end of file From 6b8d269729a4efbeaebbf57296e74b760d65b5be Mon Sep 17 00:00:00 2001 From: Aadya Chinubhai Date: Thu, 31 Jul 2025 15:25:59 +0530 Subject: [PATCH 4/8] fix --- test.py | 98 --------------------------------------------------------- 1 file changed, 98 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index fa5d7d30a5..0000000000 --- a/test.py +++ /dev/null @@ -1,98 +0,0 @@ -import numpy as np -from numba import njit, prange - -# @njit(fastmath=True, cache=True) -# def numba_digitize(x, bins, right=False): -# """ -# Numba implementation that produces identical output to np.digitize. -# """ -# x_flat = x.flatten() -# result = np.empty(x_flat.shape[0], dtype=np.intp) - -# for i in range(x_flat.shape[0]): -# val = x_flat[i] -# bin_idx = 0 - -# if right: -# # bins[i] < x <= bins[i+1] -# for j in range(len(bins)): -# if val <= bins[j]: -# bin_idx = j -# break -# bin_idx = j + 1 -# else: -# # bins[i] <= x < bins[i+1] (default behavior) -# for j in range(len(bins)): -# if val < bins[j]: -# bin_idx = j -# break -# bin_idx = j + 1 - -# result[i] = bin_idx - -# return result.reshape(x.shape) - -@njit(fastmath=True, cache=True, parallel=True) -def numba_digitize_parallel(x, bins, right=False): - """ - Parallel version for better performance on large arrays. - """ - x_flat = x.flatten() - result = np.empty(x_flat.shape[0], dtype=np.intp) - - for i in prange(x_flat.shape[0]): - val = x_flat[i] - bin_idx = 0 - - if right: - for j in range(len(bins)): - if val <= bins[j]: - bin_idx = j - break - bin_idx = j + 1 - else: - for j in range(len(bins)): - if val < bins[j]: - bin_idx = j - break - bin_idx = j + 1 - - result[i] = bin_idx - - return result.reshape(x.shape) - - -@njit(fastmath=True, cache=True, parallel=True) -def _parallel_get_sax_symbols(X, breakpoints): - n_cases, n_channels, n_timepoints = X.shape - X_new = np.zeros((n_cases, n_channels, n_timepoints), dtype=np.intp) - n_break = breakpoints.shape[0] - 1 - for i_x in prange(n_cases): - for i_c in prange(n_channels): - for i_b in prange(n_break): - mask = np.where( - (X[i_x, i_c] >= breakpoints[i_b]) - & (X[i_x, i_c] < breakpoints[i_b + 1]) - )[0] - X_new[i_x, i_c, mask] += np.array(i_b).astype(np.intp) - - return X_new - - -# Test to verify identical output -if __name__ == "__main__": - x = np.array([[[0.2, 6.4, 3.0, 1.6]]]) - bins = np.array([0.0, 1.0, 2.5, 4.0, 10.0]) - - print("Original:", np.digitize(x, bins)) - print("Numba: ", numba_digitize_parallel(x, bins)) - print("Match: ", np.array_equal(np.digitize(x, bins), numba_digitize_parallel(x, bins))) - - print("Curr: ", _parallel_get_sax_symbols(x, bins)) - - # Test with right=True - print("\nWith right=True:") - print("Original:", np.digitize(x, bins, right=True)) - print("Numba: ", numba_digitize_parallel(x, bins, right=True)) - print("Match: ", np.array_equal(np.digitize(x, bins, right=True), numba_digitize_parallel(x, bins, right=True))) - print("Curr: ", _parallel_get_sax_symbols(x, bins)) \ No newline at end of file From 6b15711f7f086c10b3e385951c2d48ba23763b75 Mon Sep 17 00:00:00 2001 From: Aadya Chinubhai Date: Thu, 31 Jul 2025 15:31:51 +0530 Subject: [PATCH 5/8] minor bug fix --- aeon/transformations/collection/dictionary_based/_sax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aeon/transformations/collection/dictionary_based/_sax.py b/aeon/transformations/collection/dictionary_based/_sax.py index e08a973b28..460abf9c26 100644 --- a/aeon/transformations/collection/dictionary_based/_sax.py +++ b/aeon/transformations/collection/dictionary_based/_sax.py @@ -309,14 +309,14 @@ def _parallel_get_sax_symbols(x, breakpoints, right=False): bin_idx = 0 if right: - for j in range(len(bins)): - if val <= bins[j]: + for j in range(len(breakpoints)): + if val <= breakpoints[j]: bin_idx = j break bin_idx = j + 1 else: - for j in range(len(bins)): - if val < bins[j]: + for j in range(len(breakpoints)): + if val < breakpoints[j]: bin_idx = j break bin_idx = j + 1 From 1bf5b6b65d8ab87b0565b6306a9aef7938d3bc1a Mon Sep 17 00:00:00 2001 From: aadya940 <77720426+aadya940@users.noreply.github.com> Date: Thu, 31 Jul 2025 10:02:31 +0000 Subject: [PATCH 6/8] Automatic `pre-commit` fixes --- aeon/transformations/collection/dictionary_based/_paa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/transformations/collection/dictionary_based/_paa.py b/aeon/transformations/collection/dictionary_based/_paa.py index 668f2f03f8..1c8a78dae7 100644 --- a/aeon/transformations/collection/dictionary_based/_paa.py +++ b/aeon/transformations/collection/dictionary_based/_paa.py @@ -3,7 +3,7 @@ __maintainer__ = [] import numpy as np -from numba import njit, prange, get_num_threads, set_num_threads +from numba import get_num_threads, njit, prange, set_num_threads from aeon.transformations.collection import BaseCollectionTransformer from aeon.utils.validation import check_n_jobs From a6cc829cd149bfac3f6de9a370f36eacd4c74aa6 Mon Sep 17 00:00:00 2001 From: Aadya Chinubhai Date: Thu, 31 Jul 2025 18:59:10 +0530 Subject: [PATCH 7/8] add cache=True to numba decorator --- aeon/transformations/collection/dictionary_based/_paa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aeon/transformations/collection/dictionary_based/_paa.py b/aeon/transformations/collection/dictionary_based/_paa.py index 1c8a78dae7..eb8f57599c 100644 --- a/aeon/transformations/collection/dictionary_based/_paa.py +++ b/aeon/transformations/collection/dictionary_based/_paa.py @@ -148,7 +148,7 @@ def _get_test_params(cls, parameter_set="default"): return params -@njit(parallel=True, fastmath=True) +@njit(parallel=True, cache=True, fastmath=True) def _parallel_paa_transform(X, n_segments, split_segments): """Parallelized PAA for uneven segment splits using Numba.""" n_samples, n_channels, _ = X.shape @@ -171,7 +171,7 @@ def _parallel_paa_transform(X, n_segments, split_segments): return X_paa -@njit(parallel=True, fastmath=True) +@njit(parallel=True, cache=True, fastmath=True) def _parallel_inverse_paa_transform(X, original_length, n_segments, split_segments): """Parallelize the inverse PAA transformation for cases where the series length is not divisible by the number of segments. From 8335ba95c020ae43b7673886b9c182345f354a51 Mon Sep 17 00:00:00 2001 From: Aadya Chinubhai Date: Sun, 21 Sep 2025 19:39:51 -0700 Subject: [PATCH 8/8] Use np.digitize and use .mean in numba parallel implementation. --- .../collection/dictionary_based/_paa.py | 9 ++---- .../collection/dictionary_based/_sax.py | 29 +++++-------------- 2 files changed, 9 insertions(+), 29 deletions(-) diff --git a/aeon/transformations/collection/dictionary_based/_paa.py b/aeon/transformations/collection/dictionary_based/_paa.py index eb8f57599c..bf5d461e5c 100644 --- a/aeon/transformations/collection/dictionary_based/_paa.py +++ b/aeon/transformations/collection/dictionary_based/_paa.py @@ -163,19 +163,14 @@ def _parallel_paa_transform(X, n_segments, split_segments): for i in range(n_samples): for j in range(n_channels): - acc = 0.0 - for k in range(seg_len): - acc += X[i, j, segment[k]] - X_paa[i, j, _s] = acc / seg_len + X_paa[i, j, _s] = X[i, j, segment].mean() return X_paa @njit(parallel=True, cache=True, fastmath=True) def _parallel_inverse_paa_transform(X, original_length, n_segments, split_segments): - """Parallelize the inverse PAA transformation for cases where the series length is not - divisible by the number of segments. - """ + """Parallelize inverse PAA when series len % segments ≠ 0.""" n_samples, n_channels, _ = X.shape X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length)) diff --git a/aeon/transformations/collection/dictionary_based/_sax.py b/aeon/transformations/collection/dictionary_based/_sax.py index 460abf9c26..37319123ac 100644 --- a/aeon/transformations/collection/dictionary_based/_sax.py +++ b/aeon/transformations/collection/dictionary_based/_sax.py @@ -300,27 +300,12 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid): @njit(fastmath=True, cache=True, parallel=True) def _parallel_get_sax_symbols(x, breakpoints, right=False): - """Parallel version of `np.digitize`.""" - x_flat = x.flatten() - result = np.empty(x_flat.shape[0], dtype=np.intp) - - for i in prange(x_flat.shape[0]): - val = x_flat[i] - bin_idx = 0 - - if right: - for j in range(len(breakpoints)): - if val <= breakpoints[j]: - bin_idx = j - break - bin_idx = j + 1 - else: - for j in range(len(breakpoints)): - if val < breakpoints[j]: - bin_idx = j - break - bin_idx = j + 1 + """Parallel version using np.digitize within prange loop.""" + n_samples, n_channels, n_segments = x.shape + result = np.empty_like(x, dtype=np.intp) - result[i] = bin_idx + for i in prange(n_samples): + for c in range(n_channels): + result[i, c, :] = np.digitize(x[i, c, :], breakpoints, right=right) - return result.reshape(x.shape) + return result