flax: Flax doesn't work on google colab

It seems like flax just stopped working on google colab. Simply running

import jax
!pip install --quiet flax
import flax

yields the error

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.9.0 requires jedi>=0.10, which is not installed.
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-1-41b8697fa850>](https://localhost:8080/#) in <module>
      1 import jax
      2 get_ipython().system('pip install --quiet flax')
----> 3 import flax

4 frames
[/usr/local/lib/python3.8/dist-packages/flax/core/meta.py](https://localhost:8080/#) in Partitioned()
    263     return self.replace(names=tuple(names))
    264 
--> 265   def get_partition_spec(self) -> jax.sharding.PartitionSpec:
    266     """Returns the ``Partitionspec`` for this partitioned value."""
    267     return jax.sharding.PartitionSpec(*self.names)

AttributeError: module 'jax.sharding' has no attribute 'PartitionSpec'

Can be solved by downgrading Flax to 0.6.4.

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Reactions: 2
  • Comments: 15 (1 by maintainers)

Commits related to this issue

Most upvoted comments

Copied from https://github.com/google/flax/issues/2950#issuecomment-1479258169

Yes, indeed, TPU Colab runtime does not support new JAX versions anymore.

So I would recommend to

  1. either use the CPU or GPU runtime (both work with the latest flax-0.6.7)
  2. install Flax like this on a TPU runtime while keeping JAX runtime fixed: !pip install flax==0.6.4 jax==0.3.25 jaxlib==0.3.25

We’re sorry about the inconveniences caused, but making Colab TPU runtime infra compatible with new JAX versions is beyond what we can currently fix.