@@ -61,6 +61,8 @@ def _yield_sampler_checks(sampler):
6161 yield check_samplers_pandas
6262 if "dask-array" in tags ["X_types" ]:
6363 yield check_samplers_dask_array
64+ if "dask-dataframe" in tags ["X_types" ]:
65+ yield check_samplers_dask_dataframe
6466 yield check_samplers_list
6567 yield check_samplers_multiclass_ova
6668 yield check_samplers_preserve_dtype
@@ -295,7 +297,7 @@ def check_samplers_pandas(name, sampler):
295297
296298def check_samplers_dask_array (name , sampler ):
297299 dask = pytest .importorskip ("dask" )
298- # Check that the samplers handle pandas dataframe and pandas series
300+ # Check that the samplers handle dask array
299301 X , y = make_classification (
300302 n_samples = 1000 ,
301303 n_classes = 3 ,
@@ -317,6 +319,37 @@ def check_samplers_dask_array(name, sampler):
317319 assert_allclose (y_res_dask , y_res )
318320
319321
322+ def check_samplers_dask_dataframe (name , sampler ):
323+ dask = pytest .importorskip ("dask" )
324+ # Check that the samplers handle dask dataframe and dask series
325+ X , y = make_classification (
326+ n_samples = 1000 ,
327+ n_classes = 3 ,
328+ n_informative = 4 ,
329+ weights = [0.2 , 0.3 , 0.5 ],
330+ random_state = 0 ,
331+ )
332+ X_df = dask .dataframe .from_array (
333+ X , columns = [str (i ) for i in range (X .shape [1 ])]
334+ )
335+ y_s = dask .dataframe .from_array (y )
336+
337+ X_res_df , y_res_s = sampler .fit_resample (X_df , y_s )
338+ X_res , y_res = sampler .fit_resample (X , y )
339+
340+ # check that we return the same type for dataframes or series types
341+ assert isinstance (X_res_df , dask .dataframe .DataFrame )
342+ assert isinstance (y_res_s , dask .dataframe .Series )
343+
344+ # assert X_df.columns.to_list() == X_res_df.columns.to_list()
345+ # assert y_df.columns.to_list() == y_res_df.columns.to_list()
346+ # assert y_s.name == y_res_s.name
347+
348+ # assert_allclose(X_res_df.to_numpy(), X_res)
349+ # assert_allclose(y_res_df.to_numpy().ravel(), y_res)
350+ # assert_allclose(y_res_s.to_numpy(), y_res)
351+
352+
320353def check_samplers_list (name , sampler ):
321354 # Check that the can samplers handle simple lists
322355 X , y = make_classification (
0 commit comments