scikit-learn: plot_confusion_matrix example breaks down if not all classes are present in the test data

Description

The example at https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py

easily breaks down without warning or error if the data does not contain all labels. This can easily happen with imbalanced datasets or with many classes and real datasets.

Steps/Code to Reproduce

import itertools
import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=45, test_size=0.05)
    

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear')
y_pred = classifier.fit(X_train, y_train).predict(X_test)


def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()


# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')
plt.show()

Expected Results

good

Actual Results

bad

Versions

Windows-7-6.1.7601-SP1 Python 3.6.5 |Anaconda, Inc.| (default, Mar 29 2018, 13:32:41) [MSC v.1900 64 bit (AMD64)] NumPy 1.14.3 SciPy 1.1.0 Scikit-Learn 0.19.1

About this issue

  • Original URL
  • State: closed
  • Created 6 years ago
  • Reactions: 2
  • Comments: 16 (13 by maintainers)

Commits related to this issue

Most upvoted comments

@sanu11 @MLopez-Ibanez has something in his fork and there’s a PR from @trungpham10, feel free to take it if they don’t mind or if there’s no reply after some time.

That’s ok with me.

On Mon, 17 Dec 2018, 09:24 Adrin Jalali <notifications@github.com wrote:

Sure, go ahead @trungpham10 https://github.com/trungpham10 , unless @MLopez-Ibanez https://github.com/MLopez-Ibanez is working on a PR.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/scikit-learn/scikit-learn/issues/12700#issuecomment-447775770, or mute the thread https://github.com/notifications/unsubscribe-auth/ACf6ddh8JiPVNKflTOv42tEipmaVYjCjks5u52LVgaJpZM4Y7uLj .

Sure, go ahead @trungpham10 , unless @MLopez-Ibanez is working on a PR.

@adrinjalali You’re also core dev now:)

We already have parameter classes. Simply replace classes=class_names to things like classes=class_names[[0, 2]] will solve your problem. I don’t think we need to modify the example.

To do that, one needs to know what the problem is (the problem is that the plotting function assumes the data contains all labels) and then do something like classes = class_names[list(set(y_trye) | set(y_pred))] to make sure the problem does not re-appear with different data, which is what I’d like to avoid.

But I don’t think we need to modify the function, and I won’t vote +1 for things like classes = class_names[list(set(y_trye) | set(y_pred))] when data contains all the labels.

My proposed function has the following benefits:

  • Make sure we don’t use class labels that do not appear in the data using unique_labels (same thing that confusion_matrix does)

  • The default title says whether the matrix uses normalization or not.

  • Use explicit fig, ax instead of default figure and axes, and return axas recommended by Matplotlib.

  • Avoid itertools, it doesn’t make the code simpler and this is not code that needs to be optimized.

from sklearn.utils.multiclass import unique_labels
def plot_confusion_matrix(y_true, y_pred, classes,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title ='Confusion matrix, without normalization'
    
    cm = confusion_matrix(y_true, y_pred)
    # Only use the labels that appear in the data
    classes = classes[unique_labels(y_true, y_pred)]
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')
    
    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax

Although, I think it would really help to showcase the usage of the parameter in this or another example. This to me looks like a nice potential improvement to this example for instance. I’m not sure why you’d oppose changing the example @qinhanmin2014 ?