jax: Jax doesn't see my GPU, even though Pytorch does
Jax sounds like an impressive project, thanks for working on it.
That said: on Ubuntu 18.04, this happens
➜ python
Python 3.6.9 (default, Oct 8 2020, 12:12:24)
[GCC 8.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch, jax; print(torch.cuda.is_available()); print(jax.devices())
True
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
I first tried to pip install jax and got various errors; the error messages said it was common with old versions of pip that didn’t support newer kinds of wheel and directed me to upgrade pip, which I did (from 9 to 20). Now Jax seems to be installed (at least various numpy-compatible functions do), but not to the point where it appears to see my GPU, a laptop “Geforce” card by Nvidia.
I’m not sure what systems diagnostics I can bring to help. This is the name of my card:
✗ lspci | grep NVIDIA
01:00.0 3D controller: NVIDIA Corporation GP108M [GeForce MX150] (rev a1)
As best as I understand it, these are drivers:
✗ lsmod | grep nvidia
nvidia_uvm 970752 0
nvidia_drm 53248 3
nvidia_modeset 1212416 2 nvidia_drm
nvidia 27643904 103 nvidia_uvm,nvidia_modeset
drm_kms_helper 184320 2 nvidia_drm,i915
drm 491520 8 drm_kms_helper,nvidia_drm,i915
Thanks for reading this anyway.
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Reactions: 13
- Comments: 19 (7 by maintainers)
I just stopped using it.
Hello 👋 Just to confirm - did you follow the Linux-specific installation instructions from the README? Also, have you tried installing JAX in a separate virtual environment that excludes PyTorch? 🤷♀️
https://github.com/google/jax#installation
I had the same problem with jax not recognizing GPU. I did the following two steps:
pip install --upgrade pip pip install --upgrade jax jaxlib==0.1.57+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
and make a link as follows:
ln -s /usr/lib/nvidia-cuda-toolkit /usr/local/cuda-10.1
After this, jax still didn’t recognize GPU. Then, I did the following steps hinted from the warning message in jax about GPU:
cd /usr/lib/nvidia-cuda-toolkit mkdir nvvm cd nvvm sudo ln -s /usr/lib/nvidia-cuda-toolkit/libdevice libdevice
You would need to use “sudo” for the above steps. After these, jax recognises my GPU.
Similar to myjr52, I was able to solve this simply by replacing this:
with this (you’ll have the change
cuda111
based on your output ofnvcc --version
- mine is11.1
):I didn’t need to do any of the additional steps mentioned by myjr52.
I am remotely connected to a slurm cluster and do not have sudo rights. In fact, I do not even have permission to make a symbolic link. Plus, my environment has no GPU, the GPU is assigned via the sbatch job file using the command “#SBATCH --gres=gpu:1” . This is way too complicated. Yet, PyTorch seems to work perfectly well.
I have no experience installing cuda in a specific env. It seems that the symbolic link wouldn’t work, see this thread: https://discuss.pytorch.org/t/where-is-cudatoolkit-path-when-installed-via-conda/47791/5
I think many new jax users will come from pytorch, so adding a nudge a la ‘If you’re coming from pytorch, make sure to install cuda separately, if you haven’t yet.’) Two more observations: I could run nvidia-smi, but not nvcc, so this might be a nice check to see if you have pytorch cuda or systemwide. Furthermore, jupyter notebooks tend to die silently with these issues, so running things as a script gives you much more info.
Guess it’s a bit late for this. But I got mine fixed by specifying the exact whl link found in the https://storage.googleapis.com/jax-releases/jax_releases.html. Just I need cuda 11.0.
The one I used was:
I had the same issue, but managed to solve it. It seems pytorch bundles its own cuda, so that’s why you don’t have to install it separately but it sees your gpu and nvidia-smi works. Installling cuda for your GPU following these instructions solved the issue for me: https://developer.nvidia.com/cuda-downloads
@aquaresima Please open a new issue, please don’t ping long-closed issues.
If someone else stumbles into this, the CUDA wheel releases are now stored on https://storage.googleapis.com/jax-releases/jax_cuda_releases.html for some reason.
I have solved this problem very easily just following this issue, some googling and stumbling upon two SO questions, and the readme of this project.
I had nvidia drivers installed in my laptop through the Pop OS store, and I installed
nvidia-cuda-toolkit
through apt, and then installed PyTorch (earlier).My cuda version is 11.2.
I did not have to do any other installation.
jax[cuda11]
instead if justjax
.Then it started working great.
FYI https://github.com/google/jax/pull/6581 bundles
libdevice.10.bc
with jaxlib wheels, which hopefully will help avoid this particular problem. If you are feeling motivated, you could try patching in that PR and building a jaxlib from source to see if it fixes your problems.Unfortunately not. I don’t have sudo control over the cluster and this makes it hard. The best way for JAX is that they ship a cuda bundle with the installation, similar to PyTorch.