numba: Function accepting njitted functions as arguments is slow

I was trying numba 0.38 and the new support for jitted functions as arguments with this code snippet:

# coding: utf-8
from scipy.optimize import newton
from numba import njit
@njit
def func(x):
    return x**3 - 1
@njit
def fprime(x):
    return 3 * x**2
@njit
def njit_newton(func, x0, fprime):
    for _ in range(50):
        fder = fprime(x0)
        fval = func(x0)
        newton_step = fval / fder
        x = x0 - newton_step
        if abs(x - x0) < 1.48e-8:
            return x
        x0 = x
            
get_ipython().run_line_magic('timeit', 'newton(func.py_func, 1.5, fprime=fprime.py_func)')
get_ipython().run_line_magic('timeit', 'newton(func, 1.5, fprime=fprime)')
get_ipython().run_line_magic('timeit', 'njit_newton.py_func(func, 1.5, fprime=fprime)')
get_ipython().run_line_magic('timeit', 'njit_newton(func, 1.5, fprime=fprime)')

And I found surprising that njit_newton is the slowest of all, while njit_newton.py_func is the fastest:

$ ipython test_perf.py 
4.76 µs ± 8.52 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
4.14 µs ± 30.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
3.58 µs ± 26 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
20 µs ± 85.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cc @nikita-astronaut

(Inspiration: https://github.com/scipy/scipy/blob/607a21e07dad234f8e63fcf03b7994137a3ccd5b/scipy/optimize/zeros.py#L164-L182)

About this issue

  • Original URL
  • State: open
  • Created 6 years ago
  • Reactions: 1
  • Comments: 21 (15 by maintainers)

Commits related to this issue

Most upvoted comments

I can confirm that this issue exists. However, as mentioned above, the issue does in fact seem to be caused by a cost when calling Numba jitted code from Python.

import numpy as np
import numba as nb

@nb.njit
def foo(x):
    return x

@nb.njit
def foo_bad(x, func):
    return x

@nb.njit
def foo_bad_alt(x, func):
    return func(x)

@nb.njit
def bar(x):
    s = 0.0

    for v in x:
        s += v

    return s

@nb.njit
def bar_bad(x, func):
    s = 0.0

    for v in x:
        s += v

    return s

@nb.njit
def bar_bad_alt(x, func):
    return func(x)
%timeit foo(10)
%timeit foo_bad(10, foo)
%timeit foo_bad_alt(10, foo)

# 167 ns ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
# 17.5 µs ± 1.54 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# 15.7 µs ± 591 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
data = np.random.random(1000000)

%timeit bar(data)
%timeit bar_bad(data, bar)
%timeit bar_bad_alt(data, bar)

# 1.02 ms ± 57.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 1.1 ms ± 65.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 1.05 ms ± 29.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

The difference in performance when comparing the foo functions is great, however, since timeit is called from the Python context these timings are largely affected by Numba invokation costs.

The difference in performance when comparing the bar functions is minimal, because now most of the time is actually spent in the function and not in interfacing between Numba and Python.

For reference, if the functions do any real work, the differences disappear (and strangely reverse, which I cannot explain)

import numba                                                                                                                                                                                                                          

from numba import njit                                                                                                                                                                                                                

@njit 
def foo(x): 
    a = 0.
    for i in range(10000000):
        a += i
    return a
                                                                                                                                                                                                                                       

def bar(x, f): 
    a = 0.
    for i in range(10000000):
        a += i
    return a
                                                                                                                                                                                                                                       

bar_jit = njit(bar)                                                                                                                                                                                                                   
foo(1)          
%timeit foo(1)                                                                                                                                                                                                                        
#26.6 ms ± 2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

bar_jit(1, foo)   
%timeit bar_jit(1, foo)                                                                                                                                                                                                               
#25.7 ms ± 2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)