Skip to content

Commit bc9e010

Browse files
committed
enable auto-select-series testcases
1 parent c9b73ed commit bc9e010

File tree

2 files changed

+160
-82
lines changed

2 files changed

+160
-82
lines changed

tests/operators/forecast/test_datasets.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,21 @@
22

33
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5-
import os
65
import json
7-
import yaml
8-
import tempfile
6+
import os
97
import subprocess
8+
import tempfile
9+
from copy import deepcopy
10+
from time import sleep
11+
1012
import pandas as pd
1113
import pytest
12-
from time import sleep, time
13-
from copy import deepcopy
14-
from pathlib import Path
15-
import random
16-
import pathlib
17-
import datetime
14+
import yaml
15+
1816
from ads.opctl.operator.cmd import run
1917
from ads.opctl.operator.lowcode.forecast.__main__ import operate as forecast_operate
2018
from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorConfig
2119

22-
2320
DATASET_PREFIX = f"{os.path.dirname(os.path.abspath(__file__))}/../data/timeseries/"
2421

2522
DATASETS_LIST = [
@@ -37,6 +34,7 @@
3734
"autots",
3835
# "lgbforecast",
3936
"auto-select",
37+
"auto-select-series",
4038
]
4139

4240
TEMPLATE_YAML = {
@@ -77,14 +75,43 @@
7775

7876

7977
def verify_explanations(tmpdirname, additional_cols, target_category_columns):
80-
glb_expl = pd.read_csv(f"{tmpdirname}/results/global_explanation.csv", index_col=0)
81-
loc_expl = pd.read_csv(f"{tmpdirname}/results/local_explanation.csv")
82-
assert loc_expl.shape[0] == PERIODS
83-
columns = ["Date", "Series"]
84-
if not target_category_columns:
85-
columns.remove("Series")
86-
for x in columns:
87-
assert x in set(loc_expl.columns)
78+
result_files = os.listdir(f"{tmpdirname}/results")
79+
if model == "auto-select-series":
80+
# Find all local and global explanation files
81+
local_expl_files = [
82+
f
83+
for f in result_files
84+
if f.startswith("local_explanation_") and f.endswith(".csv")
85+
]
86+
global_expl_files = [
87+
f
88+
for f in result_files
89+
if f.startswith("global_explanation_") and f.endswith(".csv")
90+
]
91+
92+
# Verify for each model's explanation files
93+
for loc_file, glb_file in zip(local_expl_files, global_expl_files):
94+
glb_expl = pd.read_csv(f"{tmpdirname}/results/{glb_file}", index_col=0)
95+
loc_expl = pd.read_csv(f"{tmpdirname}/results/{loc_file}")
96+
97+
assert loc_expl.shape[0] == PERIODS
98+
columns = ["Date", "Series"]
99+
if not target_category_columns:
100+
columns.remove("Series")
101+
for x in columns:
102+
assert x in set(loc_expl.columns)
103+
else:
104+
glb_expl = pd.read_csv(
105+
f"{tmpdirname}/results/global_explanation.csv", index_col=0
106+
)
107+
loc_expl = pd.read_csv(f"{tmpdirname}/results/local_explanation.csv")
108+
109+
assert loc_expl.shape[0] == PERIODS
110+
columns = ["Date", "Series"]
111+
if not target_category_columns:
112+
columns.remove("Series")
113+
for x in columns:
114+
assert x in set(loc_expl.columns)
88115
# for x in additional_cols:
89116
# assert x in set(loc_expl.columns)
90117
# assert x in set(glb_expl.index)
@@ -159,10 +186,38 @@ def test_load_datasets(model, data_details):
159186
target_category_columns=yaml_i["spec"]["target_category_columns"],
160187
)
161188
if include_test_data:
162-
test_metrics = pd.read_csv(f"{tmpdirname}/results/test_metrics.csv")
163-
print(test_metrics)
164-
train_metrics = pd.read_csv(f"{tmpdirname}/results/metrics.csv")
165-
print(train_metrics)
189+
result_files = os.listdir(f"{tmpdirname}/results")
190+
if model == "auto-select-series":
191+
# Find all metrics files for each model
192+
test_metrics_files = [
193+
f
194+
for f in result_files
195+
if f.startswith("test_metrics_") and f.endswith(".csv")
196+
]
197+
train_metrics_files = [
198+
f
199+
for f in result_files
200+
if f.startswith("metrics_") and f.endswith(".csv")
201+
]
202+
203+
# Print metrics for each model
204+
for test_file, train_file in zip(
205+
test_metrics_files, train_metrics_files
206+
):
207+
print(
208+
f"\nMetrics for {test_file.replace('test_metrics_', '').replace('.csv', '')}:"
209+
)
210+
test_metrics = pd.read_csv(f"{tmpdirname}/results/{test_file}")
211+
print("Test metrics:")
212+
print(test_metrics)
213+
train_metrics = pd.read_csv(f"{tmpdirname}/results/{train_file}")
214+
print("Train metrics:")
215+
print(train_metrics)
216+
else:
217+
test_metrics = pd.read_csv(f"{tmpdirname}/results/test_metrics.csv")
218+
print(test_metrics)
219+
train_metrics = pd.read_csv(f"{tmpdirname}/results/metrics.csv")
220+
print(train_metrics)
166221

167222

168223
@pytest.mark.parametrize("model", MODELS[:-2])

tests/operators/forecast/test_explainers.py

Lines changed: 83 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,14 @@
22

33
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5-
import datetime
6-
import json
75
import os
8-
import pathlib
9-
import random
10-
import subprocess
116
import tempfile
127
from copy import deepcopy
13-
from pathlib import Path
14-
from time import sleep, time
158

169
import numpy as np
17-
import pandas as pd
1810
import pytest
19-
import yaml
2011

21-
from ads.opctl.operator.cmd import run
2212
from ads.opctl.operator.lowcode.forecast.__main__ import operate as forecast_operate
23-
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import (
24-
ForecastDatasets,
25-
)
2613
from ads.opctl.operator.lowcode.forecast.operator_config import (
2714
ForecastOperatorConfig,
2815
)
@@ -33,6 +20,7 @@
3320
# "automlx", # FIXME: automlx is failing, no errors
3421
"prophet",
3522
"neuralprophet",
23+
"auto-select-series",
3624
]
3725

3826
TEMPLATE_YAML = {
@@ -170,31 +158,31 @@ def test_explanations_output_and_columns(model, freq, num_series):
170158
global_explanations = results.get_global_explanations()
171159
local_explanations = results.get_local_explanations()
172160

173-
assert (
174-
not (global_explanations.isna()).all().all()
175-
), "Global explanations contain NaN values"
176-
assert (
177-
not (global_explanations == 0).all().all()
178-
), "Global explanations contain only 0 values"
179-
assert (
180-
not (local_explanations.isna()).all().all()
181-
), "Local explanations contain NaN values"
182-
assert (
183-
not (local_explanations == 0).all().all()
184-
), "Local explanations contain only 0 values"
161+
assert not (global_explanations.isna()).all().all(), (
162+
"Global explanations contain NaN values"
163+
)
164+
assert not (global_explanations == 0).all().all(), (
165+
"Global explanations contain only 0 values"
166+
)
167+
assert not (local_explanations.isna()).all().all(), (
168+
"Local explanations contain NaN values"
169+
)
170+
assert not (local_explanations == 0).all().all(), (
171+
"Local explanations contain only 0 values"
172+
)
185173

186174
additional_columns = list(
187175
set(additional.columns.tolist())
188176
- set(operator_config.spec.target_category_columns)
189177
- {operator_config.spec.datetime_column.name}
190178
)
191179
for column in additional_columns:
192-
assert (
193-
column in global_explanations.T.columns
194-
), f"Column {column} missing in global explanations"
195-
assert (
196-
column in local_explanations.columns
197-
), f"Column {column} missing in local explanations"
180+
assert column in global_explanations.T.columns, (
181+
f"Column {column} missing in global explanations"
182+
)
183+
assert column in local_explanations.columns, (
184+
f"Column {column} missing in local explanations"
185+
)
198186

199187

200188
@pytest.mark.parametrize("model", MODELS) # MODELS
@@ -221,24 +209,60 @@ def test_explanations_filenames(model, num_series):
221209
operator_config.spec.local_explanation_filename = local_explanation_filename
222210

223211
results = forecast_operate(operator_config)
224-
assert (
225-
not results.get_global_explanations().empty
226-
), "Error generating Global Expl"
227-
assert not results.get_local_explanations().empty, "Error generating Local Expl"
228-
229-
global_explanation_path = os.path.join(
230-
output_directory, global_explanation_filename
231-
)
232-
local_explanation_path = os.path.join(
233-
output_directory, local_explanation_filename
212+
assert not results.get_global_explanations().empty, (
213+
"Error generating Global Expl"
234214
)
215+
assert not results.get_local_explanations().empty, "Error generating Local Expl"
235216

236-
assert os.path.exists(
237-
global_explanation_path
238-
), f"Global explanation file not found at {global_explanation_path}"
239-
assert os.path.exists(
240-
local_explanation_path
241-
), f"Local explanation file not found at {local_explanation_path}"
217+
if model == "auto-select-series":
218+
# List all files in output directory
219+
files = os.listdir(output_directory)
220+
# Find all explanation files
221+
global_explanation_files = [
222+
f
223+
for f in files
224+
if f.startswith("custom_global_explanation_") and f.endswith(".csv")
225+
]
226+
local_explanation_files = [
227+
f
228+
for f in files
229+
if f.startswith("custom_local_explanation_") and f.endswith(".csv")
230+
]
231+
232+
# Should have at least one file of each type
233+
assert len(global_explanation_files) > 0, (
234+
"No global explanation files found for auto-select-series"
235+
)
236+
assert len(local_explanation_files) > 0, (
237+
"No local explanation files found for auto-select-series"
238+
)
239+
240+
# Check each file exists
241+
for gfile in global_explanation_files:
242+
gpath = os.path.join(output_directory, gfile)
243+
assert os.path.exists(gpath), (
244+
f"Global explanation file not found at {gpath}"
245+
)
246+
247+
for lfile in local_explanation_files:
248+
lpath = os.path.join(output_directory, lfile)
249+
assert os.path.exists(lpath), (
250+
f"Local explanation file not found at {lpath}"
251+
)
252+
else:
253+
global_explanation_path = os.path.join(
254+
output_directory, global_explanation_filename
255+
)
256+
local_explanation_path = os.path.join(
257+
output_directory, local_explanation_filename
258+
)
259+
260+
assert os.path.exists(global_explanation_path), (
261+
f"Global explanation file not found at {global_explanation_path}"
262+
)
263+
assert os.path.exists(local_explanation_path), (
264+
f"Local explanation file not found at {local_explanation_path}"
265+
)
242266

243267

244268
@pytest.mark.parametrize("model", MODELS)
@@ -297,7 +321,7 @@ def test_explanations_accuracy_mode(mode, model, num_series):
297321
operator_config.spec.output_directory.url = output_directory
298322
operator_config.spec.explanations_accuracy_mode = mode
299323

300-
results = forecast_operate(operator_config)
324+
forecast_operate(operator_config)
301325

302326
global_explanation_path = os.path.join(
303327
output_directory, operator_config.spec.global_explanation_filename
@@ -306,12 +330,12 @@ def test_explanations_accuracy_mode(mode, model, num_series):
306330
output_directory, operator_config.spec.local_explanation_filename
307331
)
308332

309-
assert os.path.exists(
310-
global_explanation_path
311-
), f"Global explanation file not found at {global_explanation_path}"
312-
assert os.path.exists(
313-
local_explanation_path
314-
), f"Local explanation file not found at {local_explanation_path}"
333+
assert os.path.exists(global_explanation_path), (
334+
f"Global explanation file not found at {global_explanation_path}"
335+
)
336+
assert os.path.exists(local_explanation_path), (
337+
f"Local explanation file not found at {local_explanation_path}"
338+
)
315339

316340

317341
@pytest.mark.parametrize("model", MODELS)
@@ -345,19 +369,18 @@ def test_explanations_values(model, num_series, freq):
345369

346370
# Check decimal precision for local explanations
347371
local_numeric = local_explanations.select_dtypes(include=["int64", "float64"])
348-
assert np.allclose(local_numeric, np.round(local_numeric, 4), atol=1e-8), \
372+
assert np.allclose(local_numeric, np.round(local_numeric, 4), atol=1e-8), (
349373
"Local explanations have values with more than 4 decimal places"
374+
)
350375

351376
# Check decimal precision for global explanations
352377
global_explanations = results.get_global_explanations()
353378
global_numeric = global_explanations.select_dtypes(include=["int64", "float64"])
354-
assert np.allclose(global_numeric, np.round(global_numeric, 4), atol=1e-8), \
379+
assert np.allclose(global_numeric, np.round(global_numeric, 4), atol=1e-8), (
355380
"Global explanations have values with more than 4 decimal places"
356-
357-
local_explain_vals = (
358-
local_numeric.sum(axis=1)
359-
+ forecast.fitted_value.mean()
360381
)
382+
383+
local_explain_vals = local_numeric.sum(axis=1) + forecast.fitted_value.mean()
361384
assert np.allclose(
362385
local_explain_vals,
363386
forecast[-operator_config.spec.horizon :]["forecast_value"],

0 commit comments

Comments
 (0)