jax: `jnp.linalg.svd` etc. do not respect `__jax_array__`
A couple things going on here. First of all, the following is an example of jnp.linalg.svd
failing to respect __jax_array__
.
import jax
import jax.numpy as jnp
class MyArray:
def __jax_array__(self):
return jnp.array([[1.]])
with jax.disable_jit():
jnp.linalg.svd(MyArray())
# TypeError: Value '<__main__.MyArray object at 0x7f5ec0f584c0>' with dtype object is not a valid
# JAX array type. Only arrays of numeric types are supported by JAX.
Remove the disable_jit
and this works.
The reason it works without disable_jit
is that jnp.linalg.svd
and friends all have jax.jit
wrappers, which is what spots the __jax_array__
and handles things appropriately… unless the JAX arraylike is also a PyTree, in which case they don’t. So this also fails (with a different error message this time):
import jax
import jax.numpy as jnp
from typing import NamedTuple
class MyArray(NamedTuple):
def __jax_array__(self):
return jnp.array([[1.]])
jnp.linalg.svd(MyArray())
# ValueError: Argument to singular value decomposition must have ndims >= 2
So whilst it takes either a disable_jit
or a PyTree to actually trigger it, I think the fundamental issue here is that jnp.linalg.svd
and friends do not check for JAX arraylikes.
About this issue
- Original URL
- State: open
- Created 2 years ago
- Comments: 27 (6 by maintainers)
Commits related to this issue
- Expand support for __jax_array__ in jnp.array. This relates to the long discussion in #4725 and #10065. — committed to gnecula/jax by gnecula 2 years ago
@mattjj In favour of not removing
__jax_array__
: it allows to write code that is backend-agnostic using NEP47.