jax: cannot find libdevice
Hi
Jax cannot find libdevice. I’m running Python 3.7 with cuda 10.0 on my personal laptop qwith a GeForce RTX 2080. I installed jax using pip.
I made a little test script shown below
import os
os.environ["XLA_FLAGS"]="--xla_gpu_cuda_data_dir=/home/murphyk/miniconda3/lib"
os.environ["CUDA_HOME"]="/usr"
import jax
import jax.numpy as np
print("jax version {}".format(jax.__version__))
from jax.lib import xla_bridge
print("jax backend {}".format(xla_bridge.get_backend().platform))
from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (5,5))
print(x)
The output is shown below.
jax version 0.1.39
jax backend gpu
2019-07-07 16:44:03.905071: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc:105] Unknown compute capability (7, 5) .Defaulting to libdevice for compute_20
Traceback (most recent call last):
File "<ipython-input-15-e39e42274024>", line 1, in <module>
runfile('/home/murphyk/github/pyprobml/scripts/jax_debug.py', wdir='/home/murphyk/github/pyprobml/scripts')
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 827, in runfile
execfile(filename, namespace)
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 110, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "/home/murphyk/github/pyprobml/scripts/jax_debug.py", line 18, in <module>
x = random.normal(key, (5,5))
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/random.py", line 389, in normal
return _normal(key, shape, dtype)
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/api.py", line 123, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/core.py", line 663, in call_bind
ans = primitive.impl(f, *args, **params)
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 606, in xla_call_impl
compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/linear_util.py", line 208, in memoized_fun
ans = call(f, *args)
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 621, in xla_callable
compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 207, in compile_jaxpr
backend=xb.get_backend()), result_shape
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jaxlib/xla_client.py", line 535, in Compile
return backend.compile(self.computation, compile_options)
File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jaxlib/xla_client.py", line 118, in compile
compile_options.device_assignment)
RuntimeError: Not found: ./libdevice.compute_20.10.bc not found
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Reactions: 1
- Comments: 35 (4 by maintainers)
Links to this issue
Commits related to this issue
- Update installation directions in README to mention expected CUDA location. See https://github.com/google/jax/issues/989 — committed to skye/jax by skye 4 years ago
- Update installation directions in README to mention expected CUDA location. (#3190) See https://github.com/google/jax/issues/989 — committed to google/jax by skye 4 years ago
- Update installation directions in README to mention expected CUDA location. (#3190) See https://github.com/google/jax/issues/989 — committed to NeilGirdhar/jax by skye 4 years ago
The folks at Lambda (maker of my TensorBook laptop) looked at the source code and suggested this fix:
This actually works 😃
Maybe worth updating the set of locations that JAX searches for libdevice.10.bc?
On Mon, Jul 8, 2019 at 10:41 AM Kevin Murphy murphyk@gmail.com wrote:
Here’s my understanding of this issue:
jax depends on XLA, which is built as part of TF and bundled up into the
jaxlib
package. By default, TF is compiled to look for cuda and cudnn in/usr/local/cuda
: https://github.com/tensorflow/tensorflow/blob/master/third_party/gpus/cuda_configure.bzl#L14So symlinking your cuda install to
/usr/local/cuda
should work. Make surelibdevice
actually exists… I always have a hard time figuring out which Nvidia downloads contain libraries, but I think libdevice is shipped as part of https://developer.nvidia.com/cuda-toolkit.Alternatively, setting the environment variable
XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda
should work. I recommend exporting this outside the Python interpreter to be sure it’s being picked up when jaxlib is loaded (there’s probably a more targeted way to do it, but this will limit mistakes).Is anyone still having problems after trying these methods?
We should also potentially make a jax-specific environment variable to set a custom cuda install path, or at least document the XLA_FLAGS one more clearly… I can do that once we verify this actually works.
Good news! As of jaxlib 0.1.66, which was just released yesterday, we now bundle
libdevice
inside thejaxlib
CUDA wheels. JAX should now always find it successfully. Hope that helps!I’m getting this same error with python3.7 and CUDA 10.0. It seems like it doesn’t actually check CUDA_DIR? Symlinking my CUDA_DIR to /usr/local/cuda solved the problem.
First check where your CUDA installation resides.
For instance, if the above command spits out
/usr/lib/cuda
, then that’s where your CUDA installation is. So, we now need to get the version number.But JAX, by default, looks for CUDA installation in
/usr/local/cuda-<version>
. So, we need to create a symlink with the specific version. This would do the redirection to actual cuda installation location when JAX searches in/usr/local/cuda-<version>
.The above steps, in total in that order, would solve the issue, at least it did for me.
Note: Please keep in mind that the CUDA version (e.g., 11.0) of JAX binary that is installed on your machine should also be compiled for the specific CUDA installation <version> in
/usr/local/cuda-<11.0>
.For people running into this problem after an install of Ubuntu 20.04 with Ubuntu’s cuda toolkit package, KeAWang’s suggestion works but you need cuda-10.1 instead:
sudo ln -s /usr/lib/cuda /usr/local/cuda-10.1
Does
/usr/local/cuda/bin/ptxas
exist? You may need to install the CUDA toolkit if not.Thanks @KeAWang and @iamlemec. Agreed this could be much clearer. I’ve filed an internal bug against XLA with your suggestions and some of my own 😃 These are the suggestions:
Please comment if I should correct anything or you have other suggestions!
This problem is addressed here:
https://github.com/google/jax/discussions/6479#discussioncomment-622839
the instruction needs to clarify that one should make
nvvm
folder inside of the cuda to work.@skye Both those methods are working for me. As @KeAWang documented, I’m also seeing that in the absence of XLA_FLAGS info it will look in /usr/local/cuda-XXX, depending on the CUDA version. Would be great if the XLA folks could either actually check CUDA_DIR or simply not have an error message claiming to do so.