Skip to content

Commit f1b04ce

Browse files
committed
parallelize saa and pax
1 parent 636d7e4 commit f1b04ce

File tree

2 files changed

+86
-21
lines changed

2 files changed

+86
-21
lines changed

aeon/transformations/collection/dictionary_based/_paa.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
__maintainer__ = []
44

55
import numpy as np
6+
from numba import njit, prange, get_num_threads, set_num_threads
67

78
from aeon.transformations.collection import BaseCollectionTransformer
9+
from aeon.utils.validation import check_n_jobs
810

911

1012
class PAA(BaseCollectionTransformer):
@@ -39,12 +41,14 @@ class PAA(BaseCollectionTransformer):
3941

4042
_tags = {
4143
"capability:multivariate": True,
44+
"capability:multithreading": True,
4245
"fit_is_empty": True,
4346
"algorithm_type": "dictionary",
4447
}
4548

46-
def __init__(self, n_segments=8):
49+
def __init__(self, n_segments=8, n_jobs=1):
4750
self.n_segments = n_segments
51+
self.n_jobs = n_jobs
4852

4953
super().__init__()
5054

@@ -71,7 +75,6 @@ def _transform(self, X, y=None):
7175
# of segments is 3, the indices will be [0:3], [3:6] and [6:10]
7276
# so 3 segments, two of length 3 and one of length 4
7377
split_segments = np.array_split(all_indices, self.n_segments)
74-
7578
# If the series length is divisible by the number of segments
7679
# then the transformation can be done in one line
7780
# If not, a for loop is needed only on the segments while
@@ -82,13 +85,13 @@ def _transform(self, X, y=None):
8285
return X_paa
8386

8487
else:
85-
n_samples, n_channels, _ = X.shape
86-
X_paa = np.zeros(shape=(n_samples, n_channels, self.n_segments))
87-
88-
for _s, segment in enumerate(split_segments):
89-
if X[:, :, segment].shape[-1] > 0: # avoids mean of empty slice error
90-
X_paa[:, :, _s] = X[:, :, segment].mean(axis=-1)
91-
88+
prev_threads = get_num_threads()
89+
_n_jobs = check_n_jobs(self.n_jobs)
90+
set_num_threads(_n_jobs)
91+
X_paa = _parallel_paa_transform(
92+
X, n_segments=self.n_segments, split_segments=split_segments
93+
)
94+
set_num_threads(prev_threads)
9295
return X_paa
9396

9497
def inverse_paa(self, X, original_length):
@@ -110,17 +113,17 @@ def inverse_paa(self, X, original_length):
110113
return np.repeat(X, repeats=int(original_length / self.n_segments), axis=-1)
111114

112115
else:
113-
n_samples, n_channels, _ = X.shape
114-
X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length))
115-
116-
all_indices = np.arange(original_length)
117-
split_segments = np.array_split(all_indices, self.n_segments)
118-
119-
for _s, segment in enumerate(split_segments):
120-
X_inverse_paa[:, :, segment] = np.repeat(
121-
X[:, :, [_s]], repeats=len(segment), axis=-1
122-
)
123-
116+
split_segments = np.array_split(np.arange(original_length), self.n_segments)
117+
prev_threads = get_num_threads()
118+
_n_jobs = check_n_jobs(self.n_jobs)
119+
set_num_threads(_n_jobs)
120+
X_inverse_paa = _parallel_inverse_paa_transform(
121+
X,
122+
original_length=original_length,
123+
n_segments=self.n_segments,
124+
split_segments=split_segments,
125+
)
126+
set_num_threads(prev_threads)
124127
return X_inverse_paa
125128

126129
@classmethod
@@ -143,3 +146,44 @@ def _get_test_params(cls, parameter_set="default"):
143146
"""
144147
params = {"n_segments": 10}
145148
return params
149+
150+
151+
@njit(parallel=True, fastmath=True)
152+
def _parallel_paa_transform(X, n_segments, split_segments):
153+
"""Parallelized PAA for uneven segment splits using Numba."""
154+
n_samples, n_channels, _ = X.shape
155+
X_paa = np.zeros((n_samples, n_channels, n_segments), dtype=X.dtype)
156+
157+
for _s in prange(n_segments): # Parallel over segments
158+
segment = split_segments[_s]
159+
seg_len = segment.shape[0]
160+
161+
if seg_len == 0:
162+
continue # skip empty segment
163+
164+
for i in range(n_samples):
165+
for j in range(n_channels):
166+
acc = 0.0
167+
for k in range(seg_len):
168+
acc += X[i, j, segment[k]]
169+
X_paa[i, j, _s] = acc / seg_len
170+
171+
return X_paa
172+
173+
174+
@njit(parallel=True, fastmath=True)
175+
def _parallel_inverse_paa_transform(X, original_length, n_segments, split_segments):
176+
"""Parallelize the inverse PAA transformation for cases where the series length is not
177+
divisible by the number of segments."""
178+
n_samples, n_channels, _ = X.shape
179+
X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length))
180+
181+
for _s in prange(n_segments):
182+
segment = split_segments[_s]
183+
for idx in prange(len(segment)):
184+
t = segment[idx]
185+
for i in prange(n_samples):
186+
for j in prange(n_channels):
187+
X_inverse_paa[i, j, t] = X[i, j, _s]
188+
189+
return X_inverse_paa

aeon/transformations/collection/dictionary_based/_sax.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,11 @@ def _get_sax_symbols(self, X_paa):
167167
sax_symbols : np.ndarray of shape = (n_cases, n_channels, n_segments)
168168
The output of the SAX transformation using np.digitize
169169
"""
170-
sax_symbols = np.digitize(x=X_paa, bins=self.breakpoints)
170+
prev_threads = get_num_threads()
171+
_n_jobs = check_n_jobs(self.n_jobs)
172+
set_num_threads(_n_jobs)
173+
sax_symbols = _parallel_get_sax_symbols(X_paa, breakpoints=self.breakpoints)
174+
set_num_threads(prev_threads)
171175
return sax_symbols
172176

173177
def inverse_sax(self, X, original_length, y=None):
@@ -292,3 +296,20 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid):
292296
]
293297

294298
return sax_inverse
299+
300+
301+
@njit(fastmath=True, cache=True, parallel=True)
302+
def _parallel_get_sax_symbols(X, breakpoints):
303+
n_cases, n_channels, n_timepoints = X.shape
304+
X_new = np.zeros((n_cases, n_channels, n_timepoints), dtype=np.intp)
305+
n_break = breakpoints.shape[0] - 1
306+
for i_x in prange(n_cases):
307+
for i_c in prange(n_channels):
308+
for i_b in prange(n_break):
309+
mask = np.where(
310+
(X[i_x, i_c] >= breakpoints[i_b])
311+
& (X[i_x, i_c] < breakpoints[i_b + 1])
312+
)[0]
313+
X_new[i_x, i_c, mask] += np.array(i_b).astype(np.intp)
314+
315+
return X_new

0 commit comments

Comments
 (0)