jax: Cannot use GPU on Ubuntu 16.04, CUDA 11.0
I have a GeForce RTX 3090 with CUDA 11.0 installed on Ubuntu 16.04 and the installation works fine with TensorFlow.
The path /usr/local/cuda
points to that installation.
I installed Jax into my Python 3.8.6 conda environment by running
pip3 install --upgrade jax jaxlib==0.1.62+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
I can import Jax from Python but the first operation throws an error.
from jax import numpy
numpy.zeros(4)
2021-03-12 21:26:30.353284: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:191] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6
2021-03-12 21:26:30.353307: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:194] Used ptxas at /usr/local/cuda-11.0/bin/ptxas
2021-03-12 21:26:30.353808: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:682] failed to get PTX kernel "broadcast_2" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2021-03-12 21:26:30.353849: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1881] Execution of replica 0 failed: Internal: Could not find the corresponding function
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1181, in __repr__
s = np.array2string(self._value, prefix=prefix, suffix=',',
File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1122, in _value
self._npy_value = _force(self).device_buffer.to_py()
File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1333, in _force
result = force_fun(x)
File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1357, in force_fun
return compiled.execute([x.device_buffer])[0]
RuntimeError: Internal: Could not find the corresponding function
Running nvcc --version
prints
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Thu_Jun_11_22:26:38_PDT_2020
Cuda compilation tools, release 11.0, V11.0.194
Build cuda_11.0_bu.TC445_37.28540450_0
Is this a bug or am I doing something wrong?
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Reactions: 2
- Comments: 15 (2 by maintainers)
In my case the problem was caused by my PATH. CUDA did not include itself into it automatically (although nvidia-smi worked), so I had to add this to my
.bashrc
.Unfortunately not with CUDA 11.0. If there is a fix, I would be interested, too.
However, after upgrading to CUDA Toolkit
11.2
and usingjaxlib==0.1.64+cuda112
everything seems to be working 🤷♂️Not sure if it’s related but building from source fails with the following error: