@@ -80,20 +80,22 @@ def fit_resample(self, X, y):
8080
8181 output = self ._fit_resample (X , y )
8282
83- if self ._columns is not None :
83+ if self ._X_columns is not None or self . _y_name is not None :
8484 import pandas as pd
85- X_ = pd .DataFrame (output [0 ], columns = self ._columns )
85+
86+ if self ._X_columns is not None :
87+ X_ = pd .DataFrame (output [0 ], columns = self ._X_columns )
88+ X_ = X_ .astype (self ._X_dtypes )
8689 else :
8790 X_ = output [0 ]
8891
89- if binarize_y :
90- y_sampled = label_binarize (output [1 ], np .unique (y ))
91- if len (output ) == 2 :
92- return X_ , y_sampled
93- return X_ , y_sampled , output [2 ]
94- if len (output ) == 2 :
95- return X_ , output [1 ]
96- return X_ , output [1 ], output [2 ]
92+ y_ = (label_binarize (output [1 ], np .unique (y ))
93+ if binarize_y else output [1 ])
94+
95+ if self ._y_name is not None :
96+ y_ = pd .Series (y_ , dtype = self ._y_dtype , name = self ._y_name )
97+
98+ return (X_ , y_ ) if len (output ) == 2 else (X_ , y_ , output [2 ])
9799
98100 # define an alias for back-compatibility
99101 fit_sample = fit_resample
@@ -135,8 +137,22 @@ def __init__(self, sampling_strategy="auto"):
135137 self .sampling_strategy = sampling_strategy
136138
137139 def _check_X_y (self , X , y , accept_sparse = None ):
138- # store the columns name to reconstruct a dataframe
139- self ._columns = X .columns if hasattr (X , "loc" ) else None
140+ if hasattr (X , "loc" ):
141+ # store information to build dataframe
142+ self ._X_columns = X .columns
143+ self ._X_dtypes = X .dtypes
144+ else :
145+ self ._X_columns = None
146+ self ._X_dtypes = None
147+
148+ if hasattr (y , "loc" ):
149+ # store information to build a series
150+ self ._y_name = y .name
151+ self ._y_dtype = y .dtype
152+ else :
153+ self ._y_name = None
154+ self ._y_dtype = None
155+
140156 if accept_sparse is None :
141157 accept_sparse = ["csr" , "csc" ]
142158 y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
@@ -263,20 +279,24 @@ def fit_resample(self, X, y):
263279
264280 output = self ._fit_resample (X , y )
265281
266- if self ._columns is not None :
267- import pandas as pd
268- X_ = pd .DataFrame (output [0 ], columns = self ._columns )
269- else :
270- X_ = output [0 ]
282+ if self .validate :
283+ if self ._X_columns is not None or self ._y_name is not None :
284+ import pandas as pd
271285
272- if self .validate and binarize_y :
273- y_sampled = label_binarize (output [1 ], np .unique (y ))
274- if len (output ) == 2 :
275- return X_ , y_sampled
276- return X_ , y_sampled , output [2 ]
277- if len (output ) == 2 :
278- return X_ , output [1 ]
279- return X_ , output [1 ], output [2 ]
286+ if self ._X_columns is not None :
287+ X_ = pd .DataFrame (output [0 ], columns = self ._X_columns )
288+ X_ = X_ .astype (self ._X_dtypes )
289+ else :
290+ X_ = output [0 ]
291+
292+ y_ = (label_binarize (output [1 ], np .unique (y ))
293+ if binarize_y else output [1 ])
294+
295+ if self ._y_name is not None :
296+ y_ = pd .Series (y_ , dtype = self ._y_dtype , name = self ._y_name )
297+
298+ return (X_ , y_ ) if len (output ) == 2 else (X_ , y_ , output [2 ])
299+ return output
280300
281301 def _fit_resample (self , X , y ):
282302 func = _identity if self .func is None else self .func
0 commit comments