@@ -33,6 +33,7 @@ def __init__(self, dataset_info, name="historical_data"):
3333 self .dataset_info = dataset_info
3434 self .target_category_columns = dataset_info .target_category_columns
3535 self .target_column_name = dataset_info .target_column
36+ self .raw_column_names = None
3637 self .dt_column_name = (
3738 dataset_info .datetime_column .name if dataset_info .datetime_column else None
3839 )
@@ -59,7 +60,7 @@ def run(self, data):
5960
6061 """
6162 clean_df = self ._remove_trailing_whitespace (data )
62- # clean_df = self._normalize_column_names (clean_df)
63+ clean_df = self ._clean_column_names (clean_df )
6364 if self .name == "historical_data" :
6465 self ._check_historical_dataset (clean_df )
6566 clean_df = self ._set_series_id_column (clean_df )
@@ -97,8 +98,31 @@ def run(self, data):
9798 def _remove_trailing_whitespace (self , df ):
9899 return df .apply (lambda x : x .str .strip () if x .dtype == "object" else x )
99100
100- # def _normalize_column_names(self, df):
101- # return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
101+ def _clean_column_names (self , df ):
102+ """
103+ Remove all whitespaces from column names in a DataFrame and store the original names.
104+
105+ Parameters:
106+ df (pd.DataFrame): The DataFrame whose column names need to be cleaned.
107+
108+ Returns:
109+ pd.DataFrame: The DataFrame with cleaned column names.
110+ """
111+ self .raw_column_names = {
112+ col : col .replace (" " , "" ) for col in df .columns if " " in col
113+ }
114+
115+ self .target_column_name = self .raw_column_names .get (
116+ self .target_column_name , self .target_column_name
117+ )
118+ self .dt_column_name = self .raw_column_names .get (
119+ self .dt_column_name , self .dt_column_name
120+ )
121+ self .target_category_columns = [
122+ self .raw_column_names .get (col , col ) for col in self .target_category_columns
123+ ]
124+ df .columns = df .columns .str .replace (" " , "" )
125+ return df
102126
103127 def _set_series_id_column (self , df ):
104128 self ._target_category_columns_map = {}
@@ -226,6 +250,10 @@ def _check_historical_dataset(self, df):
226250 expected_names = [self .target_column_name , self .dt_column_name ] + (
227251 self .target_category_columns if self .target_category_columns else []
228252 )
253+
254+ if self .raw_column_names :
255+ expected_names .extend (list (self .raw_column_names .values ()))
256+
229257 if set (df .columns ) != set (expected_names ):
230258 raise DataMismatchError (
231259 f"Expected { self .name } to have columns: { expected_names } , but instead found column names: { df .columns } . Is the { self .name } path correct?"
0 commit comments