88# """
99# x_flat = x.flatten()
1010# result = np.empty(x_flat.shape[0], dtype=np.intp)
11-
11+
1212# for i in range(x_flat.shape[0]):
1313# val = x_flat[i]
1414# bin_idx = 0
15-
15+
1616# if right:
1717# # bins[i] < x <= bins[i+1]
1818# for j in range(len(bins)):
2727# bin_idx = j
2828# break
2929# bin_idx = j + 1
30-
30+
3131# result[i] = bin_idx
32-
32+
3333# return result.reshape(x.shape)
3434
35+
3536@njit (fastmath = True , cache = True , parallel = True )
3637def numba_digitize_parallel (x , bins , right = False ):
3738 """
3839 Parallel version for better performance on large arrays.
3940 """
4041 x_flat = x .flatten ()
4142 result = np .empty (x_flat .shape [0 ], dtype = np .intp )
42-
43+
4344 for i in prange (x_flat .shape [0 ]):
4445 val = x_flat [i ]
4546 bin_idx = 0
46-
47+
4748 if right :
4849 for j in range (len (bins )):
4950 if val <= bins [j ]:
@@ -56,9 +57,9 @@ def numba_digitize_parallel(x, bins, right=False):
5657 bin_idx = j
5758 break
5859 bin_idx = j + 1
59-
60+
6061 result [i ] = bin_idx
61-
62+
6263 return result .reshape (x .shape )
6364
6465
@@ -83,16 +84,25 @@ def _parallel_get_sax_symbols(X, breakpoints):
8384if __name__ == "__main__" :
8485 x = np .array ([[[0.2 , 6.4 , 3.0 , 1.6 ]]])
8586 bins = np .array ([0.0 , 1.0 , 2.5 , 4.0 , 10.0 ])
86-
87+
8788 print ("Original:" , np .digitize (x , bins ))
8889 print ("Numba: " , numba_digitize_parallel (x , bins ))
89- print ("Match: " , np .array_equal (np .digitize (x , bins ), numba_digitize_parallel (x , bins )))
90+ print (
91+ "Match: " ,
92+ np .array_equal (np .digitize (x , bins ), numba_digitize_parallel (x , bins )),
93+ )
9094
9195 print ("Curr: " , _parallel_get_sax_symbols (x , bins ))
92-
96+
9397 # Test with right=True
9498 print ("\n With right=True:" )
9599 print ("Original:" , np .digitize (x , bins , right = True ))
96100 print ("Numba: " , numba_digitize_parallel (x , bins , right = True ))
97- print ("Match: " , np .array_equal (np .digitize (x , bins , right = True ), numba_digitize_parallel (x , bins , right = True )))
98- print ("Curr: " , _parallel_get_sax_symbols (x , bins ))
101+ print (
102+ "Match: " ,
103+ np .array_equal (
104+ np .digitize (x , bins , right = True ),
105+ numba_digitize_parallel (x , bins , right = True ),
106+ ),
107+ )
108+ print ("Curr: " , _parallel_get_sax_symbols (x , bins ))
0 commit comments