1414 InvalidParameterError ,
1515)
1616from ads .opctl .operator .lowcode .common .utils import merge_category_columns
17+ from ads .opctl .operator .lowcode .forecast .operator_config import ForecastOperatorSpec
1718
1819
1920class Transformations (ABC ):
@@ -60,7 +61,7 @@ def run(self, data):
6061
6162 """
6263 clean_df = self ._remove_trailing_whitespace (data )
63- if hasattr (self .dataset_info , 'horizon' ):
64+ if isinstance (self .dataset_info , ForecastOperatorSpec ):
6465 clean_df = self ._clean_column_names (clean_df )
6566 if self .name == "historical_data" :
6667 self ._check_historical_dataset (clean_df )
@@ -109,9 +110,11 @@ def _clean_column_names(self, df):
109110 Returns:
110111 pd.DataFrame: The DataFrame with cleaned column names.
111112 """
113+
112114 self .raw_column_names = {
113115 col : col .replace (" " , "" ) for col in df .columns if " " in col
114116 }
117+ df .columns = [self .raw_column_names .get (col , col ) for col in df .columns ]
115118
116119 if self .target_column_name :
117120 self .target_column_name = self .raw_column_names .get (
@@ -123,9 +126,9 @@ def _clean_column_names(self, df):
123126
124127 if self .target_category_columns :
125128 self .target_category_columns = [
126- self .raw_column_names .get (col , col ) for col in self .target_category_columns
129+ self .raw_column_names .get (col , col )
130+ for col in self .target_category_columns
127131 ]
128- df .columns = df .columns .str .replace (" " , "" )
129132 return df
130133
131134 def _set_series_id_column (self , df ):
0 commit comments