grok-1: Error "Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89"

Hello everyone.

I’ve tried to run the pip install, but I’m facing the following error:

ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25)
ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"

I’m install on MacOsX and in a Ubuntu and faced the same issue.

Anyone else got the same error?

About this issue

  • Original URL
  • State: open
  • Created 3 months ago
  • Reactions: 13
  • Comments: 20

Most upvoted comments

Quick update: I had to install JAX first by running pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html then pip install -r requirements.txt finished okay.

(py390) [zcobol@dallas grok-1]$ pip list
Package                  Version
------------------------ ---------------------
absl-py                  2.1.0
chex                     0.1.85
dm-haiku                 0.0.12
etils                    1.5.2
flax                     0.8.2
fsspec                   2024.3.0
importlib_metadata       7.0.2
importlib_resources      6.3.1
jax                      0.4.25
jaxlib                   0.4.25+cuda12.cudnn89
jmp                      0.0.4
markdown-it-py           3.0.0
mdurl                    0.1.2
ml-dtypes                0.3.2
msgpack                  1.0.8
nest-asyncio             1.6.0
numpy                    1.26.4
nvidia-cublas-cu12       12.4.2.65
nvidia-cuda-cupti-cu12   12.4.99
nvidia-cuda-nvcc-cu12    12.4.99
nvidia-cuda-nvrtc-cu12   12.4.99
nvidia-cuda-runtime-cu12 12.4.99
nvidia-cudnn-cu12        8.9.7.29
nvidia-cufft-cu12        11.2.0.44
nvidia-cusolver-cu12     11.6.0.99
nvidia-cusparse-cu12     12.3.0.142
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.4.99
opt-einsum               3.3.0
optax                    0.2.1
orbax-checkpoint         0.5.6
pip                      23.3.1
protobuf                 5.26.0
Pygments                 2.17.2
PyYAML                   6.0.1
rich                     13.7.1
scipy                    1.12.0
sentencepiece            0.2.0
setuptools               68.2.2
tabulate                 0.9.0
tensorstore              0.1.56
toolz                    0.12.1
typing_extensions        4.10.0
wheel                    0.41.2
zipp                     3.18.1

Quick update: I had to install JAX first by running pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html then pip install -r requirements.txt finished okay.

(py390) [zcobol@dallas grok-1]$ pip list
Package                  Version
------------------------ ---------------------
absl-py                  2.1.0
chex                     0.1.85
dm-haiku                 0.0.12
etils                    1.5.2
flax                     0.8.2
fsspec                   2024.3.0
importlib_metadata       7.0.2
importlib_resources      6.3.1
jax                      0.4.25
jaxlib                   0.4.25+cuda12.cudnn89
jmp                      0.0.4
markdown-it-py           3.0.0
mdurl                    0.1.2
ml-dtypes                0.3.2
msgpack                  1.0.8
nest-asyncio             1.6.0
numpy                    1.26.4
nvidia-cublas-cu12       12.4.2.65
nvidia-cuda-cupti-cu12   12.4.99
nvidia-cuda-nvcc-cu12    12.4.99
nvidia-cuda-nvrtc-cu12   12.4.99
nvidia-cuda-runtime-cu12 12.4.99
nvidia-cudnn-cu12        8.9.7.29
nvidia-cufft-cu12        11.2.0.44
nvidia-cusolver-cu12     11.6.0.99
nvidia-cusparse-cu12     12.3.0.142
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.4.99
opt-einsum               3.3.0
optax                    0.2.1
orbax-checkpoint         0.5.6
pip                      23.3.1
protobuf                 5.26.0
Pygments                 2.17.2
PyYAML                   6.0.1
rich                     13.7.1
scipy                    1.12.0
sentencepiece            0.2.0
setuptools               68.2.2
tabulate                 0.9.0
tensorstore              0.1.56
toolz                    0.12.1
typing_extensions        4.10.0
wheel                    0.41.2
zipp                     3.18.1

Validated on Ubantu, which is great

Apple silicon is not supported. There exists a Metal plugin for Jax (which you would have to change out for the cuda jax build), but you will run into problems with the dm_haiku dependency as well. Also I don’t think any configuration for an apple silicon device would have enough memory to run this model anyway.

Quick update: I had to install JAX first by running pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html then pip install -r requirements.txt finished okay.

(py390) [zcobol@dallas grok-1]$ pip list
Package                  Version
------------------------ ---------------------
absl-py                  2.1.0
chex                     0.1.85
dm-haiku                 0.0.12
etils                    1.5.2
flax                     0.8.2
fsspec                   2024.3.0
importlib_metadata       7.0.2
importlib_resources      6.3.1
jax                      0.4.25
jaxlib                   0.4.25+cuda12.cudnn89
jmp                      0.0.4
markdown-it-py           3.0.0
mdurl                    0.1.2
ml-dtypes                0.3.2
msgpack                  1.0.8
nest-asyncio             1.6.0
numpy                    1.26.4
nvidia-cublas-cu12       12.4.2.65
nvidia-cuda-cupti-cu12   12.4.99
nvidia-cuda-nvcc-cu12    12.4.99
nvidia-cuda-nvrtc-cu12   12.4.99
nvidia-cuda-runtime-cu12 12.4.99
nvidia-cudnn-cu12        8.9.7.29
nvidia-cufft-cu12        11.2.0.44
nvidia-cusolver-cu12     11.6.0.99
nvidia-cusparse-cu12     12.3.0.142
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.4.99
opt-einsum               3.3.0
optax                    0.2.1
orbax-checkpoint         0.5.6
pip                      23.3.1
protobuf                 5.26.0
Pygments                 2.17.2
PyYAML                   6.0.1
rich                     13.7.1
scipy                    1.12.0
sentencepiece            0.2.0
setuptools               68.2.2
tabulate                 0.9.0
tensorstore              0.1.56
toolz                    0.12.1
typing_extensions        4.10.0
wheel                    0.41.2
zipp                     3.18.1

I am not Ok! Apple M1 Pro 13.2.1 (22D68)

I’m on windows 11 with wsl2/ubuntu. It comes with python 3.8. I got same 0.4.25 error. After installing 3.9 with venv, it just worked fine for me. HTH.

My operating system is Windows 11, and I am facing the same problem. WARNING: jax 0.4.25 does not provide the extra 'cuda12-pip' INFO: pip is looking at multiple versions of jax[cuda12-pip] to determine which version is compatible with other requirements. This could take a while. ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.25) ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"

pip list result: `D:\workspace\github\grok-1>pip list Package Version


importlib_metadata 7.0.2 jax 0.4.6 jaxlib 0.4.25 ml-dtypes 0.3.2 numpy 1.26.4 opt-einsum 3.3.0 pip 24.0 scipy 1.12.0 setuptools 49.2.1 wheel 0.43.0 zipp 3.18.1`

windows 11, i got same error

Needs either either a TPU or GPU (NVIDIA/AMD only). They have to be 8 devices.

8 devices ?