scikit-learn: plot_roc_curve doesn't correctly infer pos_label
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import make_blobs
from sklearn.linear_model import LogisticRegression
import numpy as np
X, y = make_blobs(centers=2)
y = y.astype(np.str)
lr = LogisticRegression().fit(X, y)
plot_roc_curve(lr, X, y)
-> raise ValueError("Data is not binary and pos_label is not specified")
I would argue that pos_label=lr.classes_[1]
is the right choice here.
cc @thomasjpfan
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 17 (17 by maintainers)
In this case, since we have the trained estimator, we can infer the pos_label from
est.classes_[1]
as suggested above.If a user wants to pass another
pos_label
we can make sure thatpos_label
appears inest.classes_
.