scikit-learn: plot_confusion_matrix example breaks down if not all classes are present in the test data
Description
easily breaks down without warning or error if the data does not contain all labels. This can easily happen with imbalanced datasets or with many classes and real datasets.
Steps/Code to Reproduce
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names
# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=45, test_size=0.05)
# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear')
y_pred = classifier.fit(X_train, y_train).predict(X_test)
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
title='Confusion matrix, without normalization')
plt.show()
Expected Results
Actual Results
Versions
Windows-7-6.1.7601-SP1 Python 3.6.5 |Anaconda, Inc.| (default, Mar 29 2018, 13:32:41) [MSC v.1900 64 bit (AMD64)] NumPy 1.14.3 SciPy 1.1.0 Scikit-Learn 0.19.1
About this issue
- Original URL
- State: closed
- Created 6 years ago
- Reactions: 2
- Comments: 16 (13 by maintainers)
Commits related to this issue
- EXA plot_confusion_matrix example breaks down if not all classes present (#13126) * fix #12700 plot_confusion_matrix example breaks down if not all classes are present in the test data * plot_conf... — committed to scikit-learn/scikit-learn by MLopez-Ibanez 5 years ago
- EXA plot_confusion_matrix example breaks down if not all classes present (#13126) * fix #12700 plot_confusion_matrix example breaks down if not all classes are present in the test data * plot_conf... — committed to jnothman/scikit-learn by MLopez-Ibanez 5 years ago
- EXA plot_confusion_matrix example breaks down if not all classes present (#13126) * fix #12700 plot_confusion_matrix example breaks down if not all classes are present in the test data * plot_conf... — committed to xhluca/scikit-learn by MLopez-Ibanez 5 years ago
- EXA plot_confusion_matrix example breaks down if not all classes present (#13126) * fix #12700 plot_confusion_matrix example breaks down if not all classes are present in the test data * plot_conf... — committed to koenvandevelde/scikit-learn by MLopez-Ibanez 5 years ago
@sanu11 @MLopez-Ibanez has something in his fork and there’s a PR from @trungpham10, feel free to take it if they don’t mind or if there’s no reply after some time.
That’s ok with me.
On Mon, 17 Dec 2018, 09:24 Adrin Jalali <notifications@github.com wrote:
Sure, go ahead @trungpham10 , unless @MLopez-Ibanez is working on a PR.
@adrinjalali You’re also core dev now:)
To do that, one needs to know what the problem is (the problem is that the plotting function assumes the data contains all labels) and then do something like
classes = class_names[list(set(y_trye) | set(y_pred))]
to make sure the problem does not re-appear with different data, which is what I’d like to avoid.My proposed function has the following benefits:
Make sure we don’t use class labels that do not appear in the data using
unique_labels
(same thing thatconfusion_matrix
does)The default title says whether the matrix uses normalization or not.
Use explicit
fig, ax
instead of default figure and axes, and returnax
as recommended by Matplotlib.Avoid itertools, it doesn’t make the code simpler and this is not code that needs to be optimized.
Although, I think it would really help to showcase the usage of the parameter in this or another example. This to me looks like a nice potential improvement to this example for instance. I’m not sure why you’d oppose changing the example @qinhanmin2014 ?