jax: Could not load library libcublasLt.so.12. Error

Description

I have some issues with a new jax install. While installation concludes without error, as soon as I try to do operations with an array I get the following message:

>>> import jax.numpy as jnp
>>> a = jnp.arange(4)
Could not load library libcublasLt.so.12. Error: libcublasLt.so.12: cannot open shared object file: No such file or directory
Aborted (core dumped)

This looks a lot like a faulty cuda install, however, inside another virtual environment I’m running an older jax installation without any error (jax=0.3.25, jaxlib=0.3.25+cuda11+cudnn82).

I’ve tried jax=0.4.6, jaxlib=0.4.6+cuda11.cudnn82 as well as the corresponding version with cudnn86 giving the same error.

It seems like jax is looking for something cuda12 related?

What jax/jaxlib version are you using?

jax=0.4.6, jaxlib=0.4.6+cuda11.cudnn82

Which accelerator(s) are you using?

GPU

Additional system info

Linux (Ubuntu)

NVIDIA GPU info

On my system I installed cuda 11.8 and cudnn 8.8. nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

nvidia-smi

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.05    Driver Version: 525.85.05    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:0C:00.0  On |                  N/A |
|  0%   53C    P5    47W / 340W |   9157MiB / 10240MiB |     37%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 15 (4 by maintainers)

Most upvoted comments

$ nvidia-smi shows CUDA 12 (it does for me too)

nvidia-smi shows the version your driver is capable of handling not the actual version installed in your machine.

I was able to resolve this only by downgrading:

$ pip install jax==0.4.2 https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.2+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl

maybe I’m missing something; I’m not sure why $ nvidia-smi shows CUDA 12 (it does for me too) whereas

$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_Mar__8_18:18:20_PST_2022
Cuda compilation tools, release 11.6, V11.6.124
Build cuda_11.6.r11.6/compiler.31057947_0

also, all the library files that I see are for CUDA 11, e.g.

$ find /usr/local -name 'libcublasLt.so.*'
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcublasLt.so.11
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcublasLt.so.11.9.2.110
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcublasLt.so.11
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcublasLt.so.11.11.3.6

EDIT: nvidia-smi shows the highest version of CUDA that is supported by the driver, not the version installed (https://stackoverflow.com/a/70032121)