3030
3131from imblearn .over_sampling .base import BaseOverSampler
3232from imblearn .under_sampling .base import BaseCleaningSampler , BaseUnderSampler
33- from imblearn .ensemble .base import BaseEnsembleSampler
3433from imblearn .under_sampling import NearMiss , ClusterCentroids
3534
3635
@@ -168,12 +167,6 @@ def check_samplers_fit_resample(name, Sampler):
168167 for class_sample in target_stats .keys ()
169168 if class_sample != class_minority
170169 )
171- elif isinstance (sampler , BaseEnsembleSampler ):
172- y_ensemble = y_res [0 ]
173- n_samples = min (target_stats .values ())
174- assert all (
175- value == n_samples for value in Counter (y_ensemble ).values ()
176- )
177170
178171
179172def check_samplers_sampling_strategy_fit_resample (name , Sampler ):
@@ -202,12 +195,6 @@ def check_samplers_sampling_strategy_fit_resample(name, Sampler):
202195 sampler .set_params (sampling_strategy = sampling_strategy )
203196 X_res , y_res = sampler .fit_resample (X , y )
204197 assert Counter (y_res )[1 ] == expected_stat
205- if isinstance (sampler , BaseEnsembleSampler ):
206- sampling_strategy = {2 : 201 , 0 : 201 }
207- sampler .set_params (sampling_strategy = sampling_strategy )
208- X_res , y_res = sampler .fit_resample (X , y )
209- y_ensemble = y_res [0 ]
210- assert Counter (y_ensemble )[1 ] == expected_stat
211198
212199
213200def check_samplers_sparse (name , Sampler ):
@@ -239,17 +226,9 @@ def check_samplers_sparse(name, Sampler):
239226 set_random_state (sampler )
240227 X_res_sparse , y_res_sparse = sampler .fit_resample (X_sparse , y )
241228 X_res , y_res = sampler .fit_resample (X , y )
242- if not isinstance (sampler , BaseEnsembleSampler ):
243- assert sparse .issparse (X_res_sparse )
244- assert_allclose (X_res_sparse .A , X_res )
245- assert_allclose (y_res_sparse , y_res )
246- else :
247- for x_sp , x , y_sp , y in zip (
248- X_res_sparse , X_res , y_res_sparse , y_res
249- ):
250- assert sparse .issparse (x_sp )
251- assert_allclose (x_sp .A , x )
252- assert_allclose (y_sp , y )
229+ assert sparse .issparse (X_res_sparse )
230+ assert_allclose (X_res_sparse .A , X_res )
231+ assert_allclose (y_res_sparse , y_res )
253232
254233
255234def check_samplers_pandas (name , Sampler ):
@@ -262,7 +241,7 @@ def check_samplers_pandas(name, Sampler):
262241 weights = [0.2 , 0.3 , 0.5 ],
263242 random_state = 0 ,
264243 )
265- X_pd = pd .DataFrame (X )
244+ X_pd = pd .DataFrame (X , columns = [ str ( i ) for i in range ( X . shape [ 1 ])] )
266245 sampler = Sampler ()
267246 if isinstance (Sampler (), NearMiss ):
268247 samplers = [Sampler (version = version ) for version in (1 , 2 , 3 )]
@@ -274,7 +253,11 @@ def check_samplers_pandas(name, Sampler):
274253 set_random_state (sampler )
275254 X_res_pd , y_res_pd = sampler .fit_resample (X_pd , y )
276255 X_res , y_res = sampler .fit_resample (X , y )
277- assert_allclose (X_res_pd , X_res )
256+
257+ # check that we return a pandas dataframe if a dataframe was given in
258+ assert isinstance (X_res_pd , pd .DataFrame )
259+ assert X_pd .columns .to_list () == X_res_pd .columns .to_list ()
260+ assert_allclose (X_res_pd .to_numpy (), X_res )
278261 assert_allclose (y_res_pd , y_res )
279262
280263
@@ -293,13 +276,8 @@ def check_samplers_multiclass_ova(name, Sampler):
293276 X_res , y_res = sampler .fit_resample (X , y )
294277 X_res_ova , y_res_ova = sampler .fit_resample (X , y_ova )
295278 assert_allclose (X_res , X_res_ova )
296- if issubclass (Sampler , BaseEnsembleSampler ):
297- for batch_y , batch_y_ova in zip (y_res , y_res_ova ):
298- assert type_of_target (batch_y_ova ) == type_of_target (y_ova )
299- assert_allclose (batch_y , batch_y_ova .argmax (axis = 1 ))
300- else :
301- assert type_of_target (y_res_ova ) == type_of_target (y_ova )
302- assert_allclose (y_res , y_res_ova .argmax (axis = 1 ))
279+ assert type_of_target (y_res_ova ) == type_of_target (y_ova )
280+ assert_allclose (y_res , y_res_ova .argmax (axis = 1 ))
303281
304282
305283def check_samplers_preserve_dtype (name , Sampler ):
0 commit comments