Skip to content

Commit f65f440

Browse files
committed
sklearn r2 giving wrong score
1 parent 09d471a commit f65f440

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pandas as pd
88
from ..operator_config import ForecastOperatorConfig
99
from .. import utils
10-
from ads.opctl.operator.lowcode.forecast.utils import default_signer
1110
from .transformations import Transformations
1211
from ads.opctl import logger
1312
import pandas as pd
@@ -40,7 +39,6 @@ def _load_data(self, spec):
4039
raw_data = utils._load_data(
4140
filename=spec.historical_data.url,
4241
format=spec.historical_data.format,
43-
storage_options=default_signer(),
4442
columns=spec.historical_data.columns,
4543
)
4644
self.original_user_data = raw_data.copy()
@@ -71,7 +69,6 @@ def _load_data(self, spec):
7169
additional_data = utils._load_data(
7270
filename=spec.additional_data.url,
7371
format=spec.additional_data.format,
74-
storage_options=default_signer(),
7572
columns=spec.additional_data.columns,
7673
)
7774
additional_data = data_transformer._sort_by_datetime_col(additional_data)

ads/opctl/operator/lowcode/forecast/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
explained_variance_score,
1818
mean_absolute_percentage_error,
1919
mean_squared_error,
20-
r2_score,
2120
)
2221

22+
try:
23+
from scipy.stats import linregress
24+
except:
25+
from sklearn.metrics import r2_score
26+
2327
from ads.common.object_storage_details import ObjectStorageDetails
2428
from ads.dataset.label_encoder import DataFrameLabelEncoder
2529
from ads.opctl import logger
@@ -358,7 +362,10 @@ def _build_metrics_df(y_true, y_pred, column_name):
358362
metrics["sMAPE"] = smape(actual=y_true, predicted=y_pred)
359363
metrics["MAPE"] = mean_absolute_percentage_error(y_true=y_true, y_pred=y_pred)
360364
metrics["RMSE"] = np.sqrt(mean_squared_error(y_true=y_true, y_pred=y_pred))
361-
metrics["r2"] = r2_score(y_true=y_true, y_pred=y_pred)
365+
try:
366+
metrics["r2"] = linregress(y_true, y_pred).rvalue ** 2
367+
except:
368+
metrics["r2"] = r2_score(y_true=y_true, y_pred=y_pred)
362369
metrics["Explained Variance"] = explained_variance_score(
363370
y_true=y_true, y_pred=y_pred
364371
)

0 commit comments

Comments
 (0)