@@ -299,17 +299,28 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid):
299299
300300
301301@njit (fastmath = True , cache = True , parallel = True )
302- def _parallel_get_sax_symbols (X , breakpoints ):
303- n_cases , n_channels , n_timepoints = X .shape
304- X_new = np .zeros ((n_cases , n_channels , n_timepoints ), dtype = np .intp )
305- n_break = breakpoints .shape [0 ] - 1
306- for i_x in prange (n_cases ):
307- for i_c in prange (n_channels ):
308- for i_b in prange (n_break ):
309- mask = np .where (
310- (X [i_x , i_c ] >= breakpoints [i_b ])
311- & (X [i_x , i_c ] < breakpoints [i_b + 1 ])
312- )[0 ]
313- X_new [i_x , i_c , mask ] += np .array (i_b ).astype (np .intp )
314-
315- return X_new
302+ def _parallel_get_sax_symbols (x , bins , right = False ):
303+ """Parallel version of `np.digitize`."""
304+ x_flat = x .flatten ()
305+ result = np .empty (x_flat .shape [0 ], dtype = np .intp )
306+
307+ for i in prange (x_flat .shape [0 ]):
308+ val = x_flat [i ]
309+ bin_idx = 0
310+
311+ if right :
312+ for j in range (len (bins )):
313+ if val <= bins [j ]:
314+ bin_idx = j
315+ break
316+ bin_idx = j + 1
317+ else :
318+ for j in range (len (bins )):
319+ if val < bins [j ]:
320+ bin_idx = j
321+ break
322+ bin_idx = j + 1
323+
324+ result [i ] = bin_idx
325+
326+ return result .reshape (x .shape )
0 commit comments