scikit-learn: check_estimator is not sufficiently general

I’ve been writing a lot of scikit-learn compatible code lately, and I love the idea of the general checks in check_estimator to make sure that my code is scikit-learn compatible. But in nearly all cases, I’m finding that these checks are not general enough for the objects I’m creating (for example: MSTClutering, which fails because it doesn’t support specification of the number of clusters through an n_clusters argument)

Digging through the code, there are a lot of hard-coded special-cases of estimators within scikit-learn itself; this would imply that absent those special-cases scikit-learn’s own estimators would not pass the general estimator checks, which is obviously a huge issue.

Making these checks significantly general would be hard; I imagine it would be a rather large project, and probably even involve designing an API so that estimators themselves can tune the checks that are run on them.

In the meantime, I think it would be better for us not to recommend people running this code on their own estimators, or to let them know that failing estimator checks do not necessarily imply non-compatibility.

About this issue

  • Original URL
  • State: closed
  • Created 8 years ago
  • Reactions: 1
  • Comments: 28 (27 by maintainers)

Most upvoted comments

560 lines of test code removed from my package based on catching up to latest sklearn version. Awesome! and thank you for the support of us downstream packages. Much appreciated.

@rth I’d argue you should should do include_meta_estimators=False. On the other hand, that could be another “tag”. Though that could be done with inheritance, too.

To put some numbers on this issue, at present 41 estimators out of 147 do not pass the check_estimator validation in scikit learn. Out of 41 validations that fail (+4 skipped), 29 are masked in sklearn.utils.testing.DONT_TEST and 2 are private classes (and shouldn’t probably be validated anyway). Full output here.

Tested with,

validate_estimators.py

from sklearn.utils.testing import all_estimators
from sklearn.utils.estimator_checks import check_estimator
import pytest


@pytest.mark.parametrize('name,estimator',
          all_estimators(include_meta_estimators=True,
              include_dont_test=True))
def test_estimator(name, estimator):
    check_estimator(estimator)

using py.test -v --tb=no validate_estimators.py on the code in the master branch (PY3, > sklearn-0.18rc2, Linux).