Skip to content

Commit 7f29766

Browse files
committed
transform columns with space to without space and preserve a map
1 parent 3e3a2a0 commit 7f29766

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(self, dataset_info, name="historical_data"):
3333
self.dataset_info = dataset_info
3434
self.target_category_columns = dataset_info.target_category_columns
3535
self.target_column_name = dataset_info.target_column
36+
self.raw_column_names = None
3637
self.dt_column_name = (
3738
dataset_info.datetime_column.name if dataset_info.datetime_column else None
3839
)
@@ -59,7 +60,7 @@ def run(self, data):
5960
6061
"""
6162
clean_df = self._remove_trailing_whitespace(data)
62-
# clean_df = self._normalize_column_names(clean_df)
63+
clean_df = self._clean_column_names(clean_df)
6364
if self.name == "historical_data":
6465
self._check_historical_dataset(clean_df)
6566
clean_df = self._set_series_id_column(clean_df)
@@ -97,8 +98,31 @@ def run(self, data):
9798
def _remove_trailing_whitespace(self, df):
9899
return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
99100

100-
# def _normalize_column_names(self, df):
101-
# return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
101+
def _clean_column_names(self, df):
102+
"""
103+
Remove all whitespaces from column names in a DataFrame and store the original names.
104+
105+
Parameters:
106+
df (pd.DataFrame): The DataFrame whose column names need to be cleaned.
107+
108+
Returns:
109+
pd.DataFrame: The DataFrame with cleaned column names.
110+
"""
111+
self.raw_column_names = {
112+
col: col.replace(" ", "") for col in df.columns if " " in col
113+
}
114+
115+
self.target_column_name = self.raw_column_names.get(
116+
self.target_column_name, self.target_column_name
117+
)
118+
self.dt_column_name = self.raw_column_names.get(
119+
self.dt_column_name, self.dt_column_name
120+
)
121+
self.target_category_columns = [
122+
self.raw_column_names.get(col, col) for col in self.target_category_columns
123+
]
124+
df.columns = df.columns.str.replace(" ", "")
125+
return df
102126

103127
def _set_series_id_column(self, df):
104128
self._target_category_columns_map = {}
@@ -226,6 +250,10 @@ def _check_historical_dataset(self, df):
226250
expected_names = [self.target_column_name, self.dt_column_name] + (
227251
self.target_category_columns if self.target_category_columns else []
228252
)
253+
254+
if self.raw_column_names:
255+
expected_names.extend(list(self.raw_column_names.values()))
256+
229257
if set(df.columns) != set(expected_names):
230258
raise DataMismatchError(
231259
f"Expected {self.name} to have columns: {expected_names}, but instead found column names: {df.columns}. Is the {self.name} path correct?"

0 commit comments

Comments
 (0)