dask-ml: xgbclassifier gives error with gridsearchcv and dask dataframes
import dask.dataframe as dd
import numpy as np
import pandas as pd
from dask_ml.model_selection import GridSearchCV
from dask_ml.xgboost import XGBClassifier
from distributed import Client
from sklearn.datasets import load_iris
if __name__ == '__main__':
client = Client()
data = load_iris()
x = pd.DataFrame(data=data['data'], columns=data['feature_names'])
x = dd.from_pandas(x, npartitions=2)
y = pd.Series(data['target'])
y = dd.from_pandas(y, npartitions=2)
estimator = XGBClassifier(objective='multi:softmax', num_class=4)
grid_search = GridSearchCV(
estimator,
param_grid={
'n_estimators': np.arange(15, 105, 15)
},
scheduler='threads'
)
grid_search.fit(x, y)
results = pd.DataFrame(grid_search.cv_results_)
print(results.to_string())
gives this
Traceback (most recent call last): File “d.py”, line 30, in <module> grid_search.fit(x, y) File “/usr/local/lib/python3.7/site-packages/dask_ml/model_selection/_search.py”, line 1233, in fit cache_cv=self.cache_cv, File “/usr/local/lib/python3.7/site-packages/dask_ml/model_selection/_search.py”, line 203, in build_cv_graph X_name, y_name, groups_name = to_keys(dsk, X, y, groups) File “/usr/local/lib/python3.7/site-packages/dask_ml/model_selection/utils.py”, line 85, in to_keys assert not is_dask_collection(x) AssertionError
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 28 (14 by maintainers)
@TomAugspurger Hello, I have met the same problem here that is exactly what you summarized above. Dask dataframes should be accepted in GridSearchCV. It’s really important to have a complete chain with this component. I have re-implement some estimators of sklearn for dask dataframe and I found that It has a great problmem that it’s not supported by the dask dataframe in GridSearch with the pipeline of these estimaors.
@stsievert , who works on dask-ml a bit, will also be around.
I won’t be. I know that @mrocklin and @jrbourbeau will be around, and they’re hosting a Dask sprint: https://github.com/dask/dask/issues/4639