diff --git a/ads/opctl/operator/lowcode/common/transformations.py b/ads/opctl/operator/lowcode/common/transformations.py index 55e55370a..d8cfc7375 100644 --- a/ads/opctl/operator/lowcode/common/transformations.py +++ b/ads/opctl/operator/lowcode/common/transformations.py @@ -15,6 +15,7 @@ InvalidParameterError, ) from ads.opctl.operator.lowcode.common.utils import merge_category_columns +from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorSpec class Transformations(ABC): @@ -34,6 +35,7 @@ def __init__(self, dataset_info, name="historical_data"): self.dataset_info = dataset_info self.target_category_columns = dataset_info.target_category_columns self.target_column_name = dataset_info.target_column + self.raw_column_names = None self.dt_column_name = ( dataset_info.datetime_column.name if dataset_info.datetime_column else None ) @@ -60,7 +62,8 @@ def run(self, data): """ clean_df = self._remove_trailing_whitespace(data) - # clean_df = self._normalize_column_names(clean_df) + if isinstance(self.dataset_info, ForecastOperatorSpec): + clean_df = self._clean_column_names(clean_df) if self.name == "historical_data": self._check_historical_dataset(clean_df) clean_df = self._set_series_id_column(clean_df) @@ -98,8 +101,36 @@ def run(self, data): def _remove_trailing_whitespace(self, df): return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x) - # def _normalize_column_names(self, df): - # return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x)) + def _clean_column_names(self, df): + """ + Remove all whitespaces from column names in a DataFrame and store the original names. + + Parameters: + df (pd.DataFrame): The DataFrame whose column names need to be cleaned. + + Returns: + pd.DataFrame: The DataFrame with cleaned column names. + """ + + self.raw_column_names = { + col: col.replace(" ", "") for col in df.columns if " " in col + } + df.columns = [self.raw_column_names.get(col, col) for col in df.columns] + + if self.target_column_name: + self.target_column_name = self.raw_column_names.get( + self.target_column_name, self.target_column_name + ) + self.dt_column_name = self.raw_column_names.get( + self.dt_column_name, self.dt_column_name + ) + + if self.target_category_columns: + self.target_category_columns = [ + self.raw_column_names.get(col, col) + for col in self.target_category_columns + ] + return df def _set_series_id_column(self, df): self._target_category_columns_map = {} @@ -233,6 +264,10 @@ def _check_historical_dataset(self, df): expected_names = [self.target_column_name, self.dt_column_name] + ( self.target_category_columns if self.target_category_columns else [] ) + + if self.raw_column_names: + expected_names.extend(list(self.raw_column_names.values())) + if set(df.columns) != set(expected_names): raise DataMismatchError( f"Expected {self.name} to have columns: {expected_names}, but instead found column names: {df.columns}. Is the {self.name} path correct?"