scikit-learn: check_estimator is not sufficiently general
I’ve been writing a lot of scikit-learn compatible code lately, and I love the idea of the general checks in check_estimator
to make sure that my code is scikit-learn compatible. But in nearly all cases, I’m finding that these checks are not general enough for the objects I’m creating (for example: MSTClutering
, which fails because it doesn’t support specification of the number of clusters through an n_clusters
argument)
Digging through the code, there are a lot of hard-coded special-cases of estimators within scikit-learn itself; this would imply that absent those special-cases scikit-learn’s own estimators would not pass the general estimator checks, which is obviously a huge issue.
Making these checks significantly general would be hard; I imagine it would be a rather large project, and probably even involve designing an API so that estimators themselves can tune the checks that are run on them.
In the meantime, I think it would be better for us not to recommend people running this code on their own estimators, or to let them know that failing estimator checks do not necessarily imply non-compatibility.
About this issue
- Original URL
- State: closed
- Created 8 years ago
- Reactions: 1
- Comments: 28 (27 by maintainers)
560 lines of test code removed from my package based on catching up to latest
sklearn
version. Awesome! and thank you for the support of us downstream packages. Much appreciated.@rth I’d argue you should should do
include_meta_estimators=False
. On the other hand, that could be another “tag”. Though that could be done with inheritance, too.To put some numbers on this issue, at present 41 estimators out of 147 do not pass the
check_estimator
validation in scikit learn. Out of 41 validations that fail (+4 skipped), 29 are masked insklearn.utils.testing.DONT_TEST
and 2 are private classes (and shouldn’t probably be validated anyway). Full output here.Tested with,
validate_estimators.py
using
py.test -v --tb=no validate_estimators.py
on the code in the master branch (PY3, > sklearn-0.18rc2, Linux).