scikit-learn: plot_roc_curve doesn't correctly infer pos_label

from sklearn.metrics import plot_roc_curve
from sklearn.datasets import make_blobs
from sklearn.linear_model import LogisticRegression

import numpy as np
X, y = make_blobs(centers=2)
y = y.astype(np.str)
lr = LogisticRegression().fit(X, y)
plot_roc_curve(lr, X, y)
-> raise ValueError("Data is not binary and pos_label is not specified")

I would argue that pos_label=lr.classes_[1] is the right choice here.

cc @thomasjpfan

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 17 (17 by maintainers)

Most upvoted comments

In this case, since we have the trained estimator, we can infer the pos_label from est.classes_[1] as suggested above.

If a user wants to pass another pos_label we can make sure that pos_label appears in est.classes_.