flax: 0.6.3 introduced a circular dependency with orbax
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib: - Python version:
- GPU/TPU model and memory:
- CUDA version (if applicable):
Problem you have encountered:
Flax 0.6.3 added a dependency on orbax, which has a dependency on flax. This is causing https://github.com/tensorflow/tfjs/issues/7159 in the TensorFlow.js repository. TFJS resolves pypi packages using Bazel, which does not support circular dependencies.
Was this change intentional? If so, I can file a bug with rules_python instead, although last time this kind of circular dependency issue arose, it was determined to be a bug in the downstream package. I’m not sure if that true in this case, though.
What you expected to happen:
No circular dependency.
Logs, error messages, etc:
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Reactions: 3
- Comments: 16 (2 by maintainers)
Commits related to this issue
- Revert "python3Packages.flax: 0.6.1 -> 0.6.3" This reverts commit fe0048c7f8d869993b486b243dec94204bcf0c08 which broke flax due to https://github.com/google/flax/issues/2707. — committed to dotlambda/nixpkgs by dotlambda a year ago
- Revert "python3Packages.flax: 0.6.1 -> 0.6.3" This reverts commit fe0048c7f8d869993b486b243dec94204bcf0c08 which broke flax due to https://github.com/google/flax/issues/2707. — committed to gador/nixpkgs by dotlambda a year ago
For the future, here are some ideas:
seralization.py, this makes sense for a check-pointing library. On the Flax side we add some wrappers.flax-serializationlibrary which both libraries can depend on to avoid the circular dependency.As Flax depends on
orbax-checkpointnow, there is no longer a circular dependency.Hi all - sorry for the delay on this issue! The underlying issue is that orbax has been using the flax serialization routines, partly for some backwards-compatibility reasons, but mainly because the simple flax “state dict” machinery was a common way to handle deriving the “key paths” to each leaf in a pytree. The circular dependency is occurring since we’re trying to transition to being able to use orbax for checkpoints.
We’re trying to resolve this issue in a fundamental way by adopting a mechanism in jax itself to define the key-paths to pytree leaves so that we needn’t use our relatively simple state-dict abstraction in other libraries (and ultimately to delete it ourselves).
Our sincere apologies for the build troubles with this circular dependency - we and the orbax maintainers are working to try to resolve it in the next week or so.