jax: TypeError when using None as custom_vjp cotangent for custom pytree and asarray

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class C:
    def __init__(self, a=1.):
        self.a = jnp.asarray(a, dtype=float)  # TypeError: float() argument must be a string or a real number, not 'object'
        # self.a = jnp.asarray(a)  # TypeError: Value '<object object at 0x7fff5220fc20>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

    def tree_flatten(self):
        return (self.a,), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

@custom_vjp
def f(x, y):
    return 2. * x  # doesn't depend on y in this simple example


def f_fwd(x, y):
    z = f(x, y)
    res = None
    return z, res

def f_bwd(res, z_cot):
    x_cot = 2. * z_cot
    y_cot = None
    return x_cot, y_cot

f.defvjp(f_fwd, f_bwd)

c = C()
vjp(f, 3., c)[1](3.)  # TypeError with different messages (see above) depending on whether c.a is weakly typed or not
vjp(f, 3., [1., 2])[1](3.)  # this works

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 15 (11 by maintainers)

Most upvoted comments

Hi @eelregit - this is a common issue when using custom PyTrees in JAX transforms. Various transforms will pass None values or object() placeholder values to the PyTree constructor, and this will cause issues if your pytree does too strict an input validation at initialization. Here’s an example of how we deal with this in the PyTree used to represent sparse matrices: https://github.com/google/jax/blob/3136004c623be4cc7b25f8477ffdce0b3a110a2e/jax/experimental/sparse/bcoo.py#L1716-L1720

https://github.com/google/jax/blob/3136004c623be4cc7b25f8477ffdce0b3a110a2e/jax/experimental/sparse/util.py#L44-L60

You’ll have to do some kind of similar check if you have your own pytrees that you want to use with JAX transformations.

@jakevdp Could we only check placeholder in tree_unflatten? Thus we can use the strictest validation at normal initialization. Moreover, do you think we need a special class to represent placeholder? Using None and object() may surprise user and give a vague error message like this issue.