1+
12#!/usr/bin/env python
23# -*- coding: utf-8 -*--
34
910import fsspec
1011from .operator_config import AnomalyOperatorSpec
1112from .const import SupportedMetrics
12-
13+ from ads . opctl import logger
1314
1415def _build_metrics_df (y_true , y_pred , column_name ):
1516 from sklearn .metrics import recall_score , precision_score , accuracy_score , f1_score , confusion_matrix , \
@@ -19,12 +20,18 @@ def _build_metrics_df(y_true, y_pred, column_name):
1920 metrics [SupportedMetrics .PRECISION ] = precision_score (y_true , y_pred )
2021 metrics [SupportedMetrics .ACCURACY ] = accuracy_score (y_true , y_pred )
2122 metrics [SupportedMetrics .F1_SCORE ] = f1_score (y_true , y_pred )
22- tn , fp , fn , tp = confusion_matrix (y_true , y_pred ).ravel ()
23+ tn , * fn_fp_tp = confusion_matrix (y_true , y_pred ).ravel ()
24+ fp , fn , tp = fn_fp_tp if fn_fp_tp else (0 , 0 , 0 )
2325 metrics [SupportedMetrics .FP ] = fp
2426 metrics [SupportedMetrics .FN ] = fn
2527 metrics [SupportedMetrics .TP ] = tp
2628 metrics [SupportedMetrics .TN ] = tn
27- metrics [SupportedMetrics .ROC_AUC ] = roc_auc_score (y_true , y_pred )
29+ try :
30+ # Throws exception if y_true has only one class
31+ metrics [SupportedMetrics .ROC_AUC ] = roc_auc_score (y_true , y_pred )
32+ except Exception as e :
33+ logger .warn (f"An exception occurred: { e } " )
34+ metrics [SupportedMetrics .ROC_AUC ] = None
2835 precision , recall , thresholds = precision_recall_curve (y_true , y_pred )
2936 metrics [SupportedMetrics .PRC_AUC ] = auc (recall , precision )
3037 metrics [SupportedMetrics .MCC ] = matthews_corrcoef (y_true , y_pred )
@@ -64,6 +71,12 @@ def _write_data(data, filename, format, storage_options, index=False, **kwargs):
6471 )
6572 raise ValueError (f"Unrecognized format: { format } " )
6673
74+ def _merge_category_columns (data , target_category_columns ):
75+ result = data .apply (
76+ lambda x : "__" .join ([str (x [col ]) for col in target_category_columns ]), axis = 1
77+ )
78+ return result if not result .empty else pd .Series ([], dtype = str )
79+
6780
6881def get_frequency_of_datetime (data : pd .DataFrame , dataset_info : AnomalyOperatorSpec ):
6982 """
0 commit comments