numba: numba.jit fails on sympy.lambdify functions

In the SymPy master, something has broken the ability to do numba.jit on a lambdified function.

>>> import numba
>>> from sympy import *
>>> f = lambdify(x, sin(x), 'numpy')
>>> f(1)
0.8414709848078965
>>> numba.jit(f)(1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/dispatcher.py", line 287, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/dispatcher.py", line 555, in compile
    cres = self._compiler.compile(args, return_type)
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/dispatcher.py", line 81, in compile
    flags=flags, locals=self.locals)
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/compiler.py", line 699, in compile_extra
    return pipeline.compile_extra(func)
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/compiler.py", line 352, in compile_extra
    return self._compile_bytecode()
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/compiler.py", line 660, in _compile_bytecode
    return self._compile_core()
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/compiler.py", line 647, in _compile_core
    res = pm.run(self.status)
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/compiler.py", line 238, in run
    raise patched_exception
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/compiler.py", line 230, in run
    stage()
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/compiler.py", line 366, in stage_analyze_bytecode
    func_ir = translate_stage(self.func_id, self.bc)
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/compiler.py", line 762, in translate_stage
    return interp.interpret(bytecode)
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/interpreter.py", line 97, in interpret
    self.dfa.run()
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/dataflow.py", line 27, in run
    self.infos[blk.offset] = self.run_on_block(blk)
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/dataflow.py", line 71, in run_on_block
    self.dispatch(info, inst)
  File "/Users/aaronmeurer/anaconda/lib/python3.5/site-packages/numba/dataflow.py", line 80, in dispatch
    fn = getattr(self, fname)
AttributeError: Failed at object (analyzing bytecode)
'DataFlowAnalysis' object has no attribute 'op_LOAD_CLOSURE'

Note that you need to run this from this pull request because of another issue (we weren’t using functools.wraps).

I bisected the change in SymPy, and I think it is due to that wrapper function that we are using.

If this is not easy to fix in Numba if you could suggest a workaround that we could use in SymPy, that would be great.

About this issue

  • Original URL
  • State: closed
  • Created 7 years ago
  • Comments: 28 (25 by maintainers)

Most upvoted comments

We refactored how lambdify works. It now produces a function instead of a lambda, and it no longer wraps it in a decorator. I still get the CALL_FUNCTION_EX with **kwargs not supported error with sympy 1.1.1, because the decorator used the standard *args, **kwargs pass through semantics (https://github.com/sympy/sympy/blob/sympy-1.1.1/sympy/utilities/lambdify.py#L440).

I’m curious if there are any plans, distant or otherwise, to support kwargs in CALL_FUNCTION_EX?

It looks like this works now:

In [14]: >>> import numba
    ...: >>> from sympy import *
    ...: >>> f = lambdify(x, sin(x), 'numpy')
    ...: >>> print(f(1))
    ...: >>> print(numba.jit(f)(1))
0.8414709848078965
0.8414709848078965

In [15]: sympy.__version__
Out[15]: '1.3'

In [16]: numba.__version__
Out[16]: '0.43.0'

Is this supposed to work now? With sympy 1.1.1 and numba '0.35.0rc1+15.g24edca2 I’m getting a CALL_FUNCTION_EX error when trying to jit a lambdafied function:

from numba import jit
import sympy as sym

y = sym.exp(-x)
tanh = (1.0 - y) / (1.0 + y)
ddx_tanh = tanh.diff(x)
f = sym.lambdify(x, ddx_tanh, 'numpy')
g = jit(f)
In [12]: f(1.0)
Out[12]: 0.3932238664829637
In [13]: g(1.0)

In [14]: Traceback (most recent call last):

  File "<ipython-input-13-7c2977bc7ed3>", line 1, in <module>
    g(1.0)

  File "C:\Miniconda3\lib\site-packages\numba\dispatcher.py", line 307, in _compile_for_args
    return self.compile(tuple(argtypes))

  File "C:\Miniconda3\lib\site-packages\numba\dispatcher.py", line 579, in compile
    cres = self._compiler.compile(args, return_type)

  File "C:\Miniconda3\lib\site-packages\numba\dispatcher.py", line 80, in compile
    flags=flags, locals=self.locals)

  File "C:\Miniconda3\lib\site-packages\numba\compiler.py", line 763, in compile_extra
    return pipeline.compile_extra(func)

  File "C:\Miniconda3\lib\site-packages\numba\compiler.py", line 360, in compile_extra
    return self._compile_bytecode()

  File "C:\Miniconda3\lib\site-packages\numba\compiler.py", line 722, in _compile_bytecode
    return self._compile_core()

  File "C:\Miniconda3\lib\site-packages\numba\compiler.py", line 709, in _compile_core
    res = pm.run(self.status)

  File "C:\Miniconda3\lib\site-packages\numba\compiler.py", line 246, in run
    raise patched_exception

  File "C:\Miniconda3\lib\site-packages\numba\compiler.py", line 238, in run
    stage()

  File "C:\Miniconda3\lib\site-packages\numba\compiler.py", line 374, in stage_analyze_bytecode
    func_ir = translate_stage(self.func_id, self.bc)

  File "C:\Miniconda3\lib\site-packages\numba\compiler.py", line 827, in translate_stage
    return interp.interpret(bytecode)

  File "C:\Miniconda3\lib\site-packages\numba\interpreter.py", line 98, in interpret
    self.dfa.run()

  File "C:\Miniconda3\lib\site-packages\numba\dataflow.py", line 28, in run
    self.infos[blk.offset] = self.run_on_block(blk)

  File "C:\Miniconda3\lib\site-packages\numba\dataflow.py", line 78, in run_on_block
    self.dispatch(info, inst)

  File "C:\Miniconda3\lib\site-packages\numba\dataflow.py", line 88, in dispatch
    fn(info, inst)

  File "C:\Miniconda3\lib\site-packages\numba\dataflow.py", line 354, in op_CALL_FUNCTION_EX
    raise NotImplementedError(errmsg)

NotImplementedError: Failed at object (analyzing bytecode)
CALL_FUNCTION_EX with **kwargs not supported