|
45 | 45 | _is_public_parameter, |
46 | 46 | _NotAnArray, |
47 | 47 | _regression_dataset, |
48 | | - check_classifiers_predictions, |
49 | 48 | check_estimators_data_not_an_array, |
50 | 49 | ) |
51 | 50 | from sklearn.utils.metaestimators import _safe_split |
@@ -1745,6 +1744,65 @@ def check_classifiers_classes(name, classifier_orig): |
1745 | 1744 | check_classifiers_predictions(X_binary, y_binary, name, classifier_orig) |
1746 | 1745 |
|
1747 | 1746 |
|
| 1747 | +@ignore_warnings |
| 1748 | +def check_classifiers_predictions(X, y, name, classifier_orig): |
| 1749 | + classes = np.unique(y) |
| 1750 | + classifier = clone(classifier_orig) |
| 1751 | + if name == "BernoulliNB": |
| 1752 | + X = X > X.mean() |
| 1753 | + set_random_state(classifier) |
| 1754 | + |
| 1755 | + classifier.fit(X, y) |
| 1756 | + y_pred = classifier.predict(X) |
| 1757 | + |
| 1758 | + if hasattr(classifier, "decision_function"): |
| 1759 | + decision = classifier.decision_function(X) |
| 1760 | + assert isinstance(decision, np.ndarray) |
| 1761 | + if len(classes) == 2: |
| 1762 | + dec_pred = (decision.ravel() > 0).astype(int) |
| 1763 | + dec_exp = classifier.classes_[dec_pred] |
| 1764 | + assert_array_equal( |
| 1765 | + dec_exp, |
| 1766 | + y_pred, |
| 1767 | + err_msg=( |
| 1768 | + "decision_function does not match " |
| 1769 | + "classifier for %r: expected '%s', got '%s'" |
| 1770 | + ) |
| 1771 | + % ( |
| 1772 | + classifier, |
| 1773 | + ", ".join(map(str, dec_exp)), |
| 1774 | + ", ".join(map(str, y_pred)), |
| 1775 | + ), |
| 1776 | + ) |
| 1777 | + elif getattr(classifier, "decision_function_shape", "ovr") == "ovr": |
| 1778 | + decision_y = np.argmax(decision, axis=1).astype(int) |
| 1779 | + y_exp = classifier.classes_[decision_y] |
| 1780 | + assert_array_equal( |
| 1781 | + y_exp, |
| 1782 | + y_pred, |
| 1783 | + err_msg=( |
| 1784 | + "decision_function does not match " |
| 1785 | + "classifier for %r: expected '%s', got '%s'" |
| 1786 | + ) |
| 1787 | + % ( |
| 1788 | + classifier, |
| 1789 | + ", ".join(map(str, y_exp)), |
| 1790 | + ", ".join(map(str, y_pred)), |
| 1791 | + ), |
| 1792 | + ) |
| 1793 | + |
| 1794 | + assert_array_equal( |
| 1795 | + classes, |
| 1796 | + classifier.classes_, |
| 1797 | + err_msg="Unexpected classes_ attribute for %r: expected '%s', got '%s'" |
| 1798 | + % ( |
| 1799 | + classifier, |
| 1800 | + ", ".join(map(str, classes)), |
| 1801 | + ", ".join(map(str, classifier.classes_)), |
| 1802 | + ), |
| 1803 | + ) |
| 1804 | + |
| 1805 | + |
1748 | 1806 | @ignore_warnings(category=FutureWarning) |
1749 | 1807 | def check_regressors_int(name, regressor_orig): |
1750 | 1808 | """ |
|
0 commit comments