distributed: problems with numba ufunc + distributed

We have created a new software package called fastjmd95 that uses numba to accelerate computation of the ocean equation of state. Everything works find with dask and a local scheduler. Now I want to run this code on a distributed dask cluster. It isn’t working, I think because the workers are not able to deserialize the numba functions properly.

Original Full Example

This example with real data can be run on any pangeo cluster

from intake import open_catalog
from fastjmd95 import rho

cat = open_catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean.yaml")
ds  = cat["SOSE"].to_dask()

rhonil = 1025
pa_to_dbar = 1.0/10000
p = ds.PHrefC * rhonil * pa_to_dbar
s = ds.SALT
t = ds.THETA
r = rho(s.data, t.data, 0)
# works fine with local scheduler
r_mean = r[:5].compute()

# now start distributed scheduler
from dask.distributed import Client
client = Client()
r_mean = r[:5].compute()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-7316322484d4> in <module>
----> 1 r_mean = r[:5].compute()

/srv/conda/envs/notebook/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    163         dask.base.compute
    164         """
--> 165         (result,) = compute(self, traverse=False, **kwargs)
    166         return result
    167 

/srv/conda/envs/notebook/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    434     keys = [x.__dask_keys__() for x in collections]
    435     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 436     results = schedule(dsk, keys, **kwargs)
    437     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    438 

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2571                     should_rejoin = False
   2572             try:
-> 2573                 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   2574             finally:
   2575                 for f in futures.values():

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
   1871                 direct=direct,
   1872                 local_worker=local_worker,
-> 1873                 asynchronous=asynchronous,
   1874             )
   1875 

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    766         else:
    767             return sync(
--> 768                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    769             )
    770 

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    332     if error[0]:
    333         typ, exc, tb = error[0]
--> 334         raise exc.with_traceback(tb)
    335     else:
    336         return result[0]

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/utils.py in f()
    316             if callback_timeout is not None:
    317                 future = gen.with_timeout(timedelta(seconds=callback_timeout), future)
--> 318             result[0] = yield future
    319         except Exception as exc:
    320             error[0] = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
   1727                             exc = CancelledError(key)
   1728                         else:
-> 1729                             raise exception.with_traceback(traceback)
   1730                         raise exc
   1731                     if errors == "skip":

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py in loads()
     57 def loads(x):
     58     try:
---> 59         return pickle.loads(x)
     60     except Exception:
     61         logger.info("Failed to deserialize %s", x[:10000], exc_info=True)

/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py in _ufunc_reconstruct()
    123     # scipy.special.expit for instance.
    124     mod = __import__(module, fromlist=[name])
--> 125     return getattr(mod, name)
    126 
    127 def _ufunc_reduce(func):

AttributeError: module '__main__' has no attribute 'rho'

Minimal Example

I believe this reproduces the core problem

import numpy as np
from numba import vectorize, float64, float32
import dask.array as dsa
from dask.distributed import Client
client = Client()

# define a numba ufunc
@vectorize([float64(float64), float32(float32)], nopython=True)
def test_numba(a):
    return a**2

# verify that the client can run it
def try_numba_on_client():
    data = np.arange(5, dtype='f4')
    return test_numba(data)
client.run(try_numba_on_client)
# works, output is:
# > {'tcp://127.0.0.1:37583': array([ 0.,  1.,  4.,  9., 16.]),
# > 'tcp://127.0.0.1:44855': array([ 0.,  1.,  4.,  9., 16.])}

# use in a computation
data_dask = dsa.arange(5, dtype='f4')
test_numba(data_dask).compute()

At this point I get a KilledWorker error. In the worker log, I can see the following error (sorry for the lack of formatting–that’s how it comes out of the worker error logs)

distributed.worker - ERROR - module '__main__' has no attribute 'test_numba'
Traceback (most recent call last): File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/worker.py", line 905, in handle_scheduler comm, every_cycle=[self.ensure_communicating, self.ensure_computing] File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/core.py", line 456, in handle_stream msgs = await comm.read() File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/tcp.py", line 222, in read frames, deserialize=self.deserialize, deserializers=deserializers File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/utils.py", line 69, in from_frames res = _from_frames() File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/utils.py", line 55, in _from_frames frames, deserialize=deserialize, deserializers=deserializers File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/core.py", line 124, in loads value = _deserialize(head, fs, deserializers=deserializers) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 255, in deserialize deserializers=deserializers, File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 268, in deserialize return loads(header, frames) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 62, in pickle_loads return pickle.loads(b"".join(frames)) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py", line 59, in loads return pickle.loads(x) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py", line 125, in _ufunc_reconstruct return getattr(mod, name)
AttributeError: module '__main__' has no attribute 'test_numba'

The basic error appears to be the same as in the full example.

This seems like a pretty straightforward use of numba + distributed, and I assumed this sort of usage was supported. Am I missing something obvious?

Installed versions

I’m on dask 2.9.0 and numba 0.48.0.

>>> client.get_versions(check=True)
{'scheduler': {'host': (('python', '3.7.6.final.0'),
   ('python-bits', 64),
   ('OS', 'Linux'),
   ('OS-release', '4.19.76+'),
   ('machine', 'x86_64'),
   ('processor', 'x86_64'),
   ('byteorder', 'little'),
   ('LC_ALL', 'en_US.UTF-8'),
   ('LANG', 'en_US.UTF-8'),
   ('LOCALE', 'en_US.UTF-8')),
  'packages': {'required': (('dask', '2.9.0'),
    ('distributed', '2.9.0'),
    ('msgpack', '0.6.2'),
    ('cloudpickle', '1.2.2'),
    ('tornado', '6.0.3'),
    ('toolz', '0.10.0')),
   'optional': (('numpy', '1.17.3'),
    ('pandas', '0.25.3'),
    ('bokeh', '1.4.0'),
    ('lz4', '2.2.1'),
    ('dask_ml', '1.1.1'),
    ('blosc', '1.8.1'))}},
 'workers': {'tcp://10.32.181.10:45663': {'host': (('python', '3.7.6.final.0'),
    ('python-bits', 64),
    ('OS', 'Linux'),
    ('OS-release', '4.19.76+'),
    ('machine', 'x86_64'),
    ('processor', 'x86_64'),
    ('byteorder', 'little'),
    ('LC_ALL', 'en_US.UTF-8'),
    ('LANG', 'en_US.UTF-8'),
    ('LOCALE', 'en_US.UTF-8')),
   'packages': {'required': (('dask', '2.9.0'),
     ('distributed', '2.9.0'),
     ('msgpack', '0.6.2'),
     ('cloudpickle', '1.2.2'),
     ('tornado', '6.0.3'),
     ('toolz', '0.10.0')),
    'optional': (('numpy', '1.17.3'),
     ('pandas', '0.25.3'),
     ('bokeh', '1.4.0'),
     ('lz4', '2.2.1'),
     ('dask_ml', '1.1.1'),
     ('blosc', '1.8.1'))}},
  'tcp://10.32.181.11:37259': {'host': (('python', '3.7.6.final.0'),
    ('python-bits', 64),
    ('OS', 'Linux'),
    ('OS-release', '4.19.76+'),
    ('machine', 'x86_64'),
    ('processor', 'x86_64'),
    ('byteorder', 'little'),
    ('LC_ALL', 'en_US.UTF-8'),
    ('LANG', 'en_US.UTF-8'),
    ('LOCALE', 'en_US.UTF-8')),
   'packages': {'required': (('dask', '2.9.0'),
     ('distributed', '2.9.0'),
     ('msgpack', '0.6.2'),
     ('cloudpickle', '1.2.2'),
     ('tornado', '6.0.3'),
     ('toolz', '0.10.0')),
    'optional': (('numpy', '1.17.3'),
     ('pandas', '0.25.3'),
     ('bokeh', '1.4.0'),
     ('lz4', '2.2.1'),
     ('dask_ml', '1.1.1'),
     ('blosc', '1.8.1'))}}},
 'client': {'host': [('python', '3.7.6.final.0'),
   ('python-bits', 64),
   ('OS', 'Linux'),
   ('OS-release', '4.19.76+'),
   ('machine', 'x86_64'),
   ('processor', 'x86_64'),
   ('byteorder', 'little'),
   ('LC_ALL', 'en_US.UTF-8'),
   ('LANG', 'en_US.UTF-8'),
   ('LOCALE', 'en_US.UTF-8')],
  'packages': {'required': [('dask', '2.9.0'),
    ('distributed', '2.9.0'),
    ('msgpack', '0.6.2'),
    ('cloudpickle', '1.2.2'),
    ('tornado', '6.0.3'),
    ('toolz', '0.10.0')],
   'optional': [('numpy', '1.17.3'),
    ('pandas', '0.25.3'),
    ('bokeh', '1.4.0'),
    ('lz4', '2.2.1'),
    ('dask_ml', '1.1.1'),
    ('blosc', '1.8.1')]}}}

About this issue

  • Original URL
  • State: open
  • Created 4 years ago
  • Comments: 29 (18 by maintainers)

Commits related to this issue

Most upvoted comments

I have a slightly better understanding of the situation now. The call order is something like

numba_ufunc(dask_array) ->
  numba_ufunc.ufunc(dask_array) ->
    dask_array.__array_ufunc__(...)

The test_numba.ufunc is a NumPy ufunc that is (I think) dynamically generated by numba.

In [2]: b.test_numba
Out[2]: <numba._DUFunc 'test_numba'>

In [3]: b.test_numba.ufunc
Out[3]: <ufunc 'test_numba'>

In [4]: type(b.test_numba.ufunc)
Out[4]: numpy.ufunc

And that’s what chokes up dask’s serialization

In [9]: pickle.loads(pickle.dumps(b.test_numba))
Out[9]: <numba._DUFunc 'test_numba'>

In [10]: pickle.loads(pickle.dumps(b.test_numba.ufunc))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-10-cff7e39b4aa8> in <module>
----> 1 pickle.loads(pickle.dumps(b.test_numba.ufunc))

~/Envs/dask-dev/lib/python3.7/site-packages/numpy/core/__init__.py in _ufunc_reconstruct(module, name)
    130     # scipy.special.expit for instance.
    131     mod = __import__(module, fromlist=[name])
--> 132     return getattr(mod, name)
    133
    134 def _ufunc_reduce(func):

AttributeError: module '__main__' has no attribute 'test_numba'
# b.py
from numba import vectorize, float64, float32


@vectorize([float64(float64), float32(float32)], nopython=True)
def test_numba(a):
    return a**2

Will start looking for solutions now.

@TomAugspurger, seems hackish, but maybe a band-aid is better than nothing.

However, I tried around a bit and I think we are missing that pickle got better, or how reliable it actually is? I.e. I think NumPy is over-engineered and that makes the solution harder than necessary. I tried modifying NumPy like this, but you can also do it manually:


if np.version < "1.20":  # use correct comparison:
    import copyreg

    def _ufunc_reduce(func):
        return func.__name__

    copyreg.pickle(ufunc, _ufunc_reduce, np.core._ufunc_reconstruct)

Now you need one more ingredient, and that is that test_numba.ufunc has to report its __name__ as test_numba.ufunc (a bit like a __qualname__. I tried this by hacking that the ufunc name is mutable. __qualname__ would maybe be better, and I guess we could add a __qualname__ to UFuncs, but if printing the extra .ufunc seems OK, this solution is possible right now maybe.

Now overriding the ufunc pickling outside of NumPy seems pretty extreme, but I am not actually sure its all that bad, I did not check, but I think the above replacement is effectively identical to what NumPy does, except that it supports attributes in a __qualname__ like fashion.

Thanks, my hope is that my file b.py sufficiently simulates a 3rd-party library function. I’m able to reproduce the same error with

import pickle
import fastjmd95

if __name__ == "__main__":
    pickle.loads(pickle.dumps(fastjmd95.rho.ufunc))