Skip to content

Commit 62c8ddd

Browse files
committed
fix: fix tupleerror
1 parent 02da9e9 commit 62c8ddd

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

sklearnex/neighbors/knn_classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def _save_attributes(self):
240240
self.classes_ = self._onedal_estimator.classes_
241241
self.n_features_in_ = self._onedal_estimator.n_features_in_
242242
self.n_samples_fit_ = self._onedal_estimator.n_samples_fit_
243-
self._fit_X = self._onedal_estimator._fit_X
243+
fit_x = self._onedal_estimator._fit_X
244+
self._fit_X = fit_x[0] if isinstance(fit_x, tuple) else fit_x
244245
self._y = self._onedal_estimator._y
245246
self._fit_method = self._onedal_estimator._fit_method
246247
self.outputs_2d_ = self._onedal_estimator.outputs_2d_

sklearnex/neighbors/knn_regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ def _onedal_score(self, X, y, sample_weight=None, queue=None):
201201
def _save_attributes(self):
202202
self.n_features_in_ = self._onedal_estimator.n_features_in_
203203
self.n_samples_fit_ = self._onedal_estimator.n_samples_fit_
204-
self._fit_X = self._onedal_estimator._fit_X
204+
fit_x = self._onedal_estimator._fit_X
205+
self._fit_X = fit_x[0] if isinstance(fit_x, tuple) else fit_x
205206
self._y = self._onedal_estimator._y
206207
self._fit_method = self._onedal_estimator._fit_method
207208
self._tree = self._onedal_estimator._tree

sklearnex/neighbors/knn_unsupervised.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ def _save_attributes(self):
186186
self.classes_ = self._onedal_estimator.classes_
187187
self.n_features_in_ = self._onedal_estimator.n_features_in_
188188
self.n_samples_fit_ = self._onedal_estimator.n_samples_fit_
189-
self._fit_X = self._onedal_estimator._fit_X
189+
fit_x = self._onedal_estimator._fit_X
190+
self._fit_X = fit_x[0] if isinstance(fit_x, tuple) else fit_x
190191
self._fit_method = self._onedal_estimator._fit_method
191192
self._tree = self._onedal_estimator._tree
192193

0 commit comments

Comments
 (0)