Skip to content

Commit f36f223

Browse files
committed
fix
1 parent cea39a5 commit f36f223

File tree

3 files changed

+101
-2
lines changed

3 files changed

+101
-2
lines changed

aeon/transformations/collection/dictionary_based/_paa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ def _parallel_paa_transform(X, n_segments, split_segments):
174174
@njit(parallel=True, fastmath=True)
175175
def _parallel_inverse_paa_transform(X, original_length, n_segments, split_segments):
176176
"""Parallelize the inverse PAA transformation for cases where the series length is not
177-
divisible by the number of segments."""
177+
divisible by the number of segments.
178+
"""
178179
n_samples, n_channels, _ = X.shape
179180
X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length))
180181

aeon/transformations/collection/dictionary_based/_sax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid):
299299

300300

301301
@njit(fastmath=True, cache=True, parallel=True)
302-
def _parallel_get_sax_symbols(x, bins, right=False):
302+
def _parallel_get_sax_symbols(x, breakpoints, right=False):
303303
"""Parallel version of `np.digitize`."""
304304
x_flat = x.flatten()
305305
result = np.empty(x_flat.shape[0], dtype=np.intp)

test.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
from numba import njit, prange
3+
4+
# @njit(fastmath=True, cache=True)
5+
# def numba_digitize(x, bins, right=False):
6+
# """
7+
# Numba implementation that produces identical output to np.digitize.
8+
# """
9+
# x_flat = x.flatten()
10+
# result = np.empty(x_flat.shape[0], dtype=np.intp)
11+
12+
# for i in range(x_flat.shape[0]):
13+
# val = x_flat[i]
14+
# bin_idx = 0
15+
16+
# if right:
17+
# # bins[i] < x <= bins[i+1]
18+
# for j in range(len(bins)):
19+
# if val <= bins[j]:
20+
# bin_idx = j
21+
# break
22+
# bin_idx = j + 1
23+
# else:
24+
# # bins[i] <= x < bins[i+1] (default behavior)
25+
# for j in range(len(bins)):
26+
# if val < bins[j]:
27+
# bin_idx = j
28+
# break
29+
# bin_idx = j + 1
30+
31+
# result[i] = bin_idx
32+
33+
# return result.reshape(x.shape)
34+
35+
@njit(fastmath=True, cache=True, parallel=True)
36+
def numba_digitize_parallel(x, bins, right=False):
37+
"""
38+
Parallel version for better performance on large arrays.
39+
"""
40+
x_flat = x.flatten()
41+
result = np.empty(x_flat.shape[0], dtype=np.intp)
42+
43+
for i in prange(x_flat.shape[0]):
44+
val = x_flat[i]
45+
bin_idx = 0
46+
47+
if right:
48+
for j in range(len(bins)):
49+
if val <= bins[j]:
50+
bin_idx = j
51+
break
52+
bin_idx = j + 1
53+
else:
54+
for j in range(len(bins)):
55+
if val < bins[j]:
56+
bin_idx = j
57+
break
58+
bin_idx = j + 1
59+
60+
result[i] = bin_idx
61+
62+
return result.reshape(x.shape)
63+
64+
65+
@njit(fastmath=True, cache=True, parallel=True)
66+
def _parallel_get_sax_symbols(X, breakpoints):
67+
n_cases, n_channels, n_timepoints = X.shape
68+
X_new = np.zeros((n_cases, n_channels, n_timepoints), dtype=np.intp)
69+
n_break = breakpoints.shape[0] - 1
70+
for i_x in prange(n_cases):
71+
for i_c in prange(n_channels):
72+
for i_b in prange(n_break):
73+
mask = np.where(
74+
(X[i_x, i_c] >= breakpoints[i_b])
75+
& (X[i_x, i_c] < breakpoints[i_b + 1])
76+
)[0]
77+
X_new[i_x, i_c, mask] += np.array(i_b).astype(np.intp)
78+
79+
return X_new
80+
81+
82+
# Test to verify identical output
83+
if __name__ == "__main__":
84+
x = np.array([[[0.2, 6.4, 3.0, 1.6]]])
85+
bins = np.array([0.0, 1.0, 2.5, 4.0, 10.0])
86+
87+
print("Original:", np.digitize(x, bins))
88+
print("Numba: ", numba_digitize_parallel(x, bins))
89+
print("Match: ", np.array_equal(np.digitize(x, bins), numba_digitize_parallel(x, bins)))
90+
91+
print("Curr: ", _parallel_get_sax_symbols(x, bins))
92+
93+
# Test with right=True
94+
print("\nWith right=True:")
95+
print("Original:", np.digitize(x, bins, right=True))
96+
print("Numba: ", numba_digitize_parallel(x, bins, right=True))
97+
print("Match: ", np.array_equal(np.digitize(x, bins, right=True), numba_digitize_parallel(x, bins, right=True)))
98+
print("Curr: ", _parallel_get_sax_symbols(x, bins))

0 commit comments

Comments
 (0)