22
33# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5- import datetime
6- import json
75import os
8- import pathlib
9- import random
10- import subprocess
116import tempfile
127from copy import deepcopy
13- from pathlib import Path
14- from time import sleep , time
158
169import numpy as np
17- import pandas as pd
1810import pytest
19- import yaml
2011
21- from ads .opctl .operator .cmd import run
2212from ads .opctl .operator .lowcode .forecast .__main__ import operate as forecast_operate
23- from ads .opctl .operator .lowcode .forecast .model .forecast_datasets import (
24- ForecastDatasets ,
25- )
2613from ads .opctl .operator .lowcode .forecast .operator_config import (
2714 ForecastOperatorConfig ,
2815)
3320 # "automlx", # FIXME: automlx is failing, no errors
3421 "prophet" ,
3522 "neuralprophet" ,
23+ "auto-select-series" ,
3624]
3725
3826TEMPLATE_YAML = {
@@ -170,31 +158,31 @@ def test_explanations_output_and_columns(model, freq, num_series):
170158 global_explanations = results .get_global_explanations ()
171159 local_explanations = results .get_local_explanations ()
172160
173- assert (
174- not ( global_explanations . isna ()). all (). all ()
175- ), "Global explanations contain NaN values"
176- assert (
177- not ( global_explanations == 0 ). all (). all ()
178- ), "Global explanations contain only 0 values"
179- assert (
180- not ( local_explanations . isna ()). all (). all ()
181- ), "Local explanations contain NaN values"
182- assert (
183- not ( local_explanations == 0 ). all (). all ()
184- ), "Local explanations contain only 0 values"
161+ assert not ( global_explanations . isna ()). all (). all (), (
162+ "Global explanations contain NaN values"
163+ )
164+ assert not ( global_explanations == 0 ). all (). all (), (
165+ "Global explanations contain only 0 values"
166+ )
167+ assert not ( local_explanations . isna ()). all (). all (), (
168+ "Local explanations contain NaN values"
169+ )
170+ assert not ( local_explanations == 0 ). all (). all (), (
171+ "Local explanations contain only 0 values"
172+ )
185173
186174 additional_columns = list (
187175 set (additional .columns .tolist ())
188176 - set (operator_config .spec .target_category_columns )
189177 - {operator_config .spec .datetime_column .name }
190178 )
191179 for column in additional_columns :
192- assert (
193- column in global_explanations . T . columns
194- ), f"Column { column } missing in global explanations"
195- assert (
196- column in local_explanations . columns
197- ), f"Column { column } missing in local explanations"
180+ assert column in global_explanations . T . columns , (
181+ f"Column { column } missing in global explanations"
182+ )
183+ assert column in local_explanations . columns , (
184+ f"Column { column } missing in local explanations"
185+ )
198186
199187
200188@pytest .mark .parametrize ("model" , MODELS ) # MODELS
@@ -221,24 +209,60 @@ def test_explanations_filenames(model, num_series):
221209 operator_config .spec .local_explanation_filename = local_explanation_filename
222210
223211 results = forecast_operate (operator_config )
224- assert (
225- not results .get_global_explanations ().empty
226- ), "Error generating Global Expl"
227- assert not results .get_local_explanations ().empty , "Error generating Local Expl"
228-
229- global_explanation_path = os .path .join (
230- output_directory , global_explanation_filename
231- )
232- local_explanation_path = os .path .join (
233- output_directory , local_explanation_filename
212+ assert not results .get_global_explanations ().empty , (
213+ "Error generating Global Expl"
234214 )
215+ assert not results .get_local_explanations ().empty , "Error generating Local Expl"
235216
236- assert os .path .exists (
237- global_explanation_path
238- ), f"Global explanation file not found at { global_explanation_path } "
239- assert os .path .exists (
240- local_explanation_path
241- ), f"Local explanation file not found at { local_explanation_path } "
217+ if model == "auto-select-series" :
218+ # List all files in output directory
219+ files = os .listdir (output_directory )
220+ # Find all explanation files
221+ global_explanation_files = [
222+ f
223+ for f in files
224+ if f .startswith ("custom_global_explanation_" ) and f .endswith (".csv" )
225+ ]
226+ local_explanation_files = [
227+ f
228+ for f in files
229+ if f .startswith ("custom_local_explanation_" ) and f .endswith (".csv" )
230+ ]
231+
232+ # Should have at least one file of each type
233+ assert len (global_explanation_files ) > 0 , (
234+ "No global explanation files found for auto-select-series"
235+ )
236+ assert len (local_explanation_files ) > 0 , (
237+ "No local explanation files found for auto-select-series"
238+ )
239+
240+ # Check each file exists
241+ for gfile in global_explanation_files :
242+ gpath = os .path .join (output_directory , gfile )
243+ assert os .path .exists (gpath ), (
244+ f"Global explanation file not found at { gpath } "
245+ )
246+
247+ for lfile in local_explanation_files :
248+ lpath = os .path .join (output_directory , lfile )
249+ assert os .path .exists (lpath ), (
250+ f"Local explanation file not found at { lpath } "
251+ )
252+ else :
253+ global_explanation_path = os .path .join (
254+ output_directory , global_explanation_filename
255+ )
256+ local_explanation_path = os .path .join (
257+ output_directory , local_explanation_filename
258+ )
259+
260+ assert os .path .exists (global_explanation_path ), (
261+ f"Global explanation file not found at { global_explanation_path } "
262+ )
263+ assert os .path .exists (local_explanation_path ), (
264+ f"Local explanation file not found at { local_explanation_path } "
265+ )
242266
243267
244268@pytest .mark .parametrize ("model" , MODELS )
@@ -297,7 +321,7 @@ def test_explanations_accuracy_mode(mode, model, num_series):
297321 operator_config .spec .output_directory .url = output_directory
298322 operator_config .spec .explanations_accuracy_mode = mode
299323
300- results = forecast_operate (operator_config )
324+ forecast_operate (operator_config )
301325
302326 global_explanation_path = os .path .join (
303327 output_directory , operator_config .spec .global_explanation_filename
@@ -306,12 +330,12 @@ def test_explanations_accuracy_mode(mode, model, num_series):
306330 output_directory , operator_config .spec .local_explanation_filename
307331 )
308332
309- assert os .path .exists (
310- global_explanation_path
311- ), f"Global explanation file not found at { global_explanation_path } "
312- assert os .path .exists (
313- local_explanation_path
314- ), f"Local explanation file not found at { local_explanation_path } "
333+ assert os .path .exists (global_explanation_path ), (
334+ f"Global explanation file not found at { global_explanation_path } "
335+ )
336+ assert os .path .exists (local_explanation_path ), (
337+ f"Local explanation file not found at { local_explanation_path } "
338+ )
315339
316340
317341@pytest .mark .parametrize ("model" , MODELS )
@@ -345,19 +369,18 @@ def test_explanations_values(model, num_series, freq):
345369
346370 # Check decimal precision for local explanations
347371 local_numeric = local_explanations .select_dtypes (include = ["int64" , "float64" ])
348- assert np .allclose (local_numeric , np .round (local_numeric , 4 ), atol = 1e-8 ), \
372+ assert np .allclose (local_numeric , np .round (local_numeric , 4 ), atol = 1e-8 ), (
349373 "Local explanations have values with more than 4 decimal places"
374+ )
350375
351376 # Check decimal precision for global explanations
352377 global_explanations = results .get_global_explanations ()
353378 global_numeric = global_explanations .select_dtypes (include = ["int64" , "float64" ])
354- assert np .allclose (global_numeric , np .round (global_numeric , 4 ), atol = 1e-8 ), \
379+ assert np .allclose (global_numeric , np .round (global_numeric , 4 ), atol = 1e-8 ), (
355380 "Global explanations have values with more than 4 decimal places"
356-
357- local_explain_vals = (
358- local_numeric .sum (axis = 1 )
359- + forecast .fitted_value .mean ()
360381 )
382+
383+ local_explain_vals = local_numeric .sum (axis = 1 ) + forecast .fitted_value .mean ()
361384 assert np .allclose (
362385 local_explain_vals ,
363386 forecast [- operator_config .spec .horizon :]["forecast_value" ],
0 commit comments