jax: jax.tree_utils do not keep dict key order
With Python 3.6+, dict
are guarantee to keep insertion order, similarly to OrderedDict
.
Both deepmind tree
and tf.nest
keep dict order, but jax.tree_util
does not.
import tensorflow as tf
import tree
import jax
data = {'z': None, 'a': None}
print(tf.nest.map_structure(lambda _: None, data)) # {'z': None, 'a': None}
print(tree.map_structure(lambda _: None, data)) # {'z': None, 'a': None}
print(jax.tree_map(lambda _: None, data)) # {'a': None, 'z': None} << Oups, keys order inverted
The fact that dict
and OrderedDict
behave differently when dict
guarantee to keep insertion order feel inconsistent.
About this issue
- Original URL
- State: open
- Created 4 years ago
- Reactions: 4
- Comments: 18 (11 by maintainers)
Commenting to add that I just encountered this behavior and I find it quite annoying.
If I was to implement this, I’d use as treedef a dictionary with the same keys but filled with None values. This way the implementation would completely piggyback on Python and remain consistent under all circumstances.
Adding my support for maintaining dict key ordering over flattening operations.
A related question: for a dictionary
d
, doestuple(d.values())
internally usetree_flatten
? Because that operation also does not maintain key ordering when building the tuple.I believe Jax should behave like
dm-tree
where flattening any dict sort the keys, but packing dict restore the original dict key order.This allow all dict to have the same flattened representation, to be mixed together:
Hmmm… Looks like
ordereddict
anddefaultdict
keep keys in order: https://github.com/google/jax/blob/bf041fbdb16cb1360d3b914ab44d8c3a799d566b/jax/tree_util.py#L247-L255While standard dicts alphabetize their keys: https://github.com/google/jax/blob/c7aff1da06072db8fb074f09a8215615d607adc2/jaxlib/pytree.cc#L132-L134
That is surprising behavior: it would be nice to make this more consistent.