From e9ad6bcb6d2ed7e3456c75098fa2256d84ae76f1 Mon Sep 17 00:00:00 2001 From: Sefa Ozalp Date: Fri, 26 Feb 2021 14:29:28 +0000 Subject: [PATCH 1/2] make colorbar optional in plot_confusion_matrix() make colorbar optional in `plot_confusion_matrix()`. Useful when multiple confusion matrices are plotted together within a `plt.subplots()`. --- scikitplot/metrics.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/scikitplot/metrics.py b/scikitplot/metrics.py index 08ec693..d1bc092 100644 --- a/scikitplot/metrics.py +++ b/scikitplot/metrics.py @@ -34,7 +34,7 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, pred_labels=None, title=None, normalize=False, hide_zeros=False, hide_counts=False, x_tick_rotation=0, ax=None, figsize=None, cmap='Blues', title_fontsize="large", - text_fontsize="medium"): + text_fontsize="medium",colorbar=True): """Generates confusion matrix plot from predictions and true labels Args: @@ -65,7 +65,7 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, hide_zeros (bool, optional): If True, does not plot cells containing a value of zero. Defaults to False. - hide_counts (bool, optional): If True, doe not overlay counts. + hide_counts (bool, optional): If True, does not overlay counts. Defaults to False. x_tick_rotation (int, optional): Rotates x-axis tick labels by the @@ -90,6 +90,10 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, text_fontsize (string or int, optional): Matplotlib-style fontsizes. Use e.g. "small", "medium", "large" or integer-values. Defaults to "medium". + + colorbar (bool, optional): If False, does not add colourmap. + Defaults to True. + Returns: ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was @@ -153,7 +157,10 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, ax.set_title('Confusion Matrix', fontsize=title_fontsize) image = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.get_cmap(cmap)) - plt.colorbar(mappable=image) + + if colorbar == True: + plt.colorbar(mappable=image) + x_tick_marks = np.arange(len(pred_classes)) y_tick_marks = np.arange(len(true_classes)) ax.set_xticks(x_tick_marks) From 28e65a05cd37c6aa4d6c190cf5d886d805e64a8a Mon Sep 17 00:00:00 2001 From: Sefa Ozalp Date: Fri, 26 Feb 2021 14:37:44 +0000 Subject: [PATCH 2/2] Update metrics.py --- scikitplot/metrics.py | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/scikitplot/metrics.py b/scikitplot/metrics.py index d1bc092..492078d 100644 --- a/scikitplot/metrics.py +++ b/scikitplot/metrics.py @@ -34,71 +34,52 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, pred_labels=None, title=None, normalize=False, hide_zeros=False, hide_counts=False, x_tick_rotation=0, ax=None, figsize=None, cmap='Blues', title_fontsize="large", - text_fontsize="medium",colorbar=True): + text_fontsize="medium", colorbar=True): """Generates confusion matrix plot from predictions and true labels - Args: y_true (array-like, shape (n_samples)): Ground truth (correct) target values. - y_pred (array-like, shape (n_samples)): Estimated targets as returned by a classifier. - labels (array-like, shape (n_classes), optional): List of labels to index the matrix. This may be used to reorder or select a subset of labels. If none is given, those that appear at least once in ``y_true`` or ``y_pred`` are used in sorted order. (new in v0.2.5) - true_labels (array-like, optional): The true labels to display. If none is given, then all of the labels are used. - pred_labels (array-like, optional): The predicted labels to display. If none is given, then all of the labels are used. - title (string, optional): Title of the generated plot. Defaults to "Confusion Matrix" if `normalize` is True. Else, defaults to "Normalized Confusion Matrix. - normalize (bool, optional): If True, normalizes the confusion matrix before plotting. Defaults to False. - hide_zeros (bool, optional): If True, does not plot cells containing a value of zero. Defaults to False. - - hide_counts (bool, optional): If True, does not overlay counts. + hide_counts (bool, optional): If True, doe not overlay counts. Defaults to False. - x_tick_rotation (int, optional): Rotates x-axis tick labels by the specified angle. This is useful in cases where there are numerous categories and the labels overlap each other. - ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to plot the curve. If None, the plot is drawn on a new set of axes. - figsize (2-tuple, optional): Tuple denoting figure size of the plot e.g. (6, 6). Defaults to ``None``. - cmap (string or :class:`matplotlib.colors.Colormap` instance, optional): Colormap used for plotting the projection. View Matplotlib Colormap documentation for available options. https://matplotlib.org/users/colormaps.html - title_fontsize (string or int, optional): Matplotlib-style fontsizes. Use e.g. "small", "medium", "large" or integer-values. Defaults to "large". - text_fontsize (string or int, optional): Matplotlib-style fontsizes. Use e.g. "small", "medium", "large" or integer-values. Defaults to "medium". - - colorbar (bool, optional): If False, does not add colourmap. + colorbar (bool, optional): If False, does not add colour bar. Defaults to True. - - Returns: ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was drawn. - Example: >>> import scikitplot as skplt >>> rf = RandomForestClassifier() @@ -107,7 +88,6 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, >>> skplt.metrics.plot_confusion_matrix(y_test, y_pred, normalize=True) >>> plt.show() - .. image:: _static/examples/plot_confusion_matrix.png :align: center :alt: Confusion matrix @@ -157,10 +137,10 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, ax.set_title('Confusion Matrix', fontsize=title_fontsize) image = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.get_cmap(cmap)) - + if colorbar == True: plt.colorbar(mappable=image) - + x_tick_marks = np.arange(len(pred_classes)) y_tick_marks = np.arange(len(true_classes)) ax.set_xticks(x_tick_marks)