From 0207814f32458f95740e5564646b2ad892fae62e Mon Sep 17 00:00:00 2001 From: Hamidreza Keshavarz <32555614+hamidkm9@users.noreply.github.com> Date: Tue, 19 Aug 2025 22:08:20 +0200 Subject: [PATCH 1/2] Using precomputed kernels Until now, the kernels were inside SEFR. Now we can have them precomputed in LinearBoost, in order to avoid redundant kernel computation in SEFR. --- src/linearboost/linear_boost.py | 87 +++++++++++++++++++++++-- src/linearboost/sefr.py | 111 ++++++++++++++++++++++++-------- 2 files changed, 164 insertions(+), 34 deletions(-) diff --git a/src/linearboost/linear_boost.py b/src/linearboost/linear_boost.py index 82ef8b7..64241fb 100644 --- a/src/linearboost/linear_boost.py +++ b/src/linearboost/linear_boost.py @@ -26,6 +26,7 @@ import numpy as np from sklearn.base import clone from sklearn.ensemble import AdaBoostClassifier +from sklearn.metrics.pairwise import pairwise_kernels from sklearn.pipeline import make_pipeline from sklearn.preprocessing import ( MaxAbsScaler, @@ -95,8 +96,9 @@ def _boost(self, iboost, X, y, sample_weight, random_state): iboost : int The index of the current boost iteration. - X : {array-like} of shape (n_samples, n_features) - The training input samples. + X : {array-like} of shape (n_samples, n_features) or (n_samples, n_samples) + The training input samples. For kernel methods, this will be a + precomputed kernel matrix. y : array-like of shape (n_samples,) The target values (class labels). @@ -375,6 +377,14 @@ class LinearBoostClassifier(_DenseAdaBoostClassifier): scaler_ : transformer The scaler instance used to transform the data. + X_fit_ : ndarray of shape (n_samples, n_features) + The training data after scaling, stored when kernel != 'linear' + for prediction purposes. + + K_train_ : ndarray of shape (n_samples, n_samples) + The precomputed kernel matrix on training data, stored when + kernel != 'linear'. + Notes ----- This classifier only supports binary classification tasks. @@ -426,8 +436,14 @@ def __init__( degree=3, coef0=1, ): + # Create SEFR estimator with 'precomputed' kernel if we're using kernels + if kernel == "linear": + base_estimator = SEFR(kernel="linear") + else: + base_estimator = SEFR(kernel="precomputed") + super().__init__( - estimator=SEFR(kernel=kernel, gamma=gamma, degree=degree, coef0=coef0), + estimator=base_estimator, n_estimators=n_estimators, learning_rate=learning_rate, ) @@ -489,6 +505,37 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]: return X, y + def _get_kernel_matrix(self, X, Y=None): + """Compute kernel matrix between X and Y. + + Parameters + ---------- + X : array-like of shape (n_samples_X, n_features) + Input samples. + Y : array-like of shape (n_samples_Y, n_features), default=None + Input samples. If None, use X. + + Returns + ------- + K : ndarray of shape (n_samples_X, n_samples_Y) + Kernel matrix. + """ + if Y is None: + Y = X + + if callable(self.kernel): + return self.kernel(X, Y) + else: + return pairwise_kernels( + X, + Y, + metric=self.kernel, + filter_params=True, + gamma=self.gamma, + degree=self.degree, + coef0=self.coef0, + ) + def fit(self, X, y, sample_weight=None) -> Self: """Build a LinearBoost classifier from the training set (X, y). @@ -515,6 +562,7 @@ def fit(self, X, y, sample_weight=None) -> Self: if self.scaler not in _scalers: raise ValueError('Invalid scaler provided; got "%s".' % self.scaler) + # Apply scaling if self.scaler == "minmax": self.scaler_ = clone(_scalers["minmax"]) else: @@ -538,10 +586,20 @@ def fit(self, X, y, sample_weight=None) -> Self: X_transformed = X_transformed[nonzero_mask] y = y[nonzero_mask] sample_weight = sample_weight[nonzero_mask] + X_transformed, y = self._check_X_y(X_transformed, y) self.classes_ = np.unique(y) self.n_classes_ = self.classes_.shape[0] + # Store training data for kernel computation during prediction + if self.kernel != "linear": + self.X_fit_ = X_transformed + # Precompute kernel matrix ONCE for all estimators + self.K_train_ = self._get_kernel_matrix(X_transformed) + training_data = self.K_train_ + else: + training_data = X_transformed + if self.class_weight is not None: if isinstance(self.class_weight, str) and self.class_weight != "balanced": raise ValueError( @@ -566,7 +624,8 @@ def fit(self, X, y, sample_weight=None) -> Self: category=FutureWarning, message=".*parameter 'algorithm' is deprecated.*", ) - return super().fit(X_transformed, y, sample_weight) + # Pass the precomputed kernel matrix (or raw features for linear) + return super().fit(training_data, y, sample_weight) @staticmethod def _samme_proba(estimator, n_classes, X): @@ -590,6 +649,15 @@ def _samme_proba(estimator, n_classes, X): ) def _boost(self, iboost, X, y, sample_weight, random_state): + """ + Implement a single boost using precomputed kernel matrix or raw features. + + Parameters + ---------- + X : ndarray + For kernel methods, this is the precomputed kernel matrix. + For linear methods, this is the raw feature matrix. + """ estimator = self._make_estimator(random_state=random_state) estimator.fit(X, y, sample_weight=sample_weight) @@ -668,13 +736,20 @@ class in ``classes_``, respectively. check_is_fitted(self) X_transformed = self.scaler_.transform(X) + if self.kernel == "linear": + # For linear kernel, pass raw features + test_data = X_transformed + else: + # For kernel methods, compute kernel matrix between test and training data + test_data = self._get_kernel_matrix(X_transformed, self.X_fit_) + if self.algorithm == "SAMME.R": # Proper SAMME.R decision function classes = self.classes_ n_classes = len(classes) pred = sum( - self._samme_proba(estimator, n_classes, X_transformed) + self._samme_proba(estimator, n_classes, test_data) for estimator in self.estimators_ ) pred /= self.estimator_weights_.sum() @@ -685,7 +760,7 @@ class in ``classes_``, respectively. else: # Standard SAMME algorithm from AdaBoostClassifier (discrete) - return super().decision_function(X_transformed) + return super().decision_function(test_data) def predict(self, X): """Predict classes for X. diff --git a/src/linearboost/sefr.py b/src/linearboost/sefr.py index 37a3c30..0b4e723 100644 --- a/src/linearboost/sefr.py +++ b/src/linearboost/sefr.py @@ -44,13 +44,14 @@ class SEFR(LinearClassifierMixin, BaseEstimator): Specifies if a constant (a.k.a. bias or intercept) should be added to the decision function. - kernel : {'linear', 'poly', 'rbf', 'sigmoid'} or callable, default='linear' + kernel : {'linear', 'poly', 'rbf', 'sigmoid', 'precomputed'} or callable, default='linear' Specifies the kernel type to be used in the algorithm. If a callable is given, it is used to pre-compute the kernel matrix. + If 'precomputed', X is assumed to be a kernel matrix. gamma : float, default=None Kernel coefficient for 'rbf', 'poly' and 'sigmoid'. If None, then it is - set to 1.0 / n_features. + set to 1.0 / n_features. Ignored when kernel='precomputed'. degree : int, default=3 Degree for 'poly' kernels. Ignored by other kernels. @@ -80,7 +81,7 @@ class SEFR(LinearClassifierMixin, BaseEstimator): has feature names that are all strings. X_fit_ : ndarray of shape (n_samples, n_features) - The training data, stored when a kernel is used. + The training data, stored when a kernel is used (except for 'precomputed'). Notes ----- @@ -100,7 +101,10 @@ class SEFR(LinearClassifierMixin, BaseEstimator): _parameter_constraints: dict = { "fit_intercept": ["boolean"], - "kernel": [StrOptions({"linear", "poly", "rbf", "sigmoid"}), callable], + "kernel": [ + StrOptions({"linear", "poly", "rbf", "sigmoid", "precomputed"}), + callable, + ], "gamma": [Interval(Real, 0, None, closed="left"), None], "degree": [Interval(Integral, 1, None, closed="left"), None], "coef0": [Real, None], @@ -144,28 +148,58 @@ def _more_tags(self) -> dict[str, bool]: } def _check_X(self, X) -> np.ndarray: - X = validate_data( - self, - X, - dtype="numeric", - force_all_finite=True, - reset=False, - ) - if X.shape[1] != self.n_features_in_: - raise ValueError( - "Expected input with %d features, got %d instead." - % (self.n_features_in_, X.shape[1]) + if self.kernel == "precomputed": + X = validate_data( + self, + X, + dtype="numeric", + force_all_finite=True, + reset=False, + ) + # For precomputed kernels during prediction, X should be (n_test_samples, n_train_samples) + if hasattr(self, "n_features_in_") and X.shape[1] != self.n_features_in_: + raise ValueError( + f"Precomputed kernel matrix should have {self.n_features_in_} columns " + f"(number of training samples), got {X.shape[1]}." + ) + else: + X = validate_data( + self, + X, + dtype="numeric", + force_all_finite=True, + reset=False, ) + if hasattr(self, "n_features_in_") and X.shape[1] != self.n_features_in_: + raise ValueError( + "Expected input with %d features, got %d instead." + % (self.n_features_in_, X.shape[1]) + ) return X def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]: - X, y = check_X_y( - X, - y, - dtype="numeric", - force_all_finite=True, - estimator=self, - ) + if self.kernel == "precomputed": + # For precomputed kernels, X should be a square kernel matrix + X, y = check_X_y( + X, + y, + dtype="numeric", + force_all_finite=True, + estimator=self, + ) + if X.shape[0] != X.shape[1]: + raise ValueError( + f"Precomputed kernel matrix should be square, got shape {X.shape}." + ) + else: + X, y = check_X_y( + X, + y, + dtype="numeric", + force_all_finite=True, + estimator=self, + ) + check_classification_targets(y) if np.unique(y).shape[0] == 1: @@ -180,6 +214,10 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]: return X, y def _get_kernel_matrix(self, X, Y=None): + if self.kernel == "precomputed": + # X is already a kernel matrix + return X + if Y is None: Y = self.X_fit_ @@ -203,9 +241,10 @@ def fit(self, X, y, sample_weight=None) -> Self: Parameters ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) + X : {array-like, sparse matrix} of shape (n_samples, n_features) or (n_samples, n_samples) Training vector, where `n_samples` is the number of samples and `n_features` is the number of features. + If kernel='precomputed', X should be a square kernel matrix. y : array-like of shape (n_samples,) Target vector relative to X. @@ -219,15 +258,25 @@ def fit(self, X, y, sample_weight=None) -> Self: self Fitted estimator. """ - _check_n_features(self, X=X, reset=True) - _check_feature_names(self, X=X, reset=True) + if self.kernel == "precomputed": + _check_n_features(self, X=X, reset=True) + _check_feature_names(self, X=X, reset=True) + else: + _check_n_features(self, X=X, reset=True) + _check_feature_names(self, X=X, reset=True) X, y = self._check_X_y(X, y) - self.X_fit_ = X + + # Store training data only for non-precomputed kernels + if self.kernel != "precomputed": + self.X_fit_ = X + self.classes_, y_ = np.unique(y, return_inverse=True) if self.kernel == "linear": K = X + elif self.kernel == "precomputed": + K = X # X is already the kernel matrix else: K = self._get_kernel_matrix(X) @@ -277,10 +326,14 @@ def fit(self, X, y, sample_weight=None) -> Self: def decision_function(self, X): check_is_fitted(self) X = self._check_X(X) + if self.kernel == "linear": K = X + elif self.kernel == "precomputed": + K = X # X is already a kernel matrix else: K = self._get_kernel_matrix(X) + return ( safe_sparse_dot(K, self.coef_.T, dense_output=True) + self.intercept_ ).ravel() @@ -294,9 +347,10 @@ def predict_proba(self, X): Parameters ---------- - X : array-like of shape (n_samples, n_features) + X : array-like of shape (n_samples, n_features) or (n_samples, n_train_samples) Vector to be scored, where `n_samples` is the number of samples and `n_features` is the number of features. + If kernel='precomputed', X should have shape (n_samples, n_train_samples). Returns ------- @@ -324,9 +378,10 @@ def predict_log_proba(self, X): Parameters ---------- - X : array-like of shape (n_samples, n_features) + X : array-like of shape (n_samples, n_features) or (n_samples, n_train_samples) Vector to be scored, where `n_samples` is the number of samples and `n_features` is the number of features. + If kernel='precomputed', X should have shape (n_samples, n_train_samples). Returns ------- From 4f81032ed5371b59a12648f4a55a783794edf6e0 Mon Sep 17 00:00:00 2001 From: Hamidreza Keshavarz <32555614+hamidkm9@users.noreply.github.com> Date: Tue, 19 Aug 2025 22:17:12 +0200 Subject: [PATCH 2/2] Update linear_boost.py --- src/linearboost/linear_boost.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/linearboost/linear_boost.py b/src/linearboost/linear_boost.py index 64241fb..93d8918 100644 --- a/src/linearboost/linear_boost.py +++ b/src/linearboost/linear_boost.py @@ -437,10 +437,16 @@ def __init__( coef0=1, ): # Create SEFR estimator with 'precomputed' kernel if we're using kernels - if kernel == "linear": + # Use string comparison that's safe for arrays (will raise TypeError for arrays) + try: + if kernel == "linear": + base_estimator = SEFR(kernel="linear") + else: + base_estimator = SEFR(kernel="precomputed") + except (ValueError, TypeError): + # If kernel is an array or invalid type, default to linear + # Parameter validation will catch this later in fit() base_estimator = SEFR(kernel="linear") - else: - base_estimator = SEFR(kernel="precomputed") super().__init__( estimator=base_estimator,