jax: problems installing JAX on a GCP deep learning VM with GPU

I have created a GCP VM with an A100 GPU and this default image: c0-deeplearning-common-cu113-v20211219-debian-10 This is cuda_11.3 , CUDNN 8.2 and Debian 10, python 3.7. I installed JAX thus:

pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html  

Inside python 3.7 I type ‘import jax’ but t I get this error:

 version `GLIBCXX_3.4.26' not found

According to this issue, , I can solve this by first creating a venv and then installing:

python -m venv env
source env/bin/activate
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html  

This partly works, in that I can now ‘import jax’ and run it. However, it fails when I use ‘jax.scan’: In particlar, the code snippet below gives this error:

2022-01-17 19:46:23.259785: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2086] Execution of replica 0 failed: INTERNAL: CustomCall failed: jaxlib/cuda_prng_kernels.cc:30: operation cudaGetLastError() failed: the provided PTX was compiled with an unsupported toolchain.

Here is the code:

import jax
import jax.numpy as jnp

# sample from a Markov chain
init_dist = jnp.array([0.8, 0.2])
trans_mat = jnp.array([[0.9, 0.1], [0.5, 0.5]])
rng_key = jax.random.PRNGKey(0)
from jax.scipy.special import logit
seq_len = 15

initial_state = jax.random.categorical(rng_key, logits=logit(init_dist), shape=(1,))

def draw_state(prev_state, key):
    logits = logit(trans_mat[:, prev_state])
    state = jax.random.categorical(key, logits=logits.flatten(), shape=(1,))
    return state, state

rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)
keys = jax.random.split(rng_state, seq_len - 1)

final_state, states = jax.lax.scan(draw_state, initial_state, keys)

print(states)

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Reactions: 3
  • Comments: 17 (9 by maintainers)

Most upvoted comments

The CUDA driver version incompatibility problem should be fixed in the next jaxlib release. JAX will automatically fall back to not using parallel compilation if the NVIDIA driver is too old.

Unfortunately we had to revert the workaround for version GLIBCXX_3.4.26 not found because the workaround was to import scipy ourselves, but that turns out to be too slow to do every time jax is imported. If you still see that problem, I recommend one of the workarounds above. Note that jax is also available via conda-forge (https://github.com/google/jax#conda-installation) and using the conda installation of JAX will not have this issue.

@sayakpaul the issue is you have CUDA 11.0 installed. JAX doesn’t support CUDA 11.0. Install a newer CUDA.

On Tue, Jan 18, 2022 at 6:12 AM Peter Hawkins @.***> wrote:

The GLIBCXX version issue was for the last user a scipy issue. There’s not much we can do other than to drop our scipy dependency or make it optional. That may be possible, we’d have to look into it.

“the provided PTX was compiled with an unsupported toolchain” means that the driver version on the VM is too old for the JAX binary. We may need to build with an older CUDA release. Another option would be for JAX to warn if the CUDA release is too old and hint that the user needs to upgrade their driver.

How can we do that? Given that this is what every GCP user (who requesta GPU VM) is going to experience, I think the instructions should be clear, otherwise it will just drive users towards other cloud providers.

— Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/9218#issuecomment-1015449329, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDK6EFPXKNDKNPCHMSQ4ODUWVYMFANCNFSM5MFPIPSQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

You are receiving this because you authored the thread.Message ID: @.***>