scikit-learn: Allow for refit=callable in *SearchCV to balance score and model complexity

GridSearchCV and RandomizedSearchCV currently allow for refit=my_scorer_name to select the model that maximises some chosen metric. But these scorers need to be calculated independent of other candidates’ results.

To balance model complexity with cross-validated score, it is common to use an approach like choosing the model that is least complex (by some metric or ordering) but is within 1 standard deviation of the best score. (Variant approaches exist, and may relate to budget constraints etc.)

We could consider allowing a callable to be passed to refit:

"""
...
    refit : boolean, string, or callable, default=True
        Refit an estimator using the best found parameters on the whole
        dataset.

        For multiple metric evaluation, this can be a string denoting the
        scorer maximised to find the best parameters for refitting the estimator
        at the end.

        Where there are considerations other than maximum model performance in
        choosing a best estimator, ``refit`` can be set to a function which
        returns the selected ``best_index_`` given the ``cv_results_``.

...
"""

Does this interface sound reasonable, @janvanrijn, @betatim?

About this issue

  • Original URL
  • State: closed
  • Created 6 years ago
  • Comments: 15 (9 by maintainers)

Most upvoted comments

That’s about right… but you should think of it in terms of documentation (how do I describe the feature to users?) and tests (what do you need to assert in order to ascertain that the implementation is likely correct?​).