jax: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed

Description

I have a python virtual environment with a clean installation of JAX

# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

When I run my scripts, they work perfectly, but sometimes I get the following error with a success rate of between 2 and 10 successful executions and between 1 and 3 failed executions

2023-04-02 16:00:19.964652: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-04-02 16:00:19.964737: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:438] Possibly insufficient driver version: 530.30.2
Traceback (most recent call last):
  File "ddpg_jax_gymnasium_pendulum.py", line 73, in <module>
    key = jax.random.PRNGKey(0)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/random.py", line 136, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 270, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 561, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 573, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base
    return seed(seeds)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 813, in threefry_seed
    lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 458, in shift_right_logical
    return shift_right_logical_p.bind(x, y)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
    return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using?

jax 0.4.8, jaxlib 0.4.7+cuda12.cudnn88

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.8.10, Ubuntu 20.04

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3080 L...    On | 00000000:01:00.0 Off |                  N/A |
| N/A   38C    P3               N/A /  55W|     10MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1528      G   /usr/lib/xorg/Xorg                            4MiB |
|    0   N/A  N/A      2435      G   /usr/lib/xorg/Xorg                            4MiB |
+---------------------------------------------------------------------------------------+

CUDNN version (/usr/local/cuda/include/cudnn_version.h)

#define CUDNN_MAJOR 8
#define CUDNN_MINOR 8
#define CUDNN_PATCHLEVEL 1

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 25 (2 by maintainers)

Most upvoted comments

I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below: image At start, I installed the newest jax as

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then I got the error as reported. So I switched to

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Everything works well!

I also got this error and it was due to GPU reaching its memory limit

(I’m also unable to any Jax code, e.g. a = jnp.ones((3,)).)

Thanks @hosseybposh, for a simple use case I was able to use JAX 0.4.13 and CUDA 11.8 with CUDNN 8.6. I needed to add /usr/lib/x86_64-linux-gnu to the LD_LIBRARY_PATH (installed libcudnn8 with apt-get).

I’m getting the same kind of error trying to install jax/jaxlib on an EC2 p2.xlarge (with k80s), to provide solidarity! I can provide more details if useful, but basically running some vanilla installation script of Anaconda and trying different variants of pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html leads Jax to report seeing the GPU when I check print(xla_bridge.get_backend().platform) but gives the DNN error above, otherwise.

I’m having the same problem but for me it’s consistent and I’m unable to run simple Jax code. I only have this problem on my newest system with 4x RTX 4090 GPUs. I have a server A100 and a PC with a 3090ti that work smoothly. Ubuntu 22 across all systems. First installed CUDA 11 from conda-forge as suggested, same issue. Then switched to loca installation of CUDA and cudnn. Same problem.

After a fresh installation of everything when I run a = jnp.ones((3,)) I get this error:

2023-05-13 09:04:27.790057: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR 2023-05-13 09:04:27.790140: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 5853872128 bytes free, 25393692672 bytes total. Traceback (most recent call last): File “<stdin>”, line 1, in <module> File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py”, line 2122, in ones return lax.full(shape, 1, _jnp_dtype(dtype)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py”, line 1203, in full return broadcast(fill_value, shape) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py”, line 768, in broadcast return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py”, line 796, in broadcast_in_dim return broadcast_in_dim_p.bind( ^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py”, line 380, in bind return self.bind_with_trace(find_top_trace(args), args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py”, line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py”, line 790, in process_primitive return primitive.impl(*tracers, **params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py”, line 131, in apply_primitive compiled_fun = xla_primitive_callable( ^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/util.py”, line 284, in wrapper return cached(config._trace_context(), *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/util.py”, line 277, in cached return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py”, line 222, in xla_primitive_callable compiled = _xla_callable_uncached( ^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py”, line 252, in _xla_callable_uncached return computation.compile().unsafe_call ^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py”, line 2313, in compile executable = UnloadedMeshExecutable.from_hlo( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py”, line 2633, in from_hlo xla_executable, compile_options = _cached_compilation( ^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py”, line 2551, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py”, line 495, in compile_or_get_cached return backend_compile(backend, computation, compile_options, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/profiler.py”, line 314, in wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py”, line 463, in backend_compile return backend.compile(built_c, compile_options=options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.