LightGBM: [dask] DaskRegressor.predict() fails on DataFrame / Series input
How you are using LightGBM?
LightGBM component: Python-package
Environment info
Operating System: Ubuntu 18.04
C++ compiler version: gcc 8.3.0
CMake version: 3.13.4
Python version:
output of 'conda info'
active environment : saturn
active env location : /opt/conda/envs/saturn
shell level : 0
user config file : /home/jovyan/.condarc
populated config files : /opt/conda/.condarc
conda version : 4.8.2
conda-build version : not installed
python version : 3.7.7.final.0
virtual packages : __glibc=2.28
base environment : /opt/conda (writable)
channel URLs : https://conda.saturncloud.io/pkgs/linux-64
https://conda.saturncloud.io/pkgs/noarch
https://conda.anaconda.org/conda-forge/linux-64
https://conda.anaconda.org/conda-forge/noarch
https://repo.anaconda.com/pkgs/main/linux-64
https://repo.anaconda.com/pkgs/main/noarch
https://repo.anaconda.com/pkgs/r/linux-64
https://repo.anaconda.com/pkgs/r/noarch
package cache : /opt/conda/pkgs
/home/jovyan/.conda/pkgs
envs directories : /opt/conda/envs
/home/jovyan/.conda/envs
platform : linux-64
user-agent : conda/4.8.2 requests/2.22.0 CPython/3.7.7 Linux/4.14.203-156.332.amzn2.x86_64 debian/10 glibc/2.28
UID:GID : 1000:100
netrc file : None
offline mode : False
LightGBM version or commit hash: https://github.com/microsoft/LightGBM/tree/9f70e9685dfb5c82f2ee87176a8433a6b7a4b98f
Error message and / or logs
Training with lightgbm.dask.DaskLGBMRegressor succeeds, and .predict() fails with this error.
ValueError: Metadata inference failed in `_predict_part`.
You have supplied a custom function and Dask is unable to
determine the type of output that that function returns.
To resolve this please provide a meta= keyword.
The docstring of the Dask function you ran should have more information.
Original error is below:
------------------------
TypeError('Unknown type of parameter:y, got:Series')
Traceback:
---------
File "/opt/conda/envs/saturn/lib/python3.7/site-packages/dask/dataframe/utils.py", line 174, in raise_on_meta_error
yield
File "/opt/conda/envs/saturn/lib/python3.7/site-packages/dask/dataframe/core.py", line 5165, in _emulate
return func(*_extract_meta(args, True), **_extract_meta(kwargs, True))
File "/opt/conda/envs/saturn/lib/python3.7/site-packages/lightgbm/dask.py", line 319, in _predict_part
**kwargs
File "/opt/conda/envs/saturn/lib/python3.7/site-packages/lightgbm/sklearn.py", line 707, in predict
pred_leaf=pred_leaf, pred_contrib=pred_contrib, **kwargs)
File "/opt/conda/envs/saturn/lib/python3.7/site-packages/lightgbm/basic.py", line 3118, in predict
predictor = self._to_predictor(deepcopy(kwargs))
File "/opt/conda/envs/saturn/lib/python3.7/site-packages/lightgbm/basic.py", line 3204, in _to_predictor
predictor = _InnerPredictor(booster_handle=self.handle, pred_parameter=pred_parameter)
File "/opt/conda/envs/saturn/lib/python3.7/site-packages/lightgbm/basic.py", line 638, in __init__
self.pred_parameter = param_dict_to_str(pred_parameter)
File "/opt/conda/envs/saturn/lib/python3.7/site-packages/lightgbm/basic.py", line 221, in param_dict_to_str
% (key, type(val).__name__))
Reproducible example(s)
I’ll update this with a better, smaller reproducible example soon. I’m rushing right now to finish something else for work, but wanted to be sure I document this so search engines return this issue if others google that error message.
I’m training and trying to .predict() on a Dask DataFrame. Something like this.
import dask.dataframe as dd
import lightgbm as lgb
taxi_train = dd.read_csv(
"s3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv",
parse_dates=["tpep_pickup_datetime", "tpep_dropoff_datetime"],
storage_options={"anon": True},
assume_missing=True,
).sample(frac=0.01, replace=False)
def prep_df(df: dd.DataFrame, target_col: str) -> dd.DataFrame:
"""
Prepare a raw taxi dataframe for training.
* computes the target ('tip_fraction')
* adds features
* removes unused features
"""
numeric_feat = [
"pickup_weekday",
"pickup_weekofyear",
"pickup_hour",
"pickup_week_hour",
"pickup_minute",
"passenger_count",
]
categorical_feat = [
"PULocationID",
"DOLocationID",
]
features = numeric_feat + categorical_feat
df = df[df.fare_amount > 0] # avoid divide-by-zero
df[target_col] = df.tip_amount / df.fare_amount
df["pickup_weekday"] = df.tpep_pickup_datetime.dt.weekday
df["pickup_weekofyear"] = df.tpep_pickup_datetime.dt.isocalendar().week
df["pickup_hour"] = df.tpep_pickup_datetime.dt.hour
df["pickup_week_hour"] = (df.pickup_weekday * 24) + df.pickup_hour
df["pickup_minute"] = df.tpep_pickup_datetime.dt.minute
df = df[features + [target_col]].astype(float).fillna(-1)
return df
target_col = "tip_fraction"
taxi_train = prep_df(taxi, target_col)
taxi_train = taxi_train.persist()
_ = wait(taxi_train)
features = [c for c in taxi_train.columns if c != target_col]
data = taxi_train[features]
label = taxi_train[target_col]
dask_reg = lgb.dask.DaskLGBMRegressor(
silent=False,
max_depth=8,
random_state=708,
learning_rate=0.05,
tree_learner="data",
n_estimators=100,
n_jobs=-1,
categorical_features=[6, 7]
)
dask_reg.fit(
client=client,
X=data,
y=label,
)
taxi_test = dd.read_csv(
"s3://nyc-tlc/trip data/yellow_tripdata_2019-02.csv",
parse_dates=["tpep_pickup_datetime", "tpep_dropoff_datetime"],
storage_options={"anon": True},
assume_missing=True,
).sample(frac=0.01, replace=False)
taxi_test = prep_df(taxi_test, target_col=target_col)
taxi_test = taxi_test.persist()
_ = wait(taxi_test)
preds = dask_reg.predict(
X=taxi_test[features]
)
See the output of conda env export below for versions of Dask and its dependencies.
output of 'conda env export'
name: saturn
channels:
- https://conda.saturncloud.io/pkgs
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- argon2-cffi=20.1.0=py37h7b6447c_1
- async_generator=1.10=py37h28b3542_0
- attrs=20.3.0=pyhd3eb1b0_0
- backcall=0.2.0=py_0
- blas=1.0=mkl
- bleach=3.2.2=pyhd3eb1b0_0
- bokeh=2.2.3=py37_0
- boto3=1.16.59=pyhd3eb1b0_0
- botocore=1.19.59=pyhd3eb1b0_0
- brotlipy=0.7.0=py37h27cfd23_1003
- ca-certificates=2021.1.19=h06a4308_0
- cairo=1.14.12=h8948797_3
- certifi=2020.12.5=py37h06a4308_0
- cffi=1.14.0=py37h2e261b9_0
- click=7.1.2=pyhd3eb1b0_0
- cloudpickle=1.6.0=py_0
- cryptography=3.3.1=py37h3c74f83_0
- cycler=0.10.0=py37_0
- cytoolz=0.11.0=py37h7b6447c_0
- dask-glm=0.2.0=py37_0
- dask-ml=1.7.0=py_0
- dbus=1.13.18=hb2f20db_0
- decorator=4.4.2=py_0
- defusedxml=0.6.0=py_0
- docutils=0.15.2=py37_0
- entrypoints=0.3=py37_0
- expat=2.2.10=he6710b0_2
- fastparquet=0.5.0=py37h6323ea4_1
- fontconfig=2.13.0=h9420a91_0
- freetype=2.10.4=h5ab3b9f_0
- fribidi=1.0.10=h7b6447c_0
- fsspec=0.8.3=py_0
- glib=2.63.1=h5a9c865_0
- graphite2=1.3.14=h23475e2_0
- graphviz=2.40.1=h21bd128_2
- gst-plugins-base=1.14.0=hbbd80ab_1
- gstreamer=1.14.0=hb453b48_1
- h5py=2.10.0=py37hd6299e0_1
- harfbuzz=1.8.8=hffaf4a1_0
- hdf5=1.10.6=hb1b8bf9_0
- heapdict=1.0.1=py_0
- icu=58.2=he6710b0_3
- importlib-metadata=2.0.0=py_1
- importlib_metadata=2.0.0=1
- intel-openmp=2020.2=254
- ipykernel=5.3.4=py37h5ca1d4c_0
- ipython=7.19.0=py37hb070fc8_0
- ipython_genutils=0.2.0=pyhd3eb1b0_1
- ipywidgets=7.6.3=pyhd3eb1b0_1
- jedi=0.18.0=py37h06a4308_1
- jinja2=2.11.2=pyhd3eb1b0_0
- jmespath=0.10.0=py_0
- joblib=1.0.0=pyhd3eb1b0_0
- jpeg=9b=h024ee3a_2
- jsonschema=3.2.0=py_2
- jupyter_client=6.1.7=py_0
- jupyter_core=4.7.0=py37h06a4308_0
- jupyterlab_pygments=0.1.2=py_0
- jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
- kiwisolver=1.3.0=py37h2531618_0
- lcms2=2.11=h396b838_0
- ld_impl_linux-64=2.33.1=h53a641e_7
- libedit=3.1.20191231=h14c3975_1
- libffi=3.2.1=hf484d3e_1007
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.3.0=hdf63c60_0
- libllvm10=10.0.1=hbcb73fb_5
- libpng=1.6.37=hbc83047_0
- libsodium=1.0.18=h7b6447c_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtiff=4.1.0=h2733197_1
- libuuid=1.0.3=h1bed415_2
- libxcb=1.14=h7b6447c_0
- libxml2=2.9.10=hb55368b_3
- llvmlite=0.34.0=py37h269e1b5_4
- locket=0.2.1=py37h06a4308_1
- lz4-c=1.9.3=h2531618_0
- markupsafe=1.1.1=py37h14c3975_1
- matplotlib=3.3.2=h06a4308_0
- matplotlib-base=3.3.2=py37h817c723_0
- mistune=0.8.4=py37h14c3975_1001
- mkl=2020.2=256
- mkl-service=2.3.0=py37he8ac12f_0
- mkl_fft=1.2.0=py37h23d657b_0
- mkl_random=1.1.1=py37h0573a6f_0
- msgpack-python=1.0.1=py37hff7bd54_0
- multipledispatch=0.6.0=py37_0
- nbclient=0.5.1=py_0
- nbconvert=6.0.7=py37_0
- nbformat=5.1.2=pyhd3eb1b0_1
- ncurses=6.2=he6710b0_1
- nest-asyncio=1.4.3=pyhd3eb1b0_0
- notebook=6.2.0=py37h06a4308_0
- numba=0.51.2=py37h04863e7_1
- numpy=1.19.2=py37h54aff64_0
- numpy-base=1.19.2=py37hfa32c7d_0
- olefile=0.46=py37_0
- openssl=1.1.1i=h27cfd23_0
- packaging=20.8=pyhd3eb1b0_0
- pandas=1.1.0=py37he6710b0_0
- pandoc=2.11=hb0f4dca_0
- pandocfilters=1.4.3=py37h06a4308_1
- pango=1.42.4=h049681c_0
- parso=0.8.1=pyhd3eb1b0_0
- partd=1.1.0=py_0
- pcre=8.44=he6710b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.1.0=py37he98fc37_0
- pip=20.3.3=py37h06a4308_0
- pixman=0.40.0=h7b6447c_0
- prometheus_client=0.9.0=pyhd3eb1b0_0
- prompt-toolkit=3.0.8=py_0
- psutil=5.7.2=py37h7b6447c_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pycparser=2.20=py_2
- pygments=2.7.4=pyhd3eb1b0_0
- pyopenssl=20.0.1=pyhd3eb1b0_1
- pyparsing=2.4.7=pyhd3eb1b0_0
- pyqt=5.9.2=py37h05f1152_2
- pyrsistent=0.17.3=py37h7b6447c_0
- pysocks=1.7.1=py37_1
- python=3.7.7=hcf32534_0_cpython
- python-dateutil=2.8.1=py_0
- pytz=2020.5=pyhd3eb1b0_0
- pyyaml=5.4.1=py37h27cfd23_1
- pyzmq=20.0.0=py37h2531618_1
- qt=5.9.7=h5867ecd_1
- readline=8.0=h7b6447c_0
- s3fs=0.4.2=py_0
- s3transfer=0.3.4=pyhd3eb1b0_0
- scikit-learn=0.23.2=py37h0573a6f_0
- scipy=1.5.2=py37h0b6359f_0
- send2trash=1.5.0=pyhd3eb1b0_1
- setuptools=52.0.0=py37h06a4308_0
- sip=4.19.8=py37hf484d3e_0
- six=1.15.0=py37h06a4308_0
- sortedcontainers=2.3.0=pyhd3eb1b0_0
- sqlite=3.33.0=h62c20be_0
- tbb=2020.3=hfd86e86_0
- tblib=1.7.0=py_0
- terminado=0.9.2=py37h06a4308_0
- testpath=0.4.4=py_0
- threadpoolctl=2.1.0=pyh5ca1d4c_0
- thrift=0.11.0=py37hf484d3e_0
- tk=8.6.10=hbc83047_0
- toolz=0.11.1=pyhd3eb1b0_0
- tornado=6.1=py37h27cfd23_0
- traitlets=5.0.5=py_0
- typing_extensions=3.7.4.3=py_0
- urllib3=1.25.11=py_0
- wcwidth=0.2.5=py_0
- webencodings=0.5.1=py37_1
- wheel=0.36.2=pyhd3eb1b0_0
- widgetsnbextension=3.5.1=py37_0
- xz=5.2.5=h7b6447c_0
- yaml=0.2.5=h7b6447c_0
- zeromq=4.3.3=he6710b0_3
- zict=2.0.0=py_0
- zipp=3.4.0=pyhd3eb1b0_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.5=h9ceee32_0
- pip:
- chardet==4.0.0
- dask==2021.1.1
- dask-saturn==0.2.2
- distributed==2021.1.1
- idna==2.10
- lightgbm==3.1.1.99
- requests==2.25.1
prefix: /opt/conda/envs/saturn
References
I think that changing the uses of map_blocks() and map_partitions based on this description from the Dask docs could fix this issue.
meta
An empty pd.DataFrame or pd.Series that matches the dtypes and column names of the output. This metadata is necessary for many algorithms in dask dataframe to work. For ease of use, some alternative inputs are also available. Instead of a DataFrame, a dict of {name: dtype} or iterable of (name, dtype) can be provided (note that the order of the names should match the order of the columns). Instead of a series, a tuple of (name, dtype) can be used. If not provided, dask will try to infer the metadata. This may lead to unexpected results, so providing meta is recommended. For more information, see dask.dataframe.utils.make_meta.
But I’m confused and concerned about this error showing up, since it does not show up in any of the tests at https://github.com/microsoft/LightGBM/blob/9f70e9685dfb5c82f2ee87176a8433a6b7a4b98f/tests/python_package_test/test_dask.py, and we test against Dask DataFrame inputs there.
For anyone new to LightGBM looking to help with this before I get to it, here’s the place where we’re using _predict_part() in map_partitions() --> https://github.com/microsoft/LightGBM/blob/9f70e9685dfb5c82f2ee87176a8433a6b7a4b98f/python-package/lightgbm/dask.py#L351-L360
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 15
Haha ok, thanks! I actually get some dedicated time to work on LightGBM at work…so how about I try this tomorrow and open a draft PR, and maybe I’ll
@you for a review and other ideas?Now that you found a small reproducible example, it should go quickly.
Oh I meant it as in: if no one’s taking it I’d like to check it out, haha. I’m not sure I’d be able to pull it off, I prefer maybe helping you with some findings or discussions.