jax: FFT on CPU noticeably slower than SciPy's FFT

The FFT implementation in JAX seems to be noticeably slower than the one in SciPy even though both use some flavor of the PocketFFT FFT implementation.

from jax.config import config

config.update("jax_enable_x64", True)

import timeit
import numpy as np
from scipy import fft as sp_fft
from jax import numpy as jnp
from jax import jit, random

key = random.PRNGKey(42)
jr = random.normal(key, (2**26, ))
r = np.array(jr)

N_IT = 7

timing = timeit.timeit(lambda: np.fft.fft(r), number=N_IT) / N_IT
print(f"NumPy: {timing} s")
timing = timeit.timeit(lambda: sp_fft.fft(r), number=N_IT) / N_IT
print(f"SciPy: {timing} s")

jax_fft = jnp.fft.fft
timing = timeit.timeit(
    lambda: jax_fft(jr).block_until_ready(), number=N_IT
) / N_IT
print(f"JAX (unjitted): {timing} s")

jax_fft_jit = jit(jax_fft)
jax_fft_jit(jr)  # Warm-up
timing = timeit.timeit(
    lambda: jax_fft_jit(jr).block_until_ready(), number=N_IT
) / N_IT
print(f"JAX (jitted): {timing} s")

On an AMD Ryzen 7 4800H (with Radeon Graphics) with JAX 0.2.18 and jaxlib 0.1.69 installed from PyPI, I get the following timings:

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
NumPy: 3.5420487681424544 s
SciPy: 1.9076589474290293 s
JAX (unjitted): 4.5806920641433475 s
JAX (jitted): 4.63735637514245 s

The timings improve if I compile JAX(lib) myself though it still compares unfavorably to SciPy (compiled JAX (unjitted): 3.506302507857202 s).

My hypothesis is that JAX is distributing binaries that are suboptimal for recent AMD CPUs and much more importantly JAX is probably using some outdated version of PocketFFT.

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 32 (23 by maintainers)

Most upvoted comments

Just to clarify earlier parts of the discussion: the C version of pocketfft is the older one which should not be used anymore, unless the available environment is C only. The C++ version of pocketfft has better performance and many more capabilities and has the nice property of coming as a single header. If you have access to C++17, I recommend the FFT component of ducc0, which is the evolution of C++ pocketfft with many minor improvements (but nothing really groundbreaking).

I don’t think we can use ducc0 ourselves in JAX, because it has a GPL license and JAX has an Apache license.

That’s correct, but if we are talking about the FFT component only, I think I could make that available under BSD terms. Please let me know if you’d be interested in that.

(The directly FFT-related files in ducc0 still have the BSD licensing header, as far as I remember, but some of the support header files don’t, so they’d need to be adjusted.)