@@ -32,7 +32,7 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
3232 pred_labels = None , title = None , normalize = False ,
3333 hide_zeros = False , hide_counts = False , x_tick_rotation = 0 , ax = None ,
3434 figsize = None , cmap = 'Blues' , title_fontsize = "large" ,
35- text_fontsize = "medium" ):
35+ text_fontsize = "medium" , show_colorbar = True ):
3636 """Generates confusion matrix plot from predictions and true labels
3737
3838 Args:
@@ -89,6 +89,9 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
8989 Use e.g. "small", "medium", "large" or integer-values. Defaults to
9090 "medium".
9191
92+ show_colorbar (bool, optional): If False, does not add colour bar.
93+ Defaults to True.
94+
9295 Returns:
9396 ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
9497 drawn.
@@ -151,7 +154,10 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
151154 ax .set_title ('Confusion Matrix' , fontsize = title_fontsize )
152155
153156 image = ax .imshow (cm , interpolation = 'nearest' , cmap = plt .cm .get_cmap (cmap ))
154- plt .colorbar (mappable = image )
157+
158+ if show_colorbar == True :
159+ plt .colorbar (mappable = image )
160+
155161 x_tick_marks = np .arange (len (pred_classes ))
156162 y_tick_marks = np .arange (len (true_classes ))
157163 ax .set_xticks (x_tick_marks )
0 commit comments