11#!/usr/bin/env python
22
3- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
66import logging
5151 SupportedModels ,
5252)
5353from ..operator_config import ForecastOperatorConfig , ForecastOperatorSpec
54- from .forecast_datasets import ForecastDatasets
54+ from .forecast_datasets import ForecastDatasets , ForecastResults
5555
5656logging .getLogger ("report_creator" ).setLevel (logging .WARNING )
5757
@@ -350,11 +350,12 @@ def generate_report(self):
350350 )
351351
352352 # save the report and result CSV
353- self ._save_report (
353+ return self ._save_report (
354354 report_sections = report_sections ,
355355 result_df = result_df ,
356356 metrics_df = self .eval_metrics ,
357357 test_metrics_df = self .test_eval_metrics ,
358+ test_data = test_data ,
358359 )
359360
360361 def _test_evaluate_metrics (self , elapsed_time = 0 ):
@@ -471,10 +472,12 @@ def _save_report(
471472 result_df : pd .DataFrame ,
472473 metrics_df : pd .DataFrame ,
473474 test_metrics_df : pd .DataFrame ,
475+ test_data : pd .DataFrame ,
474476 ):
475477 """Saves resulting reports to the given folder."""
476478
477479 unique_output_dir = self .spec .output_directory .url
480+ results = ForecastResults ()
478481
479482 if ObjectStorageDetails .is_oci_path (unique_output_dir ):
480483 storage_options = default_signer ()
@@ -500,6 +503,11 @@ def _save_report(
500503 f2 .write (f1 .read ())
501504
502505 # forecast csv report
506+ # todo: add test data into forecast.csv
507+ # if self.spec.test_data is not None:
508+ # test_data_dict = test_data.get_dict_by_series()
509+ # for series_id, test_data_values in test_data_dict.items():
510+ # result_df[DataColumns.Series] = test_data_values[]
503511 result_df = (
504512 result_df
505513 if self .target_cat_col
@@ -511,6 +519,7 @@ def _save_report(
511519 format = "csv" ,
512520 storage_options = storage_options ,
513521 )
522+ results .set_forecast (result_df )
514523
515524 # metrics csv report
516525 if self .spec .generate_metrics :
@@ -520,17 +529,19 @@ def _save_report(
520529 else "Series 1"
521530 )
522531 if metrics_df is not None :
532+ metrics_df_formatted = metrics_df .reset_index ().rename (
533+ {"index" : "metrics" , "Series 1" : metrics_col_name }, axis = 1
534+ )
523535 write_data (
524- data = metrics_df .reset_index ().rename (
525- {"index" : "metrics" , "Series 1" : metrics_col_name }, axis = 1
526- ),
536+ data = metrics_df_formatted ,
527537 filename = os .path .join (
528538 unique_output_dir , self .spec .metrics_filename
529539 ),
530540 format = "csv" ,
531541 storage_options = storage_options ,
532542 index = False ,
533543 )
544+ results .set_metrics (metrics_df_formatted )
534545 else :
535546 logger .warn (
536547 f"Attempted to generate the { self .spec .metrics_filename } file with the training metrics, however the training metrics could not be properly generated."
@@ -539,17 +550,19 @@ def _save_report(
539550 # test_metrics csv report
540551 if self .spec .test_data is not None :
541552 if test_metrics_df is not None :
553+ test_metrics_df_formatted = test_metrics_df .reset_index ().rename (
554+ {"index" : "metrics" , "Series 1" : metrics_col_name }, axis = 1
555+ )
542556 write_data (
543- data = test_metrics_df .reset_index ().rename (
544- {"index" : "metrics" , "Series 1" : metrics_col_name }, axis = 1
545- ),
557+ data = test_metrics_df_formatted ,
546558 filename = os .path .join (
547559 unique_output_dir , self .spec .test_metrics_filename
548560 ),
549561 format = "csv" ,
550562 storage_options = storage_options ,
551563 index = False ,
552564 )
565+ results .set_test_metrics (test_metrics_df_formatted )
553566 else :
554567 logger .warn (
555568 f"Attempted to generate the { self .spec .test_metrics_filename } file with the test metrics, however the test metrics could not be properly generated."
@@ -567,6 +580,7 @@ def _save_report(
567580 storage_options = storage_options ,
568581 index = True ,
569582 )
583+ results .set_global_explanations (self .formatted_global_explanation )
570584 else :
571585 logger .warn (
572586 f"Attempted to generate global explanations for the { self .spec .global_explanation_filename } file, but an issue occured in formatting the explanations."
@@ -582,6 +596,7 @@ def _save_report(
582596 storage_options = storage_options ,
583597 index = True ,
584598 )
599+ results .set_local_explanations (self .formatted_local_explanation )
585600 else :
586601 logger .warn (
587602 f"Attempted to generate local explanations for the { self .spec .local_explanation_filename } file, but an issue occured in formatting the explanations."
@@ -602,10 +617,12 @@ def _save_report(
602617 index = True ,
603618 indent = 4 ,
604619 )
620+ results .set_model_parameters (self .model_parameters )
605621
606622 # model pickle
607623 if self .spec .generate_model_pickle :
608624 self ._save_model (unique_output_dir , storage_options )
625+ results .set_models (self .models )
609626
610627 logger .info (
611628 f"The outputs have been successfully "
@@ -625,8 +642,10 @@ def _save_report(
625642 index = True ,
626643 indent = 4 ,
627644 )
645+ results .set_errors_dict (self .errors_dict )
628646 else :
629647 logger .info ("All modeling completed successfully." )
648+ return results
630649
631650 def preprocess (self , df , series_id ):
632651 """The method that needs to be implemented on the particular model level."""
0 commit comments