33__maintainer__ = []
44
55import numpy as np
6+ from numba import njit , prange , get_num_threads , set_num_threads
67
78from aeon .transformations .collection import BaseCollectionTransformer
9+ from aeon .utils .validation import check_n_jobs
810
911
1012class PAA (BaseCollectionTransformer ):
@@ -39,12 +41,14 @@ class PAA(BaseCollectionTransformer):
3941
4042 _tags = {
4143 "capability:multivariate" : True ,
44+ "capability:multithreading" : True ,
4245 "fit_is_empty" : True ,
4346 "algorithm_type" : "dictionary" ,
4447 }
4548
46- def __init__ (self , n_segments = 8 ):
49+ def __init__ (self , n_segments = 8 , n_jobs = 1 ):
4750 self .n_segments = n_segments
51+ self .n_jobs = n_jobs
4852
4953 super ().__init__ ()
5054
@@ -71,7 +75,6 @@ def _transform(self, X, y=None):
7175 # of segments is 3, the indices will be [0:3], [3:6] and [6:10]
7276 # so 3 segments, two of length 3 and one of length 4
7377 split_segments = np .array_split (all_indices , self .n_segments )
74-
7578 # If the series length is divisible by the number of segments
7679 # then the transformation can be done in one line
7780 # If not, a for loop is needed only on the segments while
@@ -82,13 +85,13 @@ def _transform(self, X, y=None):
8285 return X_paa
8386
8487 else :
85- n_samples , n_channels , _ = X . shape
86- X_paa = np . zeros ( shape = ( n_samples , n_channels , self .n_segments ) )
87-
88- for _s , segment in enumerate ( split_segments ):
89- if X [:, :, segment ]. shape [ - 1 ] > 0 : # avoids mean of empty slice error
90- X_paa [:, :, _s ] = X [:, :, segment ]. mean ( axis = - 1 )
91-
88+ prev_threads = get_num_threads ()
89+ _n_jobs = check_n_jobs ( self .n_jobs )
90+ set_num_threads ( _n_jobs )
91+ X_paa = _parallel_paa_transform (
92+ X , n_segments = self . n_segments , split_segments = split_segments
93+ )
94+ set_num_threads ( prev_threads )
9295 return X_paa
9396
9497 def inverse_paa (self , X , original_length ):
@@ -110,17 +113,17 @@ def inverse_paa(self, X, original_length):
110113 return np .repeat (X , repeats = int (original_length / self .n_segments ), axis = - 1 )
111114
112115 else :
113- n_samples , n_channels , _ = X . shape
114- X_inverse_paa = np . zeros ( shape = ( n_samples , n_channels , original_length ) )
115-
116- all_indices = np . arange ( original_length )
117- split_segments = np . array_split ( all_indices , self . n_segments )
118-
119- for _s , segment in enumerate ( split_segments ):
120- X_inverse_paa [:, :, segment ] = np . repeat (
121- X [:, :, [ _s ]], repeats = len ( segment ), axis = - 1
122- )
123-
116+ split_segments = np . array_split ( np . arange ( original_length ), self . n_segments )
117+ prev_threads = get_num_threads ( )
118+ _n_jobs = check_n_jobs ( self . n_jobs )
119+ set_num_threads ( _n_jobs )
120+ X_inverse_paa = _parallel_inverse_paa_transform (
121+ X ,
122+ original_length = original_length ,
123+ n_segments = self . n_segments ,
124+ split_segments = split_segments ,
125+ )
126+ set_num_threads ( prev_threads )
124127 return X_inverse_paa
125128
126129 @classmethod
@@ -143,3 +146,44 @@ def _get_test_params(cls, parameter_set="default"):
143146 """
144147 params = {"n_segments" : 10 }
145148 return params
149+
150+
151+ @njit (parallel = True , fastmath = True )
152+ def _parallel_paa_transform (X , n_segments , split_segments ):
153+ """Parallelized PAA for uneven segment splits using Numba."""
154+ n_samples , n_channels , _ = X .shape
155+ X_paa = np .zeros ((n_samples , n_channels , n_segments ), dtype = X .dtype )
156+
157+ for _s in prange (n_segments ): # Parallel over segments
158+ segment = split_segments [_s ]
159+ seg_len = segment .shape [0 ]
160+
161+ if seg_len == 0 :
162+ continue # skip empty segment
163+
164+ for i in range (n_samples ):
165+ for j in range (n_channels ):
166+ acc = 0.0
167+ for k in range (seg_len ):
168+ acc += X [i , j , segment [k ]]
169+ X_paa [i , j , _s ] = acc / seg_len
170+
171+ return X_paa
172+
173+
174+ @njit (parallel = True , fastmath = True )
175+ def _parallel_inverse_paa_transform (X , original_length , n_segments , split_segments ):
176+ """Parallelize the inverse PAA transformation for cases where the series length is not
177+ divisible by the number of segments."""
178+ n_samples , n_channels , _ = X .shape
179+ X_inverse_paa = np .zeros (shape = (n_samples , n_channels , original_length ))
180+
181+ for _s in prange (n_segments ):
182+ segment = split_segments [_s ]
183+ for idx in prange (len (segment )):
184+ t = segment [idx ]
185+ for i in prange (n_samples ):
186+ for j in prange (n_channels ):
187+ X_inverse_paa [i , j , t ] = X [i , j , _s ]
188+
189+ return X_inverse_paa
0 commit comments