jax: RuntimeError: Internal: libdevice not found at ./libdevice.10.bc

I install GPU version JAX and I encounter the following error when I first run a program with the AD support of JAX.

The following is the error info.

2020-10-05 16:30:40.039587: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:70] Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may result in compilation or runtime failures, if the program we try to run uses routines from libdevice.
2020-10-05 16:30:40.039604: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:71] Searched for CUDA in the following directories:
2020-10-05 16:30:40.039608: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74]   ./cuda_sdk_lib
2020-10-05 16:30:40.039611: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74]   /usr/local/cuda-10.1
2020-10-05 16:30:40.039613: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74]   .
2020-10-05 16:30:40.039616: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:76] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
i: 0
2020-10-05 16:30:40.769544: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:320] libdevice is required by this HLO module but was not found at ./libdevice.10.bc
Traceback (most recent call last):
  File "adpath.py", line 252, in <module>
...........
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 347, in fn
    return lax_fn(x)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/lax/lax.py", line 285, in sqrt
    return sqrt_p.bind(x)
jax.traceback_util.FilteredStackTrace: RuntimeError: Internal: libdevice not found at ./libdevice.10.bc
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "adpath.py", line 252, in <module>
.......
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/api.py", line 900, in jacfun
    y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/traceback_util.py", line 137, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/api.py", line 1217, in batched_fun
    out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/batching.py", line 36, in batch
    return batched_fun.call_wrapped(*in_vals)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/api.py", line 1681, in _jvp
    out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  ..........
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 347, in fn
    return lax_fn(x)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/lax/lax.py", line 285, in sqrt
    return sqrt_p.bind(x)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/core.py", line 266, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/core.py", line 574, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 224, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 264, in xla_primitive_callable
    compiled = backend_compile(backend, built_c, options)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 325, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: Internal: libdevice not found at ./libdevice.10.bc
====================================================================================

Maybe I don’t install the correct CUDA version. I don’t know. Can anyone help me figure it out?

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 18 (3 by maintainers)

Most upvoted comments

If you are not able to ln -s your cuda (for example because public cluster permissions), you always can set XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda before each command you need to use cuda:

XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda python3 your_script_with_cuda.py

Thank you very much for that post.

sudo ln -s /usr/lib/cuda /usr/local/cuda-10.1

This actually solves my problem. Thanks very much @skye

In case it’s helpful for anyone: I just struggled with this on my university’s managed cluster. I set the XLA_FLAGS environment variable, but to no avail. I wasn’t able to make the symlink to /usr/local/cuda-11.1 suggested as I don’t have admin rights. However, I noticed that JAX reported:

2021-05-28 14:08:11.947002: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:71] Searched for CUDA in the following directories:                                              
2021-05-28 14:08:11.947013: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74]   ./cuda_sdk_lib                                                                             
2021-05-28 14:08:11.947021: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74]   /usr/local/cuda-11.1                                                                       
2021-05-28 14:08:11.947028: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74]   . 

So I symlinked the CUDA directory to ./cuda_sdk_lib, and this did the trick! Admittedly it’s not the neatest solution, as I will have to make new symlinks whenever I change directories, but for me it’s a workaround I can live with, so I thought I’d share it.

How would one solve this on colab?

@miguelgargallo That’s a TensorFlow error, not a JAX one. Please file an issue on the TensorFlow github issue tracker.

sudo ln -s /usr/lib/cuda /usr/local/cuda-10.1

This actually solves my problem. Thanks very much @skye

@PostQuantum thanks. A similar command worked for me on ArchLinux:

sudo ln -s /opt/cuda /usr/local/cuda

UPD: this has nothing to do with JAX. I just had a similar issue.

I have a lambdalabs machine, and when I try to local where cuda is, I get the following

whereis -b cuda cuda: /usr/include/cuda.h /usr/include/cuda /usr/local/cuda011.1 The last one shows up because I tried to follow some of the online instructions and made an error along the way.

Given that this is where my cuda is, I am not sure how to proceed.

Any help would be appreciated