jax: Jax doesn't see my GPU, even though Pytorch does

Jax sounds like an impressive project, thanks for working on it.

That said: on Ubuntu 18.04, this happens

➜  python
Python 3.6.9 (default, Oct  8 2020, 12:12:24) 
[GCC 8.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.

>>> import torch, jax; print(torch.cuda.is_available()); print(jax.devices())
True
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]

I first tried to pip install jax and got various errors; the error messages said it was common with old versions of pip that didn’t support newer kinds of wheel and directed me to upgrade pip, which I did (from 9 to 20). Now Jax seems to be installed (at least various numpy-compatible functions do), but not to the point where it appears to see my GPU, a laptop “Geforce” card by Nvidia.

I’m not sure what systems diagnostics I can bring to help. This is the name of my card:

✗  lspci | grep NVIDIA
01:00.0 3D controller: NVIDIA Corporation GP108M [GeForce MX150] (rev a1)

As best as I understand it, these are drivers:

✗  lsmod | grep nvidia
nvidia_uvm            970752  0
nvidia_drm             53248  3
nvidia_modeset       1212416  2 nvidia_drm
nvidia              27643904  103 nvidia_uvm,nvidia_modeset
drm_kms_helper        184320  2 nvidia_drm,i915
drm                   491520  8 drm_kms_helper,nvidia_drm,i915

Thanks for reading this anyway.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 13
  • Comments: 19 (7 by maintainers)

Most upvoted comments

Do you solve this issue?

Unfortunately not. I don’t have sudo control over the cluster and this makes it hard. The best way for JAX is that they ship a cuda bundle with the installation, similar to PyTorch.

@morawi Curious to know if this is solved for you yet since I’m going through the same thing with JAX on a slurm cluster

I just stopped using it.

Hello 👋 Just to confirm - did you follow the Linux-specific installation instructions from the README? Also, have you tried installing JAX in a separate virtual environment that excludes PyTorch? 🤷‍♀️

https://github.com/google/jax#installation

On Linux, it is often necessary to first update pip to a version that supports manylinux2010 wheels.

If you want to install JAX with both CPU and GPU support, using existing CUDA and CUDNN7 installations on your machine (for example, preinstalled on your cloud VM), you can run

pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.57+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

The jaxlib version must correspond to the version of the existing CUDA installation you want to use, with cuda110 for CUDA 11.0, cuda102 for CUDA 10.2, and cuda101 for CUDA 10.1. You can find your CUDA version with: install path:

nvcc --version

Note that some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X, where X.X should be replaced with the CUDA version number (e.g. cuda-10.2). If CUDA is installed elsewhere on your system, you can either create a symlink:

sudo ln -s /path/to/cuda /usr/local/cuda-X.X

Or set the following environment variable before importing JAX:

XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda

I had the same problem with jax not recognizing GPU. I did the following two steps:

pip install --upgrade pip pip install --upgrade jax jaxlib==0.1.57+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

and make a link as follows:

ln -s /usr/lib/nvidia-cuda-toolkit /usr/local/cuda-10.1

After this, jax still didn’t recognize GPU. Then, I did the following steps hinted from the warning message in jax about GPU:

cd /usr/lib/nvidia-cuda-toolkit mkdir nvvm cd nvvm sudo ln -s /usr/lib/nvidia-cuda-toolkit/libdevice libdevice

You would need to use “sudo” for the above steps. After these, jax recognises my GPU.

Similar to myjr52, I was able to solve this simply by replacing this:

pip install --upgrade jax jaxlib

with this (you’ll have the change cuda111 based on your output of nvcc --version - mine is 11.1):

pip install --upgrade jax jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I didn’t need to do any of the additional steps mentioned by myjr52.

sudo ln -s /usr/lib/nvidia-cuda-toolkit/libdevice libdevice

I am remotely connected to a slurm cluster and do not have sudo rights. In fact, I do not even have permission to make a symbolic link. Plus, my environment has no GPU, the GPU is assigned via the sbatch job file using the command “#SBATCH --gres=gpu:1” . This is way too complicated. Yet, PyTorch seems to work perfectly well.

I have no experience installing cuda in a specific env. It seems that the symbolic link wouldn’t work, see this thread: https://discuss.pytorch.org/t/where-is-cudatoolkit-path-when-installed-via-conda/47791/5

I think many new jax users will come from pytorch, so adding a nudge a la ‘If you’re coming from pytorch, make sure to install cuda separately, if you haven’t yet.’) Two more observations: I could run nvidia-smi, but not nvcc, so this might be a nice check to see if you have pytorch cuda or systemwide. Furthermore, jupyter notebooks tend to die silently with these issues, so running things as a script gives you much more info.

Guess it’s a bit late for this. But I got mine fixed by specifying the exact whl link found in the https://storage.googleapis.com/jax-releases/jax_releases.html. Just I need cuda 11.0.

The one I used was:

pip uninstall jax jaxlib -y
pip install https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.71+cuda110-cp38-none-manylinux2010_x86_64.whl

I had the same issue, but managed to solve it. It seems pytorch bundles its own cuda, so that’s why you don’t have to install it separately but it sees your gpu and nvidia-smi works. Installling cuda for your GPU following these instructions solved the issue for me: https://developer.nvidia.com/cuda-downloads

@aquaresima Please open a new issue, please don’t ping long-closed issues.

If someone else stumbles into this, the CUDA wheel releases are now stored on https://storage.googleapis.com/jax-releases/jax_cuda_releases.html for some reason.

I have solved this problem very easily just following this issue, some googling and stumbling upon two SO questions, and the readme of this project.

I had nvidia drivers installed in my laptop through the Pop OS store, and I installed nvidia-cuda-toolkit through apt, and then installed PyTorch (earlier).

My cuda version is 11.2.

I did not have to do any other installation.

  1. I upgraded pip.
  2. I installed jax[cuda11] instead if just jax.
  3. Followed other generic instructions from the ReadMe.
  4. Created two symlinks- one for nvcc and another for cuda.

Then it started working great.

FYI https://github.com/google/jax/pull/6581 bundles libdevice.10.bc with jaxlib wheels, which hopefully will help avoid this particular problem. If you are feeling motivated, you could try patching in that PR and building a jaxlib from source to see if it fixes your problems.

Do you solve this issue?

Unfortunately not. I don’t have sudo control over the cluster and this makes it hard. The best way for JAX is that they ship a cuda bundle with the installation, similar to PyTorch.