@@ -60,7 +60,7 @@ def run(self, data):
6060
6161 """
6262 clean_df = self ._remove_trailing_whitespace (data )
63- if self .dataset_info . horizon :
63+ if hasattr ( self .dataset_info , ' horizon' ) :
6464 clean_df = self ._clean_column_names (clean_df )
6565 if self .name == "historical_data" :
6666 self ._check_historical_dataset (clean_df )
@@ -113,15 +113,18 @@ def _clean_column_names(self, df):
113113 col : col .replace (" " , "" ) for col in df .columns if " " in col
114114 }
115115
116- self .target_column_name = self .raw_column_names .get (
117- self .target_column_name , self .target_column_name
118- )
116+ if self .target_column_name :
117+ self .target_column_name = self .raw_column_names .get (
118+ self .target_column_name , self .target_column_name
119+ )
119120 self .dt_column_name = self .raw_column_names .get (
120121 self .dt_column_name , self .dt_column_name
121122 )
122- self .target_category_columns = [
123- self .raw_column_names .get (col , col ) for col in self .target_category_columns
124- ]
123+
124+ if self .target_category_columns :
125+ self .target_category_columns = [
126+ self .raw_column_names .get (col , col ) for col in self .target_category_columns
127+ ]
125128 df .columns = df .columns .str .replace (" " , "" )
126129 return df
127130
0 commit comments