jax: Unimplemented: DNN library is not found.
Working on local GPU RTX 2060 super, Cuda 11.1, and got this error.
jax has been installed successfully with the following
pip install --upgrade jax jaxlib==0.1.57+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
and symlink
sudo ln -s /path/to/cuda /usr/local/cuda-11.1
jax outputs the gpu with
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
and do math stuff like
rng_key = random.PRNGKey(0)
however still can’t train the model
evaluate(model, test_ds)
FilteredStackTrace: RuntimeError: Unimplemented: DNN library is not found.
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<ipython-input-8-0f8618edbb7d> in <module>()
13 return compute_metrics(logits, eval_ds['label'])
14
---> 15 evaluate(model, test_ds)
/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
131 def reraise_with_filtered_traceback(*args, **kwargs):
132 try:
--> 133 return fun(*args, **kwargs)
134 except Exception as e:
135 if not is_under_reraiser(e):
/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
221 backend=backend,
222 name=flat_fun.__name__,
--> 223 donated_invars=donated_invars)
224 return tree_unflatten(out_tree(), out)
225
/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
1175
1176 def bind(self, fun, *args, **params):
-> 1177 return call_bind(self, fun, *args, **params)
1178
1179 def process(self, trace, fun, tracers, params):
/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1166 tracers = map(top_trace.full_raise, args)
1167 with maybe_new_sublevel(top_trace):
-> 1168 outs = primitive.process(top_trace, fun, tracers, params)
1169 return map(full_lower, apply_todos(env_trace_todo(), outs))
1170
/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1178
1179 def process(self, trace, fun, tracers, params):
-> 1180 return trace.process_call(self, fun, tracers, params)
1181
1182 def post_process(self, trace, out_tracers, params):
/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
577
578 def process_call(self, primitive, f, tracers, params):
--> 579 return primitive.impl(f, *tracers, **params)
580 process_map = process_call
581
/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
557 *unsafe_map(arg_spec, args))
558 try:
--> 559 return compiled_fun(*args)
560 except FloatingPointError:
561 assert FLAGS.jax_debug_nans # compiled_fun can only raise in this case
/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/interpreters/xla.py in _execute_compiled(compiled, avals, handlers, *args)
805 device, = compiled.local_devices()
806 input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
--> 807 out_bufs = compiled.execute(input_bufs)
808 if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
809 return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
RuntimeError: Unimplemented: DNN library is not found.
About this issue
- Original URL
- State: open
- Created 4 years ago
- Comments: 27 (4 by maintainers)
I was able to solve this problem by adding these 4 lines of code at the head of the file:
I was also getting this error. I don’t know the details of what was happening, but the issue for me seemed to stem from JAX and Tensorflow not sharing the GPU nicely. When I added this code snippet to the top of my code it seems to run (taken from the Flax MNIST example):
The comment suggests this is a known issue, but a quick google only brings up an old closed issue #120. I don’t get the same issue running the same code on Colab (without the above snippet), so it may be particular to my machine’s configuration.
Edit: Ah I now see there is a whole page on this in the JAX documentation. Would be very useful if JAX could detect this issue and give a helpful error message. Based on the “DNN library not found” error I went down the rabbit hole of thinking I had the wrong version of cuda/cudnn.
This issue seems related deepmind/dm-haiku#83, perhaps something recently has changed?
What worked for me:
conda install -c anaconda cudnn=8.2.1 cudatoolkit=11.3
Check if LD_LIBRARY_PATH is empty:
echo $LD_LIBRARY_PATH
If emptyexport LD_LIBRARY_PATH=$CONDA_PREFIX/lib/
elseexport LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/
Finally
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
About 5 months ago (https://github.com/google/jax/commit/a141cc6e8d36ff10e28180683588bedf5432df1a) we switched how we link GPU libraries to be the same as TensorFlow, namely, we use
dlopen()
to find libraries like CuDNN rather than linking against them directly.dlopen()
looks for libraries usingLD_LIBRARY_PATH
, so that’s ultimately the cause of this error: we can’t find the libraries.I suspect you would see the exact same behavior with
tensorflow
with GPU support: as far as I am aware, it uses the same code to find the GPU libraries. It might be interesting to verify that hypothesis: install a GPU version of TF and try running a convolution. You should see the same error as JAX (if you haven’t setLD_LIBRARY_PATH
).I also suspect if you set
TF_CPP_MIN_LOG_LEVEL=0
then you may see some better logging that more clearly indicates what the real problem is.I agree the error message isn’t very helpful; we should probably fix that.
Still broken.
I am also having the same issue. I can confirm that my LD_LIBRARY is correctly configured. I pointed LD_LIBRARY_PATH to CUDA path, and there is
libcudnn.so.7
under it . But I get2021-12-20 00:51:16.501936: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dyn amic library 'libcudnn.so.7'; dlerror: libcudnn.so.7: cannot open shared object file: No such file or directory;
I sol
I solved by doing this, thanks a lot !
Arrived here after googling, running in to the same error with the
DNN library
.The comment from @gabehope helped me resolve my problem. Specifically, I was running both Tensorflow and JAX in the same script and, presumably, they were both fighting for GPU memory.
For reference, here’s the (quite helpful!) page on memory allocation with JAX.
It would be helpful if there were some way for the error to better indicate that it’s a memory issue, though it sounds like for others it may be a different problem than what Gabe and I were running into.