jax: install command doesn't work
What should the command to install be?
` (pansatz) [amawi@sylg nn_ansatz]$ pip install --upgrade pip
Requirement already satisfied: pip in /home/energy/amawi/miniconda3/envs/pansatz/lib/python3.7/site-packages (21.3)
(pansatz) [amawi@sylg nn_ansatz]$ pip install jax[cuda==11, cudnn==82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
ERROR: Invalid requirement: ‘jax[cuda==11,’
(pansatz) [amawi@sylg nn_ansatz]$ pip install jax[cuda=11, cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
ERROR: Invalid requirement: ‘jax[cuda=11,’ Hint: = is not a valid operator. Did you mean == ?
(pansatz) [amawi@sylg nn_ansatz]$ pip install “jax[cuda=11, cudnn=82]” -f https://storage.googleapis.com/jax-releases/jax_releases.html
ERROR: unknown command “install jax[cuda=11, cudnn=82]” `
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 15 (6 by maintainers)
UPDATE: I managed to get it to work! installing via
jax[cuda]
instead ofjax[cuda11]
somehow managed to get pip to correctly find and install jaxlib:I’m getting a similar error trying to install jax. On a singularity image running ubuntu18.04.
cuda version: 11.1 cudnn version: 8.0
The command I ran was
pip install --prefix=[omitted] --upgrade "jax[cuda11]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
and the output is:However when attempting to load the package I get a
ModuleNotFoundError: No module named 'jaxlib'
Probably – doing
conda update conda
and then re-doing everything from scratch: new environment, firstpip install -U jax
, thenpip install -U jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html
worked. Thanks!