jax: problems installing JAX on a GCP deep learning VM with GPU
I have created a GCP VM with an A100 GPU and this default image: c0-deeplearning-common-cu113-v20211219-debian-10 This is cuda_11.3 , CUDNN 8.2 and Debian 10, python 3.7. I installed JAX thus:
pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
Inside python 3.7 I type ‘import jax’ but t I get this error:
version `GLIBCXX_3.4.26' not found
According to this issue, , I can solve this by first creating a venv and then installing:
python -m venv env
source env/bin/activate
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
This partly works, in that I can now ‘import jax’ and run it. However, it fails when I use ‘jax.scan’: In particlar, the code snippet below gives this error:
2022-01-17 19:46:23.259785: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2086] Execution of replica 0 failed: INTERNAL: CustomCall failed: jaxlib/cuda_prng_kernels.cc:30: operation cudaGetLastError() failed: the provided PTX was compiled with an unsupported toolchain.
Here is the code:
import jax
import jax.numpy as jnp
# sample from a Markov chain
init_dist = jnp.array([0.8, 0.2])
trans_mat = jnp.array([[0.9, 0.1], [0.5, 0.5]])
rng_key = jax.random.PRNGKey(0)
from jax.scipy.special import logit
seq_len = 15
initial_state = jax.random.categorical(rng_key, logits=logit(init_dist), shape=(1,))
def draw_state(prev_state, key):
logits = logit(trans_mat[:, prev_state])
state = jax.random.categorical(key, logits=logits.flatten(), shape=(1,))
return state, state
rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)
keys = jax.random.split(rng_state, seq_len - 1)
final_state, states = jax.lax.scan(draw_state, initial_state, keys)
print(states)
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Reactions: 3
- Comments: 17 (9 by maintainers)
The CUDA driver version incompatibility problem should be fixed in the next jaxlib release. JAX will automatically fall back to not using parallel compilation if the NVIDIA driver is too old.
Unfortunately we had to revert the workaround for
version GLIBCXX_3.4.26 not found
because the workaround was to importscipy
ourselves, but that turns out to be too slow to do every time jax is imported. If you still see that problem, I recommend one of the workarounds above. Note thatjax
is also available viaconda-forge
(https://github.com/google/jax#conda-installation) and using theconda
installation of JAX will not have this issue.@sayakpaul the issue is you have CUDA 11.0 installed. JAX doesn’t support CUDA 11.0. Install a newer CUDA.
On Tue, Jan 18, 2022 at 6:12 AM Peter Hawkins @.***> wrote:
How can we do that? Given that this is what every GCP user (who requesta GPU VM) is going to experience, I think the instructions should be clear, otherwise it will just drive users towards other cloud providers.