flax: 0.6.3 introduced a circular dependency with orbax

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
  • Python version:
  • GPU/TPU model and memory:
  • CUDA version (if applicable):

Problem you have encountered:

Flax 0.6.3 added a dependency on orbax, which has a dependency on flax. This is causing https://github.com/tensorflow/tfjs/issues/7159 in the TensorFlow.js repository. TFJS resolves pypi packages using Bazel, which does not support circular dependencies.

Was this change intentional? If so, I can file a bug with rules_python instead, although last time this kind of circular dependency issue arose, it was determined to be a bug in the downstream package. I’m not sure if that true in this case, though.

What you expected to happen:

No circular dependency.

Logs, error messages, etc:

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Reactions: 3
  • Comments: 16 (2 by maintainers)

Commits related to this issue

Most upvoted comments

For the future, here are some ideas:

  • Orbax becomes the owner of seralization.py, this makes sense for a check-pointing library. On the Flax side we add some wrappers.
  • Create a new flax-serialization library which both libraries can depend on to avoid the circular dependency.

As Flax depends on orbax-checkpoint now, there is no longer a circular dependency.

Hi all - sorry for the delay on this issue! The underlying issue is that orbax has been using the flax serialization routines, partly for some backwards-compatibility reasons, but mainly because the simple flax “state dict” machinery was a common way to handle deriving the “key paths” to each leaf in a pytree. The circular dependency is occurring since we’re trying to transition to being able to use orbax for checkpoints.

We’re trying to resolve this issue in a fundamental way by adopting a mechanism in jax itself to define the key-paths to pytree leaves so that we needn’t use our relatively simple state-dict abstraction in other libraries (and ultimately to delete it ourselves).

Our sincere apologies for the build troubles with this circular dependency - we and the orbax maintainers are working to try to resolve it in the next week or so.