@@ -54,7 +54,6 @@ def _validate_estimator(self):
5454 self .nn_k_ = check_neighbors_object (
5555 "k_neighbors" , self .k_neighbors , additional_neighbor = 1
5656 )
57- self .nn_k_ .set_params (** {"n_jobs" : self .n_jobs })
5857
5958 def _make_samples (
6059 self , X , y_dtype , y_type , nn_data , nn_num , n_samples , step_size = 1.0
@@ -956,6 +955,7 @@ def _fit_resample(self, X, y):
956955 self .ohe_ = OneHotEncoder (
957956 sparse = True , handle_unknown = "ignore" , dtype = dtype_ohe
958957 )
958+
959959 # the input of the OneHotEncoder needs to be dense
960960 X_ohe = self .ohe_ .fit_transform (
961961 X_categorical .toarray ()
@@ -967,6 +967,15 @@ def _fit_resample(self, X, y):
967967 # median of the standard deviation. It will ensure that whenever
968968 # distance is computed between 2 samples, the difference will be equal
969969 # to the median of the standard deviation as in the original paper.
970+
971+ # In the edge case where the median of the std is equal to 0, the 1s
972+ # entries will be also nullified. In this case, we store the original
973+ # categorical encoding which will be later used for inversing the OHE
974+ if math .isclose (self .median_std_ , 0 ):
975+ self ._X_categorical_minority_encoded = _safe_indexing (
976+ X_ohe .toarray (), np .flatnonzero (y == class_minority )
977+ )
978+
970979 X_ohe .data = (
971980 np .ones_like (X_ohe .data , dtype = X_ohe .dtype ) * self .median_std_ / 2
972981 )
@@ -1027,6 +1036,14 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
10271036
10281037 # convert to dense array since scipy.sparse doesn't handle 3D
10291038 nn_data = (nn_data .toarray () if sparse .issparse (nn_data ) else nn_data )
1039+
1040+ # In the case that the median std was equal to zeros, we have to
1041+ # create non-null entry based on the encoded of OHE
1042+ if math .isclose (self .median_std_ , 0 ):
1043+ nn_data [:, self .continuous_features_ .size :] = (
1044+ self ._X_categorical_minority_encoded
1045+ )
1046+
10301047 all_neighbors = nn_data [nn_num [rows ]]
10311048
10321049 categories_size = [self .continuous_features_ .size ] + [
0 commit comments