3131from .._utils import PatchingConditionsChain
3232from ..base import oneDALEstimator
3333from ..utils ._array_api import get_namespace
34+ from ..utils .validation import check_feature_names
3435
3536
3637class KNeighborsDispatchingBase (oneDALEstimator ):
37-
38- def _parse_auto_method (self , method , n_samples , n_features ):
39- """Parse auto method selection for neighbors algorithm."""
40- result_method = method
41-
42- if method in ["auto" , "ball_tree" ]:
43- condition = (
44- self .n_neighbors is not None and self .n_neighbors >= n_samples // 2
45- )
46- if self .metric == "precomputed" or n_features > 15 or condition :
47- result_method = "brute"
48- else :
49- if self .metric == "euclidean" :
50- result_method = "kd_tree"
51- else :
52- result_method = "brute"
53-
54- return result_method
55-
56- def _get_weights (self , dist , weights ):
57- """Get weights for neighbors based on distance and weights parameter."""
58- if weights in (None , "uniform" ):
59- return None
60- if weights == "distance" :
61- # if user attempts to classify a point that was zero distance from one
62- # or more training points, those training points are weighted as 1.0
63- # and the other points as 0.0
64- if dist .dtype is np .dtype (object ):
65- for point_dist_i , point_dist in enumerate (dist ):
66- # check if point_dist is iterable
67- # (ex: RadiusNeighborClassifier.predict may set an element of
68- # dist to 1e-6 to represent an 'outlier')
69- if hasattr (point_dist , "__contains__" ) and 0.0 in point_dist :
70- dist [point_dist_i ] = point_dist == 0.0
71- else :
72- dist [point_dist_i ] = 1.0 / point_dist
73- else :
74- with np .errstate (divide = "ignore" ):
75- dist = 1.0 / dist
76- inf_mask = np .isinf (dist )
77- inf_row = np .any (inf_mask , axis = 1 )
78- dist [inf_row ] = inf_mask [inf_row ]
79- return dist
80- elif callable (weights ):
81- return weights (dist )
82- else :
83- raise ValueError (
84- "weights not recognized: should be 'uniform', "
85- "'distance', or a callable function"
86- )
87-
88- def _validate_targets (self , y , dtype ):
89- """Validate and convert target values."""
90- from onedal .utils .validation import _column_or_1d
91- arr = _column_or_1d (y , warn = True )
92-
93- try :
94- return arr .astype (dtype , copy = False )
95- except ValueError :
96- return arr
97-
98- def _validate_n_classes (self ):
99- """Validate that we have at least 2 classes for classification."""
100- length = 0 if self .classes_ is None else len (self .classes_ )
101- if length < 2 :
102- raise ValueError (
103- f"The number of classes has to be greater than one; got { length } "
104- )
10538 def _fit_validation (self , X , y = None ):
10639 if sklearn_check_version ("1.2" ):
10740 self ._validate_params ()
108-
41+ check_feature_names ( self , X , reset = True )
10942 if self .metric_params is not None and "p" in self .metric_params :
11043 if self .p is not None :
11144 warnings .warn (
@@ -134,9 +67,8 @@ def _fit_validation(self, X, y=None):
13467 self .effective_metric_ = "chebyshev"
13568
13669 if not isinstance (X , (KDTree , BallTree , _sklearn_NeighborsBase )):
137- xp , _ = get_namespace (X )
13870 self ._fit_X = _check_array (
139- X , dtype = [xp .float64 , xp .float32 ], accept_sparse = True
71+ X , dtype = [np .float64 , np .float32 ], accept_sparse = True
14072 )
14173 self .n_samples_fit_ = _num_samples (self ._fit_X )
14274 self .n_features_in_ = _num_features (self ._fit_X )
@@ -378,4 +310,4 @@ def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
378310
379311 return kneighbors_graph
380312
381- kneighbors_graph .__doc__ = KNeighborsMixin .kneighbors_graph .__doc__
313+ kneighbors_graph .__doc__ = KNeighborsMixin .kneighbors_graph .__doc__
0 commit comments