jax: install command doesn't work

What should the command to install be?

` (pansatz) [amawi@sylg nn_ansatz]$ pip install --upgrade pip

Requirement already satisfied: pip in /home/energy/amawi/miniconda3/envs/pansatz/lib/python3.7/site-packages (21.3)

(pansatz) [amawi@sylg nn_ansatz]$ pip install jax[cuda==11, cudnn==82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

ERROR: Invalid requirement: ‘jax[cuda==11,’

(pansatz) [amawi@sylg nn_ansatz]$ pip install jax[cuda=11, cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

ERROR: Invalid requirement: ‘jax[cuda=11,’ Hint: = is not a valid operator. Did you mean == ?

(pansatz) [amawi@sylg nn_ansatz]$ pip install “jax[cuda=11, cudnn=82]” -f https://storage.googleapis.com/jax-releases/jax_releases.html

ERROR: unknown command “install jax[cuda=11, cudnn=82]” `

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 15 (6 by maintainers)

Most upvoted comments

UPDATE: I managed to get it to work! installing via jax[cuda] instead of jax[cuda11] somehow managed to get pip to correctly find and install jaxlib:

Collecting jaxlib==0.1.73+cuda11.cudnn82
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.73%2Bcuda11.cudnn82-cp37-none-manylinux2010_x86_64.whl (138.6 MB)

I’m getting a similar error trying to install jax. On a singularity image running ubuntu18.04.
cuda version: 11.1 cudnn version: 8.0

The command I ran was pip install --prefix=[omitted] --upgrade "jax[cuda11]" -f https://storage.googleapis.com/jax-releases/jax_releases.html and the output is:

Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html
Collecting jax[cuda11]
  Downloading jax-0.2.24.tar.gz (786 kB)
     |████████████████████████████████| 786 kB 19.7 MB/s
WARNING: jax 0.2.24 does not provide the extra 'cuda11'
Collecting absl-py
  Downloading absl_py-0.15.0-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 54.1 MB/s
Collecting numpy>=1.18
  Downloading numpy-1.21.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.7 MB)
     |████████████████████████████████| 15.7 MB 41.1 MB/s
Collecting opt_einsum
  Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)
     |████████████████████████████████| 65 kB 1.3 MB/s
Collecting scipy>=1.2.1
  Downloading scipy-1.7.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (28.5 MB)
     |████████████████████████████████| 28.5 MB 114 kB/s
Collecting six
  Downloading six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting typing_extensions
  Downloading typing_extensions-3.10.0.2-py3-none-any.whl (26 kB)
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... done
  Created wheel for jax: filename=jax-0.2.24-py3-none-any.whl size=903114 sha256=9100eaa1bb6616d51d1b79e20451cdfab8f0a5d46994b63ad88c1253b86a61a1
  Stored in directory: /<omitted>/.cache/pip/wheels/28/a9/0f/3497740c85f6e1de8f4d291fd2f77d046d66a87620143d0d0e
Successfully built jax
Installing collected packages: six, numpy, typing-extensions, scipy, opt-einsum, absl-py, jax

However when attempting to load the package I get a ModuleNotFoundError: No module named 'jaxlib'

Probably – doing conda update conda and then re-doing everything from scratch: new environment, first pip install -U jax, then pip install -U jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html worked. Thanks!