1+ import numpy as np
2+ from numba import njit , prange
3+
4+ # @njit(fastmath=True, cache=True)
5+ # def numba_digitize(x, bins, right=False):
6+ # """
7+ # Numba implementation that produces identical output to np.digitize.
8+ # """
9+ # x_flat = x.flatten()
10+ # result = np.empty(x_flat.shape[0], dtype=np.intp)
11+
12+ # for i in range(x_flat.shape[0]):
13+ # val = x_flat[i]
14+ # bin_idx = 0
15+
16+ # if right:
17+ # # bins[i] < x <= bins[i+1]
18+ # for j in range(len(bins)):
19+ # if val <= bins[j]:
20+ # bin_idx = j
21+ # break
22+ # bin_idx = j + 1
23+ # else:
24+ # # bins[i] <= x < bins[i+1] (default behavior)
25+ # for j in range(len(bins)):
26+ # if val < bins[j]:
27+ # bin_idx = j
28+ # break
29+ # bin_idx = j + 1
30+
31+ # result[i] = bin_idx
32+
33+ # return result.reshape(x.shape)
34+
35+ @njit (fastmath = True , cache = True , parallel = True )
36+ def numba_digitize_parallel (x , bins , right = False ):
37+ """
38+ Parallel version for better performance on large arrays.
39+ """
40+ x_flat = x .flatten ()
41+ result = np .empty (x_flat .shape [0 ], dtype = np .intp )
42+
43+ for i in prange (x_flat .shape [0 ]):
44+ val = x_flat [i ]
45+ bin_idx = 0
46+
47+ if right :
48+ for j in range (len (bins )):
49+ if val <= bins [j ]:
50+ bin_idx = j
51+ break
52+ bin_idx = j + 1
53+ else :
54+ for j in range (len (bins )):
55+ if val < bins [j ]:
56+ bin_idx = j
57+ break
58+ bin_idx = j + 1
59+
60+ result [i ] = bin_idx
61+
62+ return result .reshape (x .shape )
63+
64+
65+ @njit (fastmath = True , cache = True , parallel = True )
66+ def _parallel_get_sax_symbols (X , breakpoints ):
67+ n_cases , n_channels , n_timepoints = X .shape
68+ X_new = np .zeros ((n_cases , n_channels , n_timepoints ), dtype = np .intp )
69+ n_break = breakpoints .shape [0 ] - 1
70+ for i_x in prange (n_cases ):
71+ for i_c in prange (n_channels ):
72+ for i_b in prange (n_break ):
73+ mask = np .where (
74+ (X [i_x , i_c ] >= breakpoints [i_b ])
75+ & (X [i_x , i_c ] < breakpoints [i_b + 1 ])
76+ )[0 ]
77+ X_new [i_x , i_c , mask ] += np .array (i_b ).astype (np .intp )
78+
79+ return X_new
80+
81+
82+ # Test to verify identical output
83+ if __name__ == "__main__" :
84+ x = np .array ([[[0.2 , 6.4 , 3.0 , 1.6 ]]])
85+ bins = np .array ([0.0 , 1.0 , 2.5 , 4.0 , 10.0 ])
86+
87+ print ("Original:" , np .digitize (x , bins ))
88+ print ("Numba: " , numba_digitize_parallel (x , bins ))
89+ print ("Match: " , np .array_equal (np .digitize (x , bins ), numba_digitize_parallel (x , bins )))
90+
91+ print ("Curr: " , _parallel_get_sax_symbols (x , bins ))
92+
93+ # Test with right=True
94+ print ("\n With right=True:" )
95+ print ("Original:" , np .digitize (x , bins , right = True ))
96+ print ("Numba: " , numba_digitize_parallel (x , bins , right = True ))
97+ print ("Match: " , np .array_equal (np .digitize (x , bins , right = True ), numba_digitize_parallel (x , bins , right = True )))
98+ print ("Curr: " , _parallel_get_sax_symbols (x , bins ))
0 commit comments