@@ -198,20 +198,27 @@ def _get_binary_metrics(self):
198198 self .y_true , self .y_pred , pos_label = self .positive_class , average = "binary"
199199 )
200200
201- (
202- self .metrics ["false_positive_rate" ],
203- self .metrics ["true_positive_rate" ],
204- _ ,
205- ) = metrics .roc_curve (self .y_true , self .y_pred , pos_label = self .positive_class )
206- self .metrics ["auc" ] = metrics .auc (
207- self .metrics ["false_positive_rate" ], self .metrics ["true_positive_rate" ]
208- )
209-
210201 if self .y_score is not None :
211202 if not all (0 >= x >= 1 for x in self .y_score ):
212203 self .y_score = np .asarray (
213204 [0 if x < 0 else 1 if x > 1 else x for x in self .y_score ]
214205 )
206+ if len (np .asarray (self .y_score ).shape ) > 1 :
207+ # If the SKLearn classifier doesn't correctly identify the problem as
208+ # binary classification, y_score may be of shape (n_rows, 2)
209+ # instead of (n_rows,)
210+ pos_class_idx = self .classes .index (self .positive_class )
211+ positive_class_scores = self .y_score [:, pos_class_idx ]
212+ else :
213+ positive_class_scores = self .y_score
214+ (
215+ self .metrics ["false_positive_rate" ],
216+ self .metrics ["true_positive_rate" ],
217+ _ ,
218+ ) = metrics .roc_curve (y_true = self .y_true , y_score = positive_class_scores , pos_label = self .positive_class )
219+ self .metrics ["auc" ] = metrics .auc (
220+ self .metrics ["false_positive_rate" ], self .metrics ["true_positive_rate" ]
221+ )
215222 self .y_score = list (self .y_score )
216223 self .metrics ["youden_j" ] = (
217224 self .metrics ["true_positive_rate" ] - self .metrics ["false_positive_rate" ]
0 commit comments