@@ -127,9 +127,11 @@ def __init__(self, sampling_strategy="auto"):
127127 self .sampling_strategy = sampling_strategy
128128
129129 @staticmethod
130- def _check_X_y (X , y ):
130+ def _check_X_y (X , y , accept_sparse = None ):
131+ if accept_sparse is None :
132+ accept_sparse = ["csr" , "csc" ]
131133 y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
132- X , y = check_X_y (X , y , accept_sparse = [ "csr" , "csc" ] )
134+ X , y = check_X_y (X , y , accept_sparse = accept_sparse )
133135 return X , y , binarize_y
134136
135137
@@ -156,6 +158,11 @@ class FunctionSampler(BaseSampler):
156158 kw_args : dict, optional (default=None)
157159 The keyword argument expected by ``func``.
158160
161+ validate : bool, default=True
162+ Whether or not to bypass the validation of ``X`` and ``y``. Turning-off
163+ validation allows to use the ``FunctionSampler`` with any type of
164+ data.
165+
159166 Notes
160167 -----
161168
@@ -202,16 +209,55 @@ class FunctionSampler(BaseSampler):
202209
203210 _sampling_type = "bypass"
204211
205- def __init__ (self , func = None , accept_sparse = True , kw_args = None ):
212+ def __init__ (self , func = None , accept_sparse = True , kw_args = None ,
213+ validate = True ):
206214 super ().__init__ ()
207215 self .func = func
208216 self .accept_sparse = accept_sparse
209217 self .kw_args = kw_args
218+ self .validate = validate
210219
211- def _fit_resample (self , X , y ):
212- X , y = check_X_y (
213- X , y , accept_sparse = ["csr" , "csc" ] if self .accept_sparse else False
220+ def fit_resample (self , X , y ):
221+ """Resample the dataset.
222+
223+ Parameters
224+ ----------
225+ X : {array-like, sparse matrix}, shape (n_samples, n_features)
226+ Matrix containing the data which have to be sampled.
227+
228+ y : array-like, shape (n_samples,)
229+ Corresponding label for each sample in X.
230+
231+ Returns
232+ -------
233+ X_resampled : {array-like, sparse matrix}, shape \
234+ (n_samples_new, n_features)
235+ The array containing the resampled data.
236+
237+ y_resampled : array-like, shape (n_samples_new,)
238+ The corresponding label of `X_resampled`.
239+
240+ """
241+ if self .validate :
242+ check_classification_targets (y )
243+ X , y , binarize_y = self ._check_X_y (
244+ X , y , accept_sparse = self .accept_sparse
245+ )
246+
247+ self .sampling_strategy_ = check_sampling_strategy (
248+ self .sampling_strategy , y , self ._sampling_type
214249 )
250+
251+ output = self ._fit_resample (X , y )
252+
253+ if self .validate and binarize_y :
254+ y_sampled = label_binarize (output [1 ], np .unique (y ))
255+ if len (output ) == 2 :
256+ return output [0 ], y_sampled
257+ return output [0 ], y_sampled , output [2 ]
258+ return output
259+
260+ def _fit_resample (self , X , y ):
215261 func = _identity if self .func is None else self .func
216262 output = func (X , y , ** (self .kw_args if self .kw_args else {}))
217263 return output
0 commit comments