jax: jax.tree_utils do not keep dict key order

With Python 3.6+, dict are guarantee to keep insertion order, similarly to OrderedDict.

Both deepmind tree and tf.nest keep dict order, but jax.tree_util does not.

import tensorflow as tf
import tree
import jax

data = {'z': None, 'a': None}

print(tf.nest.map_structure(lambda _: None, data))  # {'z': None, 'a': None}
print(tree.map_structure(lambda _: None, data))     # {'z': None, 'a': None}
print(jax.tree_map(lambda _: None, data))           # {'a': None, 'z': None}  << Oups, keys order inverted

The fact that dict and OrderedDict behave differently when dict guarantee to keep insertion order feel inconsistent.

About this issue

  • Original URL
  • State: open
  • Created 4 years ago
  • Reactions: 4
  • Comments: 18 (11 by maintainers)

Most upvoted comments

Commenting to add that I just encountered this behavior and I find it quite annoying.

If I was to implement this, I’d use as treedef a dictionary with the same keys but filled with None values. This way the implementation would completely piggyback on Python and remain consistent under all circumstances.

Adding my support for maintaining dict key ordering over flattening operations.

A related question: for a dictionary d, does tuple(d.values()) internally use tree_flatten? Because that operation also does not maintain key ordering when building the tuple.

I believe Jax should behave like dm-tree where flattening any dict sort the keys, but packing dict restore the original dict key order.

import tree

x = {'z': 'z', 'a': 'a'}

print(tree.flatten(x))  # Keys sorted: ['a', 'z']
print(tree.unflatten_as(x, [0, 1]))  # Key order restored: {'z': 1, 'a': 0}

This allow all dict to have the same flattened representation, to be mixed together:

import jax

d0 = {'z': 'z', 'a': 'a'}
d1 = collections.defaultdict(int, d0)

assert jax.tree_leaves(d0) == jax.tree_leaves(d1)  # AssertionError: Oups ['z', 'a'] != ['a', 'z']

Hmmm… Looks like ordereddict and defaultdict keep keys in order: https://github.com/google/jax/blob/bf041fbdb16cb1360d3b914ab44d8c3a799d566b/jax/tree_util.py#L247-L255

While standard dicts alphabetize their keys: https://github.com/google/jax/blob/c7aff1da06072db8fb074f09a8215615d607adc2/jaxlib/pytree.cc#L132-L134

That is surprising behavior: it would be nice to make this more consistent.