|
17 | 17 | from ads.opctl.operator.lowcode.forecast.const import ( |
18 | 18 | AUTOMLX_METRIC_MAP, |
19 | 19 | ForecastOutputColumns, |
| 20 | + SpeedAccuracyMode, |
20 | 21 | SupportedModels, |
21 | 22 | ) |
22 | 23 | from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe |
@@ -241,18 +242,18 @@ def _generate_report(self): |
241 | 242 | # If the key is present, call the "explain_model" method |
242 | 243 | self.explain_model() |
243 | 244 |
|
244 | | - # Convert the global explanation data to a DataFrame |
245 | | - global_explanation_df = pd.DataFrame(self.global_explanation) |
| 245 | + global_explanation_section = None |
| 246 | + if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX: |
| 247 | + # Convert the global explanation data to a DataFrame |
| 248 | + global_explanation_df = pd.DataFrame(self.global_explanation) |
246 | 249 |
|
247 | | - self.formatted_global_explanation = ( |
248 | | - global_explanation_df / global_explanation_df.sum(axis=0) * 100 |
249 | | - ) |
250 | | - self.formatted_global_explanation = ( |
251 | | - self.formatted_global_explanation.rename( |
| 250 | + self.formatted_global_explanation = ( |
| 251 | + global_explanation_df / global_explanation_df.sum(axis=0) * 100 |
| 252 | + ) |
| 253 | + self.formatted_global_explanation = self.formatted_global_explanation.rename( |
252 | 254 | {self.spec.datetime_column.name: ForecastOutputColumns.DATE}, |
253 | 255 | axis=1, |
254 | 256 | ) |
255 | | - ) |
256 | 257 |
|
257 | 258 | aggregate_local_explanations = pd.DataFrame() |
258 | 259 | for s_id, local_ex_df in self.local_explanation.items(): |
@@ -293,8 +294,11 @@ def _generate_report(self): |
293 | 294 | ) |
294 | 295 |
|
295 | 296 | # Append the global explanation text and section to the "other_sections" list |
| 297 | + if global_explanation_section: |
| 298 | + other_sections.append(global_explanation_section) |
| 299 | + |
| 300 | + # Append the local explanation text and section to the "other_sections" list |
296 | 301 | other_sections = other_sections + [ |
297 | | - global_explanation_section, |
298 | 302 | local_explanation_section, |
299 | 303 | ] |
300 | 304 | except Exception as e: |
@@ -375,3 +379,79 @@ def _custom_predict_automlx(self, data): |
375 | 379 | return self.models.get(self.series_id).forecast( |
376 | 380 | X=data_temp, periods=data_temp.shape[0] |
377 | 381 | )[self.series_id] |
| 382 | + |
| 383 | + @runtime_dependency( |
| 384 | + module="automlx", |
| 385 | + err_msg=( |
| 386 | + "Please run `python3 -m pip install automlx` to install the required dependencies for model explanation." |
| 387 | + ), |
| 388 | + ) |
| 389 | + def explain_model(self): |
| 390 | + """ |
| 391 | + Generates explanations for the model using the AutoMLx library. |
| 392 | +
|
| 393 | + Parameters |
| 394 | + ---------- |
| 395 | + None |
| 396 | +
|
| 397 | + Returns |
| 398 | + ------- |
| 399 | + None |
| 400 | +
|
| 401 | + Notes |
| 402 | + ----- |
| 403 | + This function works by generating local explanations for each series in the dataset. |
| 404 | + It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions |
| 405 | + for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary. |
| 406 | +
|
| 407 | + If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations. |
| 408 | + Otherwise, it falls back to the default explanation generation method. |
| 409 | + """ |
| 410 | + import automlx |
| 411 | + |
| 412 | + # Loop through each series in the dataset |
| 413 | + for s_id, data_i in self.datasets.get_data_by_series( |
| 414 | + include_horizon=False |
| 415 | + ).items(): |
| 416 | + try: |
| 417 | + if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX: |
| 418 | + # Use the MLExplainer class from AutoMLx to generate explanations |
| 419 | + explainer = automlx.MLExplainer( |
| 420 | + self.models[s_id], |
| 421 | + self.datasets.additional_data.get_data_for_series(series_id=s_id) |
| 422 | + .drop(self.spec.datetime_column.name, axis=1) |
| 423 | + .head(-self.spec.horizon) |
| 424 | + if self.spec.additional_data |
| 425 | + else None, |
| 426 | + pd.DataFrame(data_i[self.spec.target_column]), |
| 427 | + task="forecasting", |
| 428 | + ) |
| 429 | + |
| 430 | + # Generate explanations for the forecast |
| 431 | + explanations = explainer.explain_prediction( |
| 432 | + X=self.datasets.additional_data.get_data_for_series(series_id=s_id) |
| 433 | + .drop(self.spec.datetime_column.name, axis=1) |
| 434 | + .tail(self.spec.horizon) |
| 435 | + if self.spec.additional_data |
| 436 | + else None, |
| 437 | + forecast_timepoints=list(range(self.spec.horizon + 1)), |
| 438 | + ) |
| 439 | + |
| 440 | + # Convert the explanations to a DataFrame |
| 441 | + explanations_df = pd.concat( |
| 442 | + [exp.to_dataframe() for exp in explanations] |
| 443 | + ) |
| 444 | + explanations_df["row"] = explanations_df.groupby("Feature").cumcount() |
| 445 | + explanations_df = explanations_df.pivot( |
| 446 | + index="row", columns="Feature", values="Attribution" |
| 447 | + ) |
| 448 | + explanations_df = explanations_df.reset_index(drop=True) |
| 449 | + |
| 450 | + # Store the explanations in the local_explanation dictionary |
| 451 | + self.local_explanation[s_id] = explanations_df |
| 452 | + else: |
| 453 | + # Fall back to the default explanation generation method |
| 454 | + super().explain_model() |
| 455 | + except Exception as e: |
| 456 | + logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.") |
| 457 | + logger.debug(f"Full Traceback: {traceback.format_exc()}") |
0 commit comments