@@ -114,6 +114,16 @@ def __init__(self, features, default=False, sparse=False, df_out=False,
114114 if (df_out and (sparse or default )):
115115 raise ValueError ("Can not use df_out with sparse or default" )
116116
117+ def _build (self ):
118+ """
119+ Build attributes built_features and built_default.
120+ """
121+ if isinstance (self .features , list ):
122+ self .built_features = [_build_feature (* f ) for f in self .features ]
123+ else :
124+ self .built_features = self .features
125+ self .built_default = _build_transformer (self .default )
126+
117127 @property
118128 def _selected_columns (self ):
119129 """
@@ -198,12 +208,7 @@ def fit(self, X, y=None):
198208 y the target vector relative to X, optional
199209
200210 """
201- if isinstance (self .features , list ):
202- self .built_features = [_build_feature (* f ) for f in self .features ]
203- else :
204- self .built_features = self .features
205-
206- self .built_default = _build_transformer (self .default )
211+ self ._build ()
207212
208213 for columns , transformers , options in self .built_features :
209214 input_df = options .get ('input_df' , self .input_df )
@@ -273,23 +278,32 @@ def get_dtype(self, ex):
273278 else :
274279 raise TypeError (type (ex ))
275280
276- def transform (self , X ):
281+ def _transform (self , X , y = None , do_fit = False ):
277282 """
278- Transform the given data. Assumes that fit has already been called .
279-
280- X the data to transform
283+ Transform the given data with possibility to fit in advance .
284+ Avoids code duplication for implementation of transform and
285+ fit_transform.
281286 """
287+ if do_fit :
288+ self ._build ()
289+
282290 extracted = []
283291 self .transformed_names_ = []
284292 for columns , transformers , options in self .built_features :
285293 input_df = options .get ('input_df' , self .input_df )
294+
286295 # columns could be a string or list of
287296 # strings; we don't care because pandas
288297 # will handle either.
289298 Xt = self ._get_col_subset (X , columns , input_df )
290299 if transformers is not None :
291300 with add_column_names_to_exception (columns ):
292- Xt = transformers .transform (Xt )
301+ if do_fit and hasattr (transformers , 'fit_transform' ):
302+ Xt = _call_fit (transformers .fit_transform , Xt , y )
303+ else :
304+ if do_fit :
305+ _call_fit (transformers .fit , Xt , y )
306+ Xt = transformers .transform (Xt )
293307 extracted .append (_handle_feature (Xt ))
294308
295309 alias = options .get ('alias' )
@@ -302,7 +316,12 @@ def transform(self, X):
302316 Xt = self ._get_col_subset (X , unsel_cols , self .input_df )
303317 if self .built_default is not None :
304318 with add_column_names_to_exception (unsel_cols ):
305- Xt = self .built_default .transform (Xt )
319+ if do_fit and hasattr (self .built_default , 'fit_transform' ):
320+ Xt = _call_fit (self .built_default .fit_transform , Xt , y )
321+ else :
322+ if do_fit :
323+ _call_fit (self .built_default .fit , Xt , y )
324+ Xt = self .built_default .transform (Xt )
306325 self .transformed_names_ += self .get_names (
307326 unsel_cols , self .built_default , Xt )
308327 else :
@@ -348,3 +367,22 @@ def transform(self, X):
348367 return df_out
349368 else :
350369 return stacked
370+
371+ def transform (self , X ):
372+ """
373+ Transform the given data. Assumes that fit has already been called.
374+
375+ X the data to transform
376+ """
377+ return self ._transform (X )
378+
379+ def fit_transform (self , X , y = None ):
380+ """
381+ Fit a transformation from the pipeline and directly apply
382+ it to the given data.
383+
384+ X the data to fit
385+
386+ y the target vector relative to X, optional
387+ """
388+ return self ._transform (X , y , True )
0 commit comments