Skip to content

Commit 75c0f05

Browse files
committed
changed target in outputs such that original target column is not present
1 parent 338ac6b commit 75c0f05

File tree

7 files changed

+51
-32
lines changed

7 files changed

+51
-32
lines changed

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _train_model(self, i, target, df):
128128
for param in ['arima_res_', 'endog_index_']:
129129
if param in params:
130130
params.pop(param)
131-
self.model_parameters[target] = {
131+
self.model_parameters[utils.convert_target(target, self.original_target_column)] = {
132132
"framework": SupportedModels.Arima,
133133
**params,
134134
}
@@ -197,7 +197,7 @@ def _generate_report(self):
197197

198198
sec5_text = dp.Text(f"## ARIMA Model Parameters")
199199
blocks = [
200-
dp.HTML(m.summary().as_html(), label=target)
200+
dp.HTML(m.summary().as_html(), label=utils.convert_target(target, self.original_target_column))
201201
for i, (target, m) in enumerate(self.models.items())
202202
]
203203
sec5 = dp.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
@@ -242,7 +242,7 @@ def _generate_report(self):
242242
blocks = [
243243
dp.DataTable(
244244
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
245-
label=s_id,
245+
label=utils.convert_target(s_id, self.original_target_column),
246246
)
247247
for s_id, local_ex_df in self.local_explanation.items()
248248
]

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def _build_model(self) -> pd.DataFrame:
164164
outputs[target] = summary_frame
165165
# outputs_legacy[target] = summary_frame
166166

167-
self.model_parameters[target] = {
167+
self.model_parameters[utils.convert_target(target, self.original_target_column)] = {
168168
"framework": SupportedModels.AutoMLX,
169169
"score_metric": model.score_metric,
170170
"random_state": model.random_state,
@@ -304,7 +304,7 @@ def _generate_report(self):
304304
blocks = [
305305
dp.DataTable(
306306
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
307-
label=s_id,
307+
label=utils.convert_target(s_id, self.original_target_column),
308308
)
309309
for s_id, local_ex_df in self.local_explanation.items()
310310
]

ads/opctl/operator/lowcode/forecast/model/autots.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _build_model(self) -> pd.DataFrame:
7878
drop_data_older_than_periods=self.spec.model_kwargs.get(
7979
"drop_data_older_than_periods", None
8080
),
81-
model_list=self.spec.model_kwargs.get("model_list", "fast_parallel"),
81+
model_list=self.spec.model_kwargs.get("model_list", "fast_parallel"),
8282
transformer_list=self.spec.model_kwargs.get("transformer_list", "auto"),
8383
transformer_max_depth=self.spec.model_kwargs.get(
8484
"transformer_max_depth", 6
@@ -225,7 +225,7 @@ def _build_model(self) -> pd.DataFrame:
225225
category=cat, target_category_column=cat_target, forecast=output_i
226226
)
227227

228-
self.model_parameters[cat_target] = {
228+
self.model_parameters[utils.convert_target(cat_target, self.original_target_column)] = {
229229
"framework": SupportedModels.AutoTS,
230230
**params,
231231
}
@@ -266,6 +266,7 @@ def _generate_report(self) -> tuple:
266266
].min(),
267267
),
268268
target_columns=self.target_columns,
269+
original_target_column=self.original_target_column
269270
)
270271

271272
# Section 2: AutoTS Model Parameters
@@ -323,7 +324,7 @@ def _generate_report(self) -> tuple:
323324
blocks = [
324325
dp.DataTable(
325326
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
326-
label=s_id,
327+
label=utils.convert_target(s_id, self.original_target_column),
327328
)
328329
for s_id, local_ex_df in self.local_explanation.items()
329330
]

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def generate_report(self):
105105
self.datasets,
106106
self.forecast_output,
107107
self.spec.datetime_column.name,
108+
self.original_target_column,
108109
target_col=self.forecast_col_name,
109110
)
110111
else:
@@ -125,6 +126,7 @@ def generate_report(self):
125126
target_columns=self.target_columns,
126127
test_filename=self.spec.test_data.url,
127128
output=self.forecast_output,
129+
original_target_column=self.original_target_column,
128130
target_col=self.forecast_col_name,
129131
elapsed_time=elapsed_time,
130132
)
@@ -144,12 +146,13 @@ def generate_report(self):
144146

145147
title_text = dp.Text("# Forecast Report")
146148

147-
md_columns = " * ".join([f"{x} \n" for x in self.target_columns])
149+
md_columns = " * ".join([f"{utils.convert_target(x,self.original_target_column)} \n"
150+
for x in self.target_columns])
148151
first_10_rows_blocks = [
149152
dp.DataTable(
150153
df.head(10).rename({col: self.spec.target_column}, axis=1),
151154
caption="Start",
152-
label=col,
155+
label=utils.convert_target(col, self.original_target_column),
153156
)
154157
for col, df in self.full_data_dict.items()
155158
]
@@ -158,7 +161,7 @@ def generate_report(self):
158161
dp.DataTable(
159162
df.tail(10).rename({col: self.spec.target_column}, axis=1),
160163
caption="End",
161-
label=col,
164+
label=utils.convert_target(col, self.original_target_column),
162165
)
163166
for col, df in self.full_data_dict.items()
164167
]
@@ -167,7 +170,7 @@ def generate_report(self):
167170
dp.DataTable(
168171
df.rename({col: self.spec.target_column}, axis=1).describe(),
169172
caption="Summary Statistics",
170-
label=col,
173+
label=utils.convert_target(col, self.original_target_column),
171174
)
172175
for col, df in self.full_data_dict.items()
173176
]
@@ -254,6 +257,7 @@ def generate_report(self):
254257
forecast_sec = utils.get_forecast_plots(
255258
self.forecast_output,
256259
self.target_columns,
260+
self.original_target_column,
257261
horizon=self.spec.horizon,
258262
test_data=test_data,
259263
ci_interval_width=self.spec.confidence_interval_width,
@@ -280,7 +284,7 @@ def generate_report(self):
280284
)
281285

282286
def _test_evaluate_metrics(
283-
self, target_columns, test_filename, output, target_col="yhat", elapsed_time=0
287+
self, target_columns, test_filename, output, original_target_column, target_col="yhat", elapsed_time=0
284288
):
285289
total_metrics = pd.DataFrame()
286290
summary_metrics = pd.DataFrame()
@@ -335,7 +339,7 @@ def _test_evaluate_metrics(
335339
metrics_df = utils._build_metrics_df(
336340
y_true=y_true[-self.spec.horizon:],
337341
y_pred=y_pred[-self.spec.horizon:],
338-
column_name=target_column_i,
342+
column_name=utils.convert_target(target_column_i, original_target_column),
339343
)
340344
total_metrics = pd.concat([total_metrics, metrics_df], axis=1)
341345
else:
@@ -675,7 +679,7 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
675679
f"No explanations generated. Ensure that additional data has been provided."
676680
)
677681
else:
678-
self.global_explanation[series_id] = dict(
682+
self.global_explanation[utils.convert_target(series_id, self.original_target_column)] = dict(
679683
zip(
680684
data_trimmed.columns[1:],
681685
np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0),
@@ -724,4 +728,4 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
724728
["series_id", self.spec.target_column], axis=1, inplace=True
725729
)
726730

727-
self.local_explanation[series_id] = local_kernel_explnr_df
731+
self.local_explanation[utils.convert_target(series_id, self.original_target_column)] = local_kernel_explnr_df

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def _load_model(self):
8181
except:
8282
logger.info("model.pkl/trainer.pkl is not present")
8383

84-
8584
def _train_model(self, i, target, df):
8685

8786
try:
@@ -236,7 +235,7 @@ def objective(trial):
236235
self.models[target] = model
237236
self.trainers[target] = model.trainer
238237

239-
self.model_parameters[target] = {
238+
self.model_parameters[utils.convert_target(target, self.original_target_column)] = {
240239
"framework": SupportedModels.NeuralProphet,
241240
"config": model.config,
242241
"config_trend": model.config_trend,
@@ -259,7 +258,7 @@ def objective(trial):
259258
"highlight_forecast_step_n": model.highlight_forecast_step_n,
260259
"true_ar_weights": model.true_ar_weights,
261260
}
262-
261+
263262
logger.debug("===========Done===========")
264263
except Exception as e:
265264
self.errors_dict[target] = {"model_name": self.spec.model, "error": str(e)}
@@ -286,7 +285,6 @@ def _build_model(self) -> pd.DataFrame:
286285
if self.loaded_trainers is not None:
287286
self.trainers = self.loaded_trainers
288287

289-
290288
# Merge the outputs from each model into 1 df with all outputs by target and category
291289
col = self.original_target_column
292290
output_col = pd.DataFrame()
@@ -349,18 +347,21 @@ def _generate_report(self):
349347
sec1 = utils._select_plot_list(
350348
lambda idx, target, *args: self.models[target].plot(self.outputs[target]),
351349
target_columns=self.target_columns,
350+
original_target_column=self.original_target_column
352351
)
353352

354353
sec2_text = dp.Text(f"## Forecast Broken Down by Trend Component")
355354
sec2 = utils._select_plot_list(
356355
lambda idx, target, *args: self.models[target].plot_components(self.outputs[target]),
357356
target_columns=self.target_columns,
357+
original_target_column=self.original_target_column
358358
)
359359

360360
sec3_text = dp.Text(f"## Forecast Parameter Plots")
361361
sec3 = utils._select_plot_list(
362362
lambda idx, target, *args: self.models[target].plot_parameters(),
363363
target_columns=self.target_columns,
364+
original_target_column=self.original_target_column
364365
)
365366

366367
sec5_text = dp.Text(f"## Neural Prophet Model Parameters")
@@ -370,7 +371,7 @@ def _generate_report(self):
370371
pd.Series(
371372
m.state_dict(),
372373
index=m.state_dict().keys(),
373-
name=target,
374+
name=utils.convert_target(target, self.original_target_column),
374375
)
375376
)
376377
all_model_states = pd.concat(model_states, axis=1)
@@ -406,7 +407,7 @@ def _generate_report(self):
406407
global_explanation_df = pd.DataFrame(self.global_explanation)
407408

408409
self.formatted_global_explanation = (
409-
global_explanation_df / global_explanation_df.sum(axis=0) * 100
410+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
410411
)
411412

412413
# Create a markdown section for the global explainability
@@ -428,7 +429,7 @@ def _generate_report(self):
428429
blocks = [
429430
dp.DataTable(
430431
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
431-
label=s_id,
432+
label=utils.convert_target(s_id, self.original_target_column),
432433
)
433434
for s_id, local_ex_df in self.local_explanation.items()
434435
]

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def objective(trial):
207207
for param in ["history", "history_dates", "stan_fit"]:
208208
if param in params:
209209
params.pop(param)
210-
self.model_parameters[target] = {
210+
self.model_parameters[utils.convert_target(target, self.original_target_column)] = {
211211
"framework": SupportedModels.Prophet,
212212
**params,
213213
}
@@ -293,12 +293,14 @@ def _generate_report(self):
293293
self.outputs[target], include_legend=True
294294
),
295295
target_columns=self.target_columns,
296+
original_target_column=self.original_target_column
296297
)
297298

298299
sec2_text = dp.Text(f"## Forecast Broken Down by Trend Component")
299300
sec2 = utils._select_plot_list(
300301
lambda idx, target, *args: self.models[target].plot_components(self.outputs[target]),
301302
target_columns=self.target_columns,
303+
original_target_column=self.original_target_column
302304
)
303305

304306
sec3_text = dp.Text(f"## Forecast Changepoints")
@@ -313,7 +315,9 @@ def _generate_report(self):
313315
for idx in range(len(self.target_columns))
314316
]
315317
sec3 = utils._select_plot_list(
316-
lambda idx, *args: sec3_figs[idx], target_columns=self.target_columns
318+
lambda idx, *args: sec3_figs[idx],
319+
target_columns=self.target_columns,
320+
original_target_column=self.original_target_column
317321
)
318322

319323
all_sections = [sec1_text, sec1, sec2_text, sec2, sec3_text, sec3]
@@ -374,7 +378,7 @@ def _generate_report(self):
374378
blocks = [
375379
dp.DataTable(
376380
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
377-
label=s_id,
381+
label=utils.convert_target(s_id,self.original_target_column),
378382
)
379383
for s_id, local_ex_df in self.local_explanation.items()
380384
]

ads/opctl/operator/lowcode/forecast/utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,7 @@ def _build_metrics_df(y_true, y_pred, column_name):
394394

395395

396396
def evaluate_train_metrics(
397-
target_columns, datasets, output, datetime_col, target_col="yhat"
398-
):
397+
target_columns, datasets, output, datetime_col, original_target_column, target_col="yhat"):
399398
"""
400399
Training metrics
401400
"""
@@ -408,7 +407,7 @@ def evaluate_train_metrics(
408407
y_true = forecast_by_col["input_value"].values
409408
y_pred = forecast_by_col["fitted_value"].values
410409
metrics_df = _build_metrics_df(
411-
y_true=y_true, y_pred=y_pred, column_name=col
410+
y_true=y_true, y_pred=y_pred, column_name=convert_target(col, original_target_column)
412411
)
413412
total_metrics = pd.concat([total_metrics, metrics_df], axis=1)
414413
except Exception as e:
@@ -417,10 +416,11 @@ def evaluate_train_metrics(
417416
return total_metrics
418417

419418

420-
def _select_plot_list(fn, target_columns):
419+
def _select_plot_list(fn, target_columns, original_target_column):
421420
import datapane as dp
422421

423-
blocks = [dp.Plot(fn(i, col), label=col) for i, col in enumerate(target_columns)]
422+
blocks = [dp.Plot(fn(i, target), label=convert_target(target, original_target_column))
423+
for i, target in enumerate(target_columns)]
424424
return dp.Select(blocks=blocks) if len(target_columns) > 1 else blocks[0]
425425

426426

@@ -431,6 +431,7 @@ def _add_unit(num, unit):
431431
def get_forecast_plots(
432432
forecast_output,
433433
target_columns,
434+
original_target_column,
434435
horizon,
435436
test_data=None,
436437
ci_interval_width=0.95,
@@ -524,7 +525,7 @@ def plot_forecast_plotly(idx, col):
524525
)
525526
return fig
526527

527-
return _select_plot_list(plot_forecast_plotly, target_columns)
528+
return _select_plot_list(plot_forecast_plotly, target_columns, original_target_column)
528529

529530

530531
def human_time_friendly(seconds):
@@ -627,6 +628,14 @@ def get_frequency_of_datetime(data: pd.DataFrame, dataset_info: ForecastOperator
627628
return freq
628629

629630

631+
def convert_target(target: str, target_col: str):
632+
if target_col is not None and target_col!='':
633+
temp = target_col + '_'
634+
if temp in target:
635+
target = target.replace(temp, '')
636+
return target
637+
638+
630639
def default_signer(**kwargs):
631640
os.environ["EXTRA_USER_AGENT_INFO"] = "Forecast-Operator"
632641
from ads.common.auth import default_signer

0 commit comments

Comments
 (0)