Skip to content

Commit afc3cf8

Browse files
fixes
1 parent cd1b776 commit afc3cf8

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

tsml/tests/_sklearn_checks.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
_is_public_parameter,
4646
_NotAnArray,
4747
_regression_dataset,
48-
check_classifiers_predictions,
4948
check_estimators_data_not_an_array,
5049
)
5150
from sklearn.utils.metaestimators import _safe_split
@@ -1745,6 +1744,65 @@ def check_classifiers_classes(name, classifier_orig):
17451744
check_classifiers_predictions(X_binary, y_binary, name, classifier_orig)
17461745

17471746

1747+
@ignore_warnings
1748+
def check_classifiers_predictions(X, y, name, classifier_orig):
1749+
classes = np.unique(y)
1750+
classifier = clone(classifier_orig)
1751+
if name == "BernoulliNB":
1752+
X = X > X.mean()
1753+
set_random_state(classifier)
1754+
1755+
classifier.fit(X, y)
1756+
y_pred = classifier.predict(X)
1757+
1758+
if hasattr(classifier, "decision_function"):
1759+
decision = classifier.decision_function(X)
1760+
assert isinstance(decision, np.ndarray)
1761+
if len(classes) == 2:
1762+
dec_pred = (decision.ravel() > 0).astype(int)
1763+
dec_exp = classifier.classes_[dec_pred]
1764+
assert_array_equal(
1765+
dec_exp,
1766+
y_pred,
1767+
err_msg=(
1768+
"decision_function does not match "
1769+
"classifier for %r: expected '%s', got '%s'"
1770+
)
1771+
% (
1772+
classifier,
1773+
", ".join(map(str, dec_exp)),
1774+
", ".join(map(str, y_pred)),
1775+
),
1776+
)
1777+
elif getattr(classifier, "decision_function_shape", "ovr") == "ovr":
1778+
decision_y = np.argmax(decision, axis=1).astype(int)
1779+
y_exp = classifier.classes_[decision_y]
1780+
assert_array_equal(
1781+
y_exp,
1782+
y_pred,
1783+
err_msg=(
1784+
"decision_function does not match "
1785+
"classifier for %r: expected '%s', got '%s'"
1786+
)
1787+
% (
1788+
classifier,
1789+
", ".join(map(str, y_exp)),
1790+
", ".join(map(str, y_pred)),
1791+
),
1792+
)
1793+
1794+
assert_array_equal(
1795+
classes,
1796+
classifier.classes_,
1797+
err_msg="Unexpected classes_ attribute for %r: expected '%s', got '%s'"
1798+
% (
1799+
classifier,
1800+
", ".join(map(str, classes)),
1801+
", ".join(map(str, classifier.classes_)),
1802+
),
1803+
)
1804+
1805+
17481806
@ignore_warnings(category=FutureWarning)
17491807
def check_regressors_int(name, regressor_orig):
17501808
"""

tsml/transformations/interval_extraction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@ class SupervisedIntervalTransformer(TransformerMixin, BaseTimeSeriesEstimator):
552552
.. [2] Cabello, N., Naghizade, E., Qi, J. and Kulik, L., 2021. Fast, accurate and
553553
interpretable time series classification through randomization. arXiv preprint
554554
arXiv:2105.14876.
555+
556+
Examples
555557
"""
556558

557559
def __init__(

tsml/utils/numba_functions/stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def count_mean_crossing(X: np.ndarray) -> float:
115115
>>> c = count_mean_crossing(X)
116116
"""
117117
m = mean(X)
118-
d = general_numba.first_order_differences(X > m)
118+
d = general_numba.first_order_differences((X > m).astype(np.int32))
119119
count = 0
120120
for i in range(d.shape[0]):
121121
if d[i] != 0:

0 commit comments

Comments
 (0)