Skip to content

Commit b9aa22d

Browse files
aadya940github-actions[bot]
authored andcommitted
Automatic pre-commit fixes
1 parent f36f223 commit b9aa22d

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

aeon/transformations/collection/dictionary_based/_paa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
__maintainer__ = []
44

55
import numpy as np
6-
from numba import njit, prange, get_num_threads, set_num_threads
6+
from numba import get_num_threads, njit, prange, set_num_threads
77

88
from aeon.transformations.collection import BaseCollectionTransformer
99
from aeon.utils.validation import check_n_jobs

test.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
# """
99
# x_flat = x.flatten()
1010
# result = np.empty(x_flat.shape[0], dtype=np.intp)
11-
11+
1212
# for i in range(x_flat.shape[0]):
1313
# val = x_flat[i]
1414
# bin_idx = 0
15-
15+
1616
# if right:
1717
# # bins[i] < x <= bins[i+1]
1818
# for j in range(len(bins)):
@@ -27,23 +27,24 @@
2727
# bin_idx = j
2828
# break
2929
# bin_idx = j + 1
30-
30+
3131
# result[i] = bin_idx
32-
32+
3333
# return result.reshape(x.shape)
3434

35+
3536
@njit(fastmath=True, cache=True, parallel=True)
3637
def numba_digitize_parallel(x, bins, right=False):
3738
"""
3839
Parallel version for better performance on large arrays.
3940
"""
4041
x_flat = x.flatten()
4142
result = np.empty(x_flat.shape[0], dtype=np.intp)
42-
43+
4344
for i in prange(x_flat.shape[0]):
4445
val = x_flat[i]
4546
bin_idx = 0
46-
47+
4748
if right:
4849
for j in range(len(bins)):
4950
if val <= bins[j]:
@@ -56,9 +57,9 @@ def numba_digitize_parallel(x, bins, right=False):
5657
bin_idx = j
5758
break
5859
bin_idx = j + 1
59-
60+
6061
result[i] = bin_idx
61-
62+
6263
return result.reshape(x.shape)
6364

6465

@@ -83,16 +84,25 @@ def _parallel_get_sax_symbols(X, breakpoints):
8384
if __name__ == "__main__":
8485
x = np.array([[[0.2, 6.4, 3.0, 1.6]]])
8586
bins = np.array([0.0, 1.0, 2.5, 4.0, 10.0])
86-
87+
8788
print("Original:", np.digitize(x, bins))
8889
print("Numba: ", numba_digitize_parallel(x, bins))
89-
print("Match: ", np.array_equal(np.digitize(x, bins), numba_digitize_parallel(x, bins)))
90+
print(
91+
"Match: ",
92+
np.array_equal(np.digitize(x, bins), numba_digitize_parallel(x, bins)),
93+
)
9094

9195
print("Curr: ", _parallel_get_sax_symbols(x, bins))
92-
96+
9397
# Test with right=True
9498
print("\nWith right=True:")
9599
print("Original:", np.digitize(x, bins, right=True))
96100
print("Numba: ", numba_digitize_parallel(x, bins, right=True))
97-
print("Match: ", np.array_equal(np.digitize(x, bins, right=True), numba_digitize_parallel(x, bins, right=True)))
98-
print("Curr: ", _parallel_get_sax_symbols(x, bins))
101+
print(
102+
"Match: ",
103+
np.array_equal(
104+
np.digitize(x, bins, right=True),
105+
numba_digitize_parallel(x, bins, right=True),
106+
),
107+
)
108+
print("Curr: ", _parallel_get_sax_symbols(x, bins))

0 commit comments

Comments
 (0)