|
1 | 1 | #!/usr/bin/env python |
2 | | -# -*- coding: utf-8 -*-- |
3 | 2 |
|
4 | | -# Copyright (c) 2023 Oracle and/or its affiliates. |
| 3 | +# Copyright (c) 2023, 2024 Oracle and/or its affiliates. |
5 | 4 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
6 | 5 |
|
| 6 | +from abc import ABC |
| 7 | + |
| 8 | +import pandas as pd |
| 9 | +import re |
| 10 | + |
7 | 11 | from ads.opctl import logger |
| 12 | +from ads.opctl.operator.lowcode.common.const import DataColumns |
8 | 13 | from ads.opctl.operator.lowcode.common.errors import ( |
9 | | - InvalidParameterError, |
10 | 14 | DataMismatchError, |
| 15 | + InvalidParameterError, |
11 | 16 | ) |
12 | | -from ads.opctl.operator.lowcode.common.const import DataColumns |
13 | 17 | from ads.opctl.operator.lowcode.common.utils import merge_category_columns |
14 | | -import pandas as pd |
15 | | -from abc import ABC |
16 | 18 |
|
17 | 19 |
|
18 | 20 | class Transformations(ABC): |
@@ -58,6 +60,7 @@ def run(self, data): |
58 | 60 |
|
59 | 61 | """ |
60 | 62 | clean_df = self._remove_trailing_whitespace(data) |
| 63 | + clean_df = self._normalize_column_names(clean_df) |
61 | 64 | if self.name == "historical_data": |
62 | 65 | self._check_historical_dataset(clean_df) |
63 | 66 | clean_df = self._set_series_id_column(clean_df) |
@@ -95,8 +98,11 @@ def run(self, data): |
95 | 98 | def _remove_trailing_whitespace(self, df): |
96 | 99 | return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x) |
97 | 100 |
|
| 101 | + def _normalize_column_names(self, df): |
| 102 | + return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x)) |
| 103 | + |
98 | 104 | def _set_series_id_column(self, df): |
99 | | - self._target_category_columns_map = dict() |
| 105 | + self._target_category_columns_map = {} |
100 | 106 | if not self.target_category_columns: |
101 | 107 | df[DataColumns.Series] = "Series 1" |
102 | 108 | self.has_artificial_series = True |
@@ -125,10 +131,10 @@ def _format_datetime_col(self, df): |
125 | 131 | df[self.dt_column_name] = pd.to_datetime( |
126 | 132 | df[self.dt_column_name], format=self.dt_column_format |
127 | 133 | ) |
128 | | - except: |
| 134 | + except Exception as ee: |
129 | 135 | raise InvalidParameterError( |
130 | 136 | f"Unable to determine the datetime type for column: {self.dt_column_name} in dataset: {self.name}. Please specify the format explicitly. (For example adding 'format: %d/%m/%Y' underneath 'name: {self.dt_column_name}' in the datetime_column section of the yaml file if you haven't already. For reference, here is the first datetime given: {df[self.dt_column_name].values[0]}" |
131 | | - ) |
| 137 | + ) from ee |
132 | 138 | return df |
133 | 139 |
|
134 | 140 | def _set_multi_index(self, df): |
@@ -242,7 +248,6 @@ def _check_historical_dataset(self, df): |
242 | 248 | "Class": "A", |
243 | 249 | "Num": 2 |
244 | 250 | }, |
245 | | - |
246 | 251 | } |
247 | 252 | """ |
248 | 253 |
|
|
0 commit comments