Skip to content

Commit 3685043

Browse files
baralinepatrickzib
andauthored
[BUG] Fix redcomets bug when using only one sample (#2952)
* Fix * fix exemples * Empty commit for CI * Rework examples and add test * Try fixing example indentation error * Fix examples --------- Co-authored-by: baraline <10759117+baraline@users.noreply.github.com> Co-authored-by: Patrick Schäfer <patrick.schaefer@hu-berlin.de>
1 parent bf628f4 commit 3685043

File tree

5 files changed

+34
-15
lines changed

5 files changed

+34
-15
lines changed

aeon/classification/dictionary_based/_redcomets.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def _fit(self, X, y):
150150
self.sfa_clfs,
151151
self.sax_transforms,
152152
self.sax_clfs,
153-
) = self._build_univariate_ensemble(np.squeeze(X), y)
153+
) = self._build_univariate_ensemble(np.squeeze(X, 1), y)
154154
else: # Multivariate
155155

156156
if self.variant in [1, 2, 3]: # Concatenate
@@ -204,7 +204,7 @@ def _build_univariate_ensemble(self, X, y):
204204

205205
from imblearn.over_sampling import SMOTE, RandomOverSampler
206206

207-
X = Normalizer().fit_transform(X).squeeze()
207+
X = Normalizer().fit_transform(X).squeeze(1)
208208

209209
if self.variant in [1, 2, 3]:
210210
perc_length = self.perc_length / self._n_channels
@@ -391,7 +391,7 @@ def _predict_proba(self, X) -> np.ndarray:
391391
Predicted probabilities using the ordering in ``classes_``.
392392
"""
393393
if X.shape[1] == 1: # Univariate
394-
return self._predict_proba_unvivariate(np.squeeze(X))
394+
return self._predict_proba_unvivariate(np.squeeze(X, 1))
395395
else: # Multivariate
396396
if self.variant in [1, 2, 3]: # Concatenate
397397
X_concat = X.reshape(*X.shape[:-2], -1)
@@ -414,7 +414,7 @@ def _predict_proba_unvivariate(self, X) -> np.ndarray:
414414
2D np.ndarray of shape (n_cases, n_classes_)
415415
Predicted probabilities using the ordering in ``classes_``.
416416
"""
417-
X = Normalizer().fit_transform(X).squeeze()
417+
X = Normalizer().fit_transform(X).squeeze(1)
418418

419419
pred_mat = np.zeros((X.shape[0], self.n_classes_))
420420

@@ -588,7 +588,7 @@ def _parallel_sax(self, sax_transforms, X):
588588
"""
589589

590590
def _sax_wrapper(sax):
591-
return np.squeeze(sax.fit_transform(X))
591+
return np.squeeze(sax.fit_transform(X), 1)
592592

593593
sax_parallel_res = Parallel(n_jobs=self._n_jobs, backend=self.parallel_backend)(
594594
delayed(_sax_wrapper)(sax) for sax in sax_transforms

aeon/distances/mindist/_dft_sfa.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,24 @@ def mindist_dft_sfa_distance(
4545
>>> import numpy as np
4646
>>> from aeon.distances import mindist_dft_sfa_distance
4747
>>> from aeon.transformations.collection.dictionary_based import SFAWhole
48-
>>> x = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
49-
>>> y = np.array([[11, 12, 13, 14, 15, 16, 17, 18, 19, 20]])
48+
>>> x = np.random.rand(1,1,10) # (n_cases, n_channels, n_timepoints)
49+
>>> y = np.random.rand(1,1,10)
5050
>>> transform = SFAWhole(
5151
... word_length=8,
5252
... alphabet_size=8,
5353
... norm=True,
5454
... )
5555
>>> x_sfa, _ = transform.fit_transform(x)
5656
>>> _, y_dft = transform.transform(y)
57-
>>> dist = mindist_dft_sfa_distance(y_dft, x_sfa, transform.breakpoints)
57+
>>> for i in range(x.shape[0]):
58+
... dist = mindist_dft_sfa_distance(y_dft[0], x_sfa[0], transform.breakpoints)
5859
"""
5960
if x_dft.ndim == 1 and y_sfa.ndim == 1:
6061
return _univariate_dft_sfa_distance(x_dft, y_sfa, breakpoints)
61-
raise ValueError("x and y must be 1D")
62+
raise ValueError(
63+
f"x and y must be 1D, but got x of shape {x_dft.shape} and y of shape"
64+
f"{y_sfa.shape}"
65+
)
6266

6367

6468
@njit(cache=True, fastmath=True)

aeon/distances/mindist/_sfa.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,23 @@ def mindist_sfa_distance(
4545
>>> import numpy as np
4646
>>> from aeon.distances import mindist_sfa_distance
4747
>>> from aeon.transformations.collection.dictionary_based import SFAWhole
48-
>>> x = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
49-
>>> y = np.array([[11, 12, 13, 14, 15, 16, 17, 18, 19, 20]])
48+
>>> x = np.random.rand(1,1,10) # (n_cases, n_channels, n_timepoints)
49+
>>> y = np.random.rand(1,1,10)
5050
>>> transform = SFAWhole(
5151
... word_length=8,
5252
... alphabet_size=8,
5353
... norm=True
5454
... )
5555
>>> x_sfa, _ = transform.fit_transform(x)
5656
>>> y_sfa, _ = transform.transform(y)
57-
>>> dist = mindist_sfa_distance(x_sfa, y_sfa, transform.breakpoints)
57+
>>> for i in range(x.shape[0]):
58+
... dist = mindist_sfa_distance(x_sfa[i], y_sfa[i], transform.breakpoints)
5859
"""
5960
if x.ndim == 1 and y.ndim == 1:
6061
return _univariate_sfa_distance(x, y, breakpoints)
61-
raise ValueError("x and y must be 1D")
62+
raise ValueError(
63+
f"x and y must be 1D, but got x of shape {x.shape} and y of shape {y.shape}"
64+
)
6265

6366

6467
@njit(cache=True, fastmath=True)

aeon/distances/tests/test_symbolic_mindist.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aeon.distances.mindist._paa_sax import mindist_paa_sax_distance
99
from aeon.distances.mindist._sax import mindist_sax_distance
1010
from aeon.distances.mindist._sfa import mindist_sfa_distance
11+
from aeon.testing.data_generation import make_example_3d_numpy
1112
from aeon.transformations.collection.dictionary_based import SAX, SFA, SFAFast, SFAWhole
1213

1314

@@ -49,6 +50,18 @@ def test_sax_mindist():
4950
assert mindist_paa_sax <= ed
5051

5152

53+
def test_single_sample():
54+
"""Test the SFA Min-Distance function."""
55+
x, _ = make_example_3d_numpy(n_cases=1, n_channels=1, n_timepoints=10)
56+
y = x + 10
57+
transform = SFAWhole(word_length=8, alphabet_size=8, norm=True)
58+
x_sfa, _ = transform.fit_transform(x)
59+
_, y_dft = transform.transform(y)
60+
for i in range(len(x_sfa)):
61+
dist = mindist_dft_sfa_distance(y_dft[i], x_sfa[i], transform.breakpoints)
62+
assert dist == 0
63+
64+
5265
def test_sfa_mindist():
5366
"""Test the SFA Min-Distance function."""
5467
n_segments = 16

aeon/transformations/collection/dictionary_based/_sfa_fast.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -729,8 +729,7 @@ def transform_words(self, X):
729729
self.alphabet_size,
730730
self.breakpoints,
731731
)
732-
733-
return words.squeeze(), dfts.squeeze()
732+
return words.squeeze(1), dfts.squeeze(1)
734733

735734
@classmethod
736735
def _get_test_params(cls, parameter_set="default"):

0 commit comments

Comments
 (0)