jax: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed
Description
I have a python virtual environment with a clean installation of JAX
# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
When I run my scripts, they work perfectly, but sometimes I get the following error with a success rate of between 2 and 10 successful executions and between 1 and 3 failed executions
2023-04-02 16:00:19.964652: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-04-02 16:00:19.964737: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:438] Possibly insufficient driver version: 530.30.2
Traceback (most recent call last):
File "ddpg_jax_gymnasium_pendulum.py", line 73, in <module>
key = jax.random.PRNGKey(0)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/random.py", line 136, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 270, in seed_with_impl
return random_seed(seed, impl=impl)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 561, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 817, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 573, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base
return seed(seeds)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 813, in threefry_seed
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 458, in shift_right_logical
return shift_right_logical_p.bind(x, y)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 817, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/util.py", line 253, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/util.py", line 246, in cached
return f(*args, **kwargs)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
self._executable = UnloadedMeshExecutable.from_hlo(
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
What jax/jaxlib version are you using?
jax 0.4.8, jaxlib 0.4.7+cuda12.cudnn88
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.8.10, Ubuntu 20.04
NVIDIA GPU info
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3080 L... On | 00000000:01:00.0 Off | N/A |
| N/A 38C P3 N/A / 55W| 10MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 1528 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 2435 G /usr/lib/xorg/Xorg 4MiB |
+---------------------------------------------------------------------------------------+
CUDNN version (/usr/local/cuda/include/cudnn_version.h
)
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 8
#define CUDNN_PATCHLEVEL 1
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 25 (2 by maintainers)
I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below:
At start, I installed the newest jax as
Then I got the error as reported. So I switched to
Everything works well!
I also got this error and it was due to GPU reaching its memory limit
(I’m also unable to any Jax code, e.g.
a = jnp.ones((3,))
.)Thanks @hosseybposh, for a simple use case I was able to use JAX 0.4.13 and CUDA 11.8 with CUDNN 8.6. I needed to add
/usr/lib/x86_64-linux-gnu
to theLD_LIBRARY_PATH
(installedlibcudnn8
withapt-get
).I’m getting the same kind of error trying to install jax/jaxlib on an EC2 p2.xlarge (with k80s), to provide solidarity! I can provide more details if useful, but basically running some vanilla installation script of Anaconda and trying different variants of
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
leads Jax to report seeing the GPU when I checkprint(xla_bridge.get_backend().platform)
but gives the DNN error above, otherwise.I’m having the same problem but for me it’s consistent and I’m unable to run simple Jax code. I only have this problem on my newest system with 4x RTX 4090 GPUs. I have a server A100 and a PC with a 3090ti that work smoothly. Ubuntu 22 across all systems. First installed CUDA 11 from conda-forge as suggested, same issue. Then switched to loca installation of CUDA and cudnn. Same problem.
After a fresh installation of everything when I run
a = jnp.ones((3,))
I get this error:2023-05-13 09:04:27.790057: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR 2023-05-13 09:04:27.790140: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 5853872128 bytes free, 25393692672 bytes total. Traceback (most recent call last): File “<stdin>”, line 1, in <module> File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py”, line 2122, in ones return lax.full(shape, 1, _jnp_dtype(dtype)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py”, line 1203, in full return broadcast(fill_value, shape) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py”, line 768, in broadcast return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py”, line 796, in broadcast_in_dim return broadcast_in_dim_p.bind( ^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py”, line 380, in bind return self.bind_with_trace(find_top_trace(args), args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py”, line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py”, line 790, in process_primitive return primitive.impl(*tracers, **params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py”, line 131, in apply_primitive compiled_fun = xla_primitive_callable( ^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/util.py”, line 284, in wrapper return cached(config._trace_context(), *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/util.py”, line 277, in cached return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py”, line 222, in xla_primitive_callable compiled = _xla_callable_uncached( ^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py”, line 252, in _xla_callable_uncached return computation.compile().unsafe_call ^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py”, line 2313, in compile executable = UnloadedMeshExecutable.from_hlo( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py”, line 2633, in from_hlo xla_executable, compile_options = _cached_compilation( ^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py”, line 2551, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py”, line 495, in compile_or_get_cached return backend_compile(backend, computation, compile_options, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/profiler.py”, line 314, in wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File “/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py”, line 463, in backend_compile return backend.compile(built_c, compile_options=options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.