jax: `jnp.linalg.svd` etc. do not respect `__jax_array__`

A couple things going on here. First of all, the following is an example of jnp.linalg.svd failing to respect __jax_array__.

import jax
import jax.numpy as jnp

class MyArray:
    def __jax_array__(self):
        return jnp.array([[1.]])

with jax.disable_jit():
    jnp.linalg.svd(MyArray())

# TypeError: Value '<__main__.MyArray object at 0x7f5ec0f584c0>' with dtype object is not a valid
# JAX array type. Only arrays of numeric types are supported by JAX.

Remove the disable_jit and this works.

The reason it works without disable_jit is that jnp.linalg.svd and friends all have jax.jit wrappers, which is what spots the __jax_array__ and handles things appropriately… unless the JAX arraylike is also a PyTree, in which case they don’t. So this also fails (with a different error message this time):

import jax
import jax.numpy as jnp
from typing import NamedTuple

class MyArray(NamedTuple):
    def __jax_array__(self):
        return jnp.array([[1.]])

jnp.linalg.svd(MyArray())
# ValueError: Argument to singular value decomposition must have ndims >= 2

So whilst it takes either a disable_jit or a PyTree to actually trigger it, I think the fundamental issue here is that jnp.linalg.svd and friends do not check for JAX arraylikes.

About this issue

  • Original URL
  • State: open
  • Created 2 years ago
  • Comments: 27 (6 by maintainers)

Commits related to this issue

Most upvoted comments

@mattjj In favour of not removing __jax_array__: it allows to write code that is backend-agnostic using NEP47.