@@ -336,11 +336,13 @@ def plot_roc_curve(y_true, y_probas, title='ROC Curves',
336336 return ax
337337
338338
339- def plot_roc (y_true , y_probas , title = 'ROC Curves' ,
340- plot_micro = True , plot_macro = True , classes_to_plot = None ,
341- ax = None , figsize = None , cmap = 'nipy_spectral' ,
342- title_fontsize = "large" , text_fontsize = "medium" ,
343- show_labels = True ,):
339+ def plot_roc (
340+ y_true , y_probas , title = 'ROC Curves' ,
341+ plot_micro = True , plot_macro = True , classes_to_plot = None ,
342+ ax = None , figsize = None , cmap = 'nipy_spectral' ,
343+ title_fontsize = "large" , text_fontsize = "medium" ,
344+ show_labels = True , digits = 3 ,
345+ ):
344346 """Generates the ROC curves from labels and predicted scores/probabilities
345347
346348 Args:
@@ -386,6 +388,9 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
386388 show_labels (boolean, optional): Shows the labels in the plot.
387389 Defaults to ``True``.
388390
391+ digits (int, optional): Number of digits for formatting output floating point values.
392+ Use e.g. 2 or 4. Defaults to 3.
393+
389394 Returns:
390395 ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
391396 drawn.
@@ -428,8 +433,8 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
428433 roc_auc = auc (fpr_dict [i ], tpr_dict [i ])
429434 color = plt .cm .get_cmap (cmap )(float (i ) / len (classes ))
430435 ax .plot (fpr_dict [i ], tpr_dict [i ], lw = 2 , color = color ,
431- label = 'ROC curve of class {0} (area = {1:0.2f })'
432- '' .format (classes [i ], roc_auc ))
436+ label = 'ROC curve of class {0} (area = {1:.{digits}f })'
437+ '' .format (classes [i ], roc_auc , digits = digits ))
433438
434439 if plot_micro :
435440 binarized_y_true = label_binarize (y_true , classes = classes )
@@ -440,7 +445,7 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
440445 roc_auc = auc (fpr , tpr )
441446 ax .plot (fpr , tpr ,
442447 label = 'micro-average ROC curve '
443- '(area = {0:0.2f} )' .format (roc_auc ),
448+ '(area = {0:.{digits}f} )' .format (roc_auc , digits = digits ),
444449 color = 'deeppink' , linestyle = ':' , linewidth = 4 )
445450
446451 if plot_macro :
@@ -459,7 +464,7 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
459464
460465 ax .plot (all_fpr , mean_tpr ,
461466 label = 'macro-average ROC curve '
462- '(area = {0:0.2f} )' .format (roc_auc ),
467+ '(area = {0:.{digits}f} )' .format (roc_auc , digits = digits ),
463468 color = 'navy' , linestyle = ':' , linewidth = 4 )
464469
465470 ax .plot ([0 , 1 ], [0 , 1 ], 'k--' , lw = 2 )
@@ -475,7 +480,7 @@ def plot_roc(y_true, y_probas, title='ROC Curves',
475480
476481def plot_ks_statistic (y_true , y_probas , title = 'KS Statistic Plot' ,
477482 ax = None , figsize = None , title_fontsize = "large" ,
478- text_fontsize = "medium" ):
483+ text_fontsize = "medium" , digits = 3 ):
479484 """Generates the KS Statistic plot from labels and scores/probabilities
480485
481486 Args:
@@ -503,6 +508,9 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot',
503508 Use e.g. "small", "medium", "large" or integer-values. Defaults to
504509 "medium".
505510
511+ digits (int, optional): Number of digits for formatting output floating point values.
512+ Use e.g. 2 or 4. Defaults to 3.
513+
506514 Returns:
507515 ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
508516 drawn.
@@ -543,9 +551,10 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot',
543551 ax .plot (thresholds , pct2 , lw = 3 , label = 'Class {}' .format (classes [1 ]))
544552 idx = np .where (thresholds == max_distance_at )[0 ][0 ]
545553 ax .axvline (max_distance_at , * sorted ([pct1 [idx ], pct2 [idx ]]),
546- label = 'KS Statistic: {:.3f} at {:.3f}' .format (ks_statistic ,
547- max_distance_at ),
548- linestyle = ':' , lw = 3 , color = 'black' )
554+ label = 'KS Statistic: {:.{digits}f} at {:.{digits}f}' .format (
555+ ks_statistic , max_distance_at , digits = digits
556+ ),
557+ linestyle = ':' , lw = 3 , color = 'black' )
549558
550559 ax .set_xlim ([0.0 , 1.0 ])
551560 ax .set_ylim ([0.0 , 1.0 ])
@@ -685,13 +694,16 @@ def plot_precision_recall_curve(y_true, y_probas,
685694 return ax
686695
687696
688- def plot_precision_recall (y_true , y_probas ,
689- title = 'Precision-Recall Curve' ,
690- plot_micro = True ,
691- classes_to_plot = None , ax = None ,
692- figsize = None , cmap = 'nipy_spectral' ,
693- title_fontsize = "large" ,
694- text_fontsize = "medium" ):
697+ def plot_precision_recall (
698+ y_true , y_probas ,
699+ title = 'Precision-Recall Curve' ,
700+ plot_micro = True ,
701+ classes_to_plot = None , ax = None ,
702+ figsize = None , cmap = 'nipy_spectral' ,
703+ title_fontsize = "large" ,
704+ text_fontsize = "medium" ,
705+ digits = 3 ,
706+ ):
695707 """Generates the Precision Recall Curve from labels and probabilities
696708
697709 Args:
@@ -731,6 +743,9 @@ def plot_precision_recall(y_true, y_probas,
731743 Use e.g. "small", "medium", "large" or integer-values. Defaults to
732744 "medium".
733745
746+ digits (int, optional): Number of digits for formatting output floating point values.
747+ Use e.g. 2 or 4. Defaults to 3.
748+
734749 Returns:
735750 ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
736751 drawn.
@@ -778,8 +793,9 @@ def plot_precision_recall(y_true, y_probas,
778793 color = plt .cm .get_cmap (cmap )(float (i ) / len (classes ))
779794 ax .plot (recall , precision , lw = 2 ,
780795 label = 'Precision-recall curve of class {0} '
781- '(area = {1:0.3f})' .format (classes [i ],
782- average_precision ),
796+ '(area = {1:.{digits}f})' .format (classes [i ],
797+ average_precision ,
798+ digits = digits ),
783799 color = color )
784800
785801 if plot_micro :
@@ -790,7 +806,7 @@ def plot_precision_recall(y_true, y_probas,
790806 average = 'micro' )
791807 ax .plot (recall , precision ,
792808 label = 'micro-average Precision-recall curve '
793- '(area = {0:0.3f} )' .format (average_precision ),
809+ '(area = {0:.{digits}f} )' .format (average_precision , digits = digits ),
794810 color = 'navy' , linestyle = ':' , linewidth = 4 )
795811
796812 ax .set_xlim ([0.0 , 1.0 ])
@@ -802,10 +818,12 @@ def plot_precision_recall(y_true, y_probas,
802818 return ax
803819
804820
805- def plot_silhouette (X , cluster_labels , title = 'Silhouette Analysis' ,
806- metric = 'euclidean' , copy = True , ax = None , figsize = None ,
807- cmap = 'nipy_spectral' , title_fontsize = "large" ,
808- text_fontsize = "medium" ):
821+ def plot_silhouette (
822+ X , cluster_labels , title = 'Silhouette Analysis' ,
823+ metric = 'euclidean' , copy = True , ax = None , figsize = None ,
824+ cmap = 'nipy_spectral' , title_fontsize = "large" ,
825+ text_fontsize = "medium" , digits = 3 ,
826+ ):
809827 """Plots silhouette analysis of clusters provided.
810828
811829 Args:
@@ -847,6 +865,9 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
847865 Use e.g. "small", "medium", "large" or integer-values. Defaults to
848866 "medium".
849867
868+ digits (int, optional): Number of digits for formatting output floating point values.
869+ Use e.g. 2 or 4. Defaults to 3.
870+
850871 Returns:
851872 ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
852873 drawn.
@@ -908,8 +929,10 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
908929
909930 y_lower = y_upper + 10
910931
911- ax .axvline (x = silhouette_avg , color = "red" , linestyle = "--" ,
912- label = 'Silhouette score: {0:0.3f}' .format (silhouette_avg ))
932+ ax .axvline (
933+ x = silhouette_avg , color = "red" , linestyle = "--" ,
934+ label = 'Silhouette score: {0:.{digits}f}' .format (silhouette_avg , digits = 2 )
935+ )
913936
914937 ax .set_yticks ([]) # Clear the y-axis labels / ticks
915938 ax .set_xticks (np .arange (- 0.1 , 1.0 , 0.2 ))
@@ -920,11 +943,13 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
920943 return ax
921944
922945
923- def plot_calibration_curve (y_true , probas_list , clf_names = None , n_bins = 10 ,
924- title = 'Calibration plots (Reliability Curves)' ,
925- ax = None , figsize = None , cmap = 'nipy_spectral' ,
926- title_fontsize = "large" , text_fontsize = "medium" ,
927- pos_label = None , strategy = "uniform" ,):
946+ def plot_calibration_curve (
947+ y_true , probas_list , clf_names = None , n_bins = 10 ,
948+ title = 'Calibration plots (Reliability Curves)' ,
949+ ax = None , figsize = None , cmap = 'nipy_spectral' ,
950+ title_fontsize = "large" , text_fontsize = "medium" ,
951+ pos_label = None , strategy = "uniform" ,
952+ ):
928953 """Plots calibration curves for a set of classifier probability estimates.
929954
930955 Plotting the calibration curves of a classifier is useful for determining
@@ -1073,9 +1098,13 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
10731098 return ax
10741099
10751100
1076- def plot_cumulative_gain (y_true , y_probas , title = 'Cumulative Gains Curve' ,
1077- ax = None , figsize = None , title_fontsize = "large" ,
1078- text_fontsize = "medium" , class_names = None ):
1101+ def plot_cumulative_gain (
1102+ y_true , y_probas , title = 'Cumulative Gains Curve' ,
1103+ classes_to_plot = None , plot_micro = True , plot_macro = True ,
1104+ ax = None , figsize = None , title_fontsize = "large" ,
1105+ text_fontsize = "medium" , cmap = 'nipy_spectral' ,
1106+ class_names = None ,
1107+ ):
10791108 """Generates the Cumulative Gains Plot from labels and scores/probabilities
10801109
10811110 The cumulative gains chart is used to determine the effectiveness of a
@@ -1093,6 +1122,17 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
10931122 title (string, optional): Title of the generated plot. Defaults to
10941123 "Cumulative Gains Curve".
10951124
1125+ classes_to_plot (list-like, optional): Classes for which the Cumulative Gain
1126+ curve should be plotted. e.g. [0, 'cold']. If given class does not exist,
1127+ it will be ignored. If ``None``, all classes will be plotted. Defaults to
1128+ ``None``
1129+
1130+ plot_micro (boolean, optional): Plot the micro average ROC curve.
1131+ Defaults to ``True``.
1132+
1133+ plot_macro (boolean, optional): Plot the macro average ROC curve.
1134+ Defaults to ``True``.
1135+
10961136 ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
10971137 plot the learning curve. If None, the plot is drawn on a new set of
10981138 axes.
@@ -1107,6 +1147,11 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11071147 text_fontsize (string or int, optional): Matplotlib-style fontsizes.
11081148 Use e.g. "small", "medium", "large" or integer-values. Defaults to
11091149 "medium".
1150+
1151+ cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
1152+ Colormap used for plotting the projection. View Matplotlib Colormap
1153+ documentation for available options.
1154+ https://matplotlib.org/users/colormaps.html
11101155
11111156 class_names (list of strings, optional): List of class names. Used for
11121157 the legend. Order should be synchronized with the order of classes
@@ -1129,28 +1174,58 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11291174 :align: center
11301175 :alt: Cumulative Gains Plot
11311176 """
1177+ if ax is None :
1178+ fig , ax = plt .subplots (1 , 1 , figsize = figsize )
1179+ ax .set_title (title , fontsize = title_fontsize )
1180+
11321181 y_true = np .array (y_true )
11331182 y_probas = np .array (y_probas )
1134-
11351183 classes = np .unique (y_true )
1136- if class_names is None : class_names = classes
1137- if len (classes ) != 2 :
1184+
1185+ if classes_to_plot is None :
1186+ classes_to_plot = classes
1187+ if class_names is None : class_names = classes_to_plot
1188+
1189+ if len (classes_to_plot ) != 2 :
11381190 raise ValueError ('Cannot calculate Cumulative Gains for data with '
11391191 '{} category/ies' .format (len (classes )))
11401192
1141- # Compute Cumulative Gain Curves
1142- percentages , gains1 = cumulative_gain_curve (y_true , y_probas [:, 0 ],
1143- classes [0 ])
1144- percentages , gains2 = cumulative_gain_curve (y_true , y_probas [:, 1 ],
1145- classes [1 ])
1193+ perc_dict = dict ()
1194+ gain_dict = dict ()
11461195
1147- if ax is None :
1148- fig , ax = plt .subplots (1 , 1 , figsize = figsize )
1196+ indices_to_plot = np .isin (classes , classes_to_plot )
1197+ # Loop for all classes to get different class gain
1198+ for i , to_plot in enumerate (indices_to_plot ):
1199+ perc_dict [i ], gain_dict [i ] = cumulative_gain_curve (y_true , y_probas [:, i ], pos_label = classes [i ])
11491200
1150- ax .set_title (title , fontsize = title_fontsize )
1201+ if to_plot :
1202+ color = plt .cm .get_cmap (cmap )(float (i ) / len (classes ))
1203+ ax .plot (perc_dict [i ], gain_dict [i ], lw = 2 , color = color ,
1204+ label = 'Class {} Cumulative Gain curve' .format (class_names [i ]))
11511205
1152- ax .plot (percentages , gains1 , lw = 3 , label = 'Class {}' .format (class_names [0 ]))
1153- ax .plot (percentages , gains2 , lw = 3 , label = 'Class {}' .format (class_names [1 ]))
1206+ # Whether or to plot macro or micro
1207+ if plot_micro :
1208+ binarized_y_true = label_binarize (y_true , classes = classes )
1209+ if len (classes ) == 2 :
1210+ binarized_y_true = np .hstack ((1 - binarized_y_true , binarized_y_true ))
1211+
1212+ perc , gain = cumulative_gain_curve (binarized_y_true .ravel (), y_probas .ravel ())
1213+ ax .plot (perc , gain , label = 'micro-average Cumulative Gain curve' ,
1214+ color = 'deeppink' , linestyle = ':' , linewidth = 4 )
1215+
1216+ if plot_macro :
1217+ # First aggregate all percentages
1218+ all_perc = np .unique (np .concatenate ([perc_dict [x ] for x in range (len (classes ))]))
1219+
1220+ # Then interpolate all cumulative gain
1221+ mean_gain = np .zeros_like (all_perc )
1222+ for i in range (len (classes )):
1223+ mean_gain += np .interp (all_perc , perc_dict [i ], gain_dict [i ])
1224+
1225+ mean_gain /= len (classes )
1226+
1227+ ax .plot (all_perc , mean_gain , label = 'macro-average Cumulative Gain curve' ,
1228+ color = 'navy' , linestyle = ':' , linewidth = 4 )
11541229
11551230 ax .set_xlim ([0.0 , 1.0 ])
11561231 ax .set_ylim ([0.0 , 1.0 ])
@@ -1159,16 +1234,19 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11591234
11601235 ax .set_xlabel ('Percentage of sample' , fontsize = text_fontsize )
11611236 ax .set_ylabel ('Gain' , fontsize = text_fontsize )
1237+
11621238 ax .tick_params (labelsize = text_fontsize )
11631239 ax .grid ('on' )
11641240 ax .legend (loc = 'lower right' , fontsize = text_fontsize )
11651241
11661242 return ax
11671243
11681244
1169- def plot_lift_curve (y_true , y_probas , title = 'Lift Curve' ,
1170- ax = None , figsize = None , title_fontsize = "large" ,
1171- text_fontsize = "medium" , class_names = None ):
1245+ def plot_lift_curve (
1246+ y_true , y_probas , title = 'Lift Curve' ,
1247+ ax = None , figsize = None , title_fontsize = "large" ,
1248+ text_fontsize = "medium" , class_names = None
1249+ ):
11721250 """Generates the Lift Curve from labels and scores/probabilities
11731251
11741252 The lift curve is used to determine the effectiveness of a
0 commit comments