@@ -31,13 +31,6 @@ class SamplerMixin(six.with_metaclass(ABCMeta, BaseEstimator)):
3131
3232 _estimator_type = 'sampler'
3333
34- def _check_X_y (self , X , y ):
35- """Private function to check that the X and y in fitting are the same
36- than in sampling."""
37- X_hash , y_hash = hash_X_y (X , y )
38- if self .X_hash_ != X_hash or self .y_hash_ != y_hash :
39- raise RuntimeError ("X and y need to be same array earlier fitted." )
40-
4134 def sample (self , X , y ):
4235 """Resample the dataset.
4336
@@ -60,11 +53,10 @@ def sample(self, X, y):
6053
6154 """
6255 # Check the consistency of X and y
63- y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
64- X , y = check_X_y (X , y , accept_sparse = ['csr' , 'csc' ])
56+ X , y , binarize_y = self ._check_X_y (X , y )
6557
6658 check_is_fitted (self , 'sampling_strategy_' )
67- self ._check_X_y (X , y )
59+ self ._check_X_y_hash (X , y )
6860
6961 output = self ._sample (X , y )
7062
@@ -151,6 +143,19 @@ def __init__(self, sampling_strategy='auto', ratio=None):
151143 self .ratio = ratio
152144 self .logger = logging .getLogger (self .__module__ )
153145
146+ @staticmethod
147+ def _check_X_y (X , y ):
148+ y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
149+ X , y = check_X_y (X , y , accept_sparse = ['csr' , 'csc' ])
150+ return X , y , binarize_y
151+
152+ def _check_X_y_hash (self , X , y ):
153+ """Private function to check that the X and y in fitting are the same
154+ than in sampling."""
155+ X_hash , y_hash = hash_X_y (X , y )
156+ if self .X_hash_ != X_hash or self .y_hash_ != y_hash :
157+ raise RuntimeError ("X and y need to be same array earlier fitted." )
158+
154159 @property
155160 def ratio_ (self ):
156161 # FIXME: remove in 0.6
@@ -183,9 +188,9 @@ def fit(self, X, y):
183188
184189 """
185190 self ._deprecate_ratio ()
186- y = check_target_type (y )
187- X , y = check_X_y (X , y , accept_sparse = ['csr' , 'csc' ])
191+ X , y , _ = self ._check_X_y (X , y )
188192 self .X_hash_ , self .y_hash_ = hash_X_y (X , y )
193+ # _sampling_type is defined in the children base class
189194 self .sampling_strategy_ = check_sampling_strategy (
190195 self .sampling_strategy , y , self ._sampling_type )
191196
0 commit comments