Skip to content

Commit 0e8b4c6

Browse files
committed
fix: try it again
1 parent d17bb34 commit 0e8b4c6

File tree

2 files changed

+5
-73
lines changed

2 files changed

+5
-73
lines changed

onedal/neighbors/neighbors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,4 +691,4 @@ def fit(self, X, y, queue=None):
691691

692692
@supports_queue
693693
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
694-
return self._kneighbors(X, n_neighbors, return_distance)
694+
return self._kneighbors(X, n_neighbors, return_distance)

sklearnex/neighbors/common.py

Lines changed: 4 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -31,81 +31,14 @@
3131
from .._utils import PatchingConditionsChain
3232
from ..base import oneDALEstimator
3333
from ..utils._array_api import get_namespace
34+
from ..utils.validation import check_feature_names
3435

3536

3637
class KNeighborsDispatchingBase(oneDALEstimator):
37-
38-
def _parse_auto_method(self, method, n_samples, n_features):
39-
"""Parse auto method selection for neighbors algorithm."""
40-
result_method = method
41-
42-
if method in ["auto", "ball_tree"]:
43-
condition = (
44-
self.n_neighbors is not None and self.n_neighbors >= n_samples // 2
45-
)
46-
if self.metric == "precomputed" or n_features > 15 or condition:
47-
result_method = "brute"
48-
else:
49-
if self.metric == "euclidean":
50-
result_method = "kd_tree"
51-
else:
52-
result_method = "brute"
53-
54-
return result_method
55-
56-
def _get_weights(self, dist, weights):
57-
"""Get weights for neighbors based on distance and weights parameter."""
58-
if weights in (None, "uniform"):
59-
return None
60-
if weights == "distance":
61-
# if user attempts to classify a point that was zero distance from one
62-
# or more training points, those training points are weighted as 1.0
63-
# and the other points as 0.0
64-
if dist.dtype is np.dtype(object):
65-
for point_dist_i, point_dist in enumerate(dist):
66-
# check if point_dist is iterable
67-
# (ex: RadiusNeighborClassifier.predict may set an element of
68-
# dist to 1e-6 to represent an 'outlier')
69-
if hasattr(point_dist, "__contains__") and 0.0 in point_dist:
70-
dist[point_dist_i] = point_dist == 0.0
71-
else:
72-
dist[point_dist_i] = 1.0 / point_dist
73-
else:
74-
with np.errstate(divide="ignore"):
75-
dist = 1.0 / dist
76-
inf_mask = np.isinf(dist)
77-
inf_row = np.any(inf_mask, axis=1)
78-
dist[inf_row] = inf_mask[inf_row]
79-
return dist
80-
elif callable(weights):
81-
return weights(dist)
82-
else:
83-
raise ValueError(
84-
"weights not recognized: should be 'uniform', "
85-
"'distance', or a callable function"
86-
)
87-
88-
def _validate_targets(self, y, dtype):
89-
"""Validate and convert target values."""
90-
from onedal.utils.validation import _column_or_1d
91-
arr = _column_or_1d(y, warn=True)
92-
93-
try:
94-
return arr.astype(dtype, copy=False)
95-
except ValueError:
96-
return arr
97-
98-
def _validate_n_classes(self):
99-
"""Validate that we have at least 2 classes for classification."""
100-
length = 0 if self.classes_ is None else len(self.classes_)
101-
if length < 2:
102-
raise ValueError(
103-
f"The number of classes has to be greater than one; got {length}"
104-
)
10538
def _fit_validation(self, X, y=None):
10639
if sklearn_check_version("1.2"):
10740
self._validate_params()
108-
41+
check_feature_names(self, X, reset=True)
10942
if self.metric_params is not None and "p" in self.metric_params:
11043
if self.p is not None:
11144
warnings.warn(
@@ -134,9 +67,8 @@ def _fit_validation(self, X, y=None):
13467
self.effective_metric_ = "chebyshev"
13568

13669
if not isinstance(X, (KDTree, BallTree, _sklearn_NeighborsBase)):
137-
xp, _ = get_namespace(X)
13870
self._fit_X = _check_array(
139-
X, dtype=[xp.float64, xp.float32], accept_sparse=True
71+
X, dtype=[np.float64, np.float32], accept_sparse=True
14072
)
14173
self.n_samples_fit_ = _num_samples(self._fit_X)
14274
self.n_features_in_ = _num_features(self._fit_X)
@@ -378,4 +310,4 @@ def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
378310

379311
return kneighbors_graph
380312

381-
kneighbors_graph.__doc__ = KNeighborsMixin.kneighbors_graph.__doc__
313+
kneighbors_graph.__doc__ = KNeighborsMixin.kneighbors_graph.__doc__

0 commit comments

Comments
 (0)