jax: isinstance(numpy.zeros(1), jax.numpy.ndarray) returns True
numpy.ndarray instances currently return True if you run an isinstance check against jax.numpy.ndarray. I guess I see how this happens: I think Jax doesn’t actually use that type, so it’s maybe the actual one from numpy? It’s a bit of a hassle when you’re checking the array provenances though.
def test_numpy_ndarray_is_not_instance_of_jax_numpy_ndarray():
assert not isinstance(numpy.zeros(1), jax.numpy.ndarray)
Btw, what’s the preferred way to convert data from Jax to numpy? I’ve found jax.device_get() by poking around, but I don’t think it’s documented.
Thanks for the great project!
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 20 (11 by maintainers)
@mattjj Thanks for the quick reply!
I would’ve expected the following:
This is the normal inheritancish behaviour:
jax.numpy.ndarrayis the new impostor, and it can claim to be a type ofnumpy.ndarray(even when that’s not literally true). But it’s kind of weird to tricknumpy.ndarrayinto believing it’s a type ofjax.numpy.ndarray, since that’s really not at all true.Finally, I think it can be convenient to do something like
from jax import numpyin a quick script, and it’s nice for that to work fine when the program only has to deal with jax arrays or numpy arrays, but not both. But if a user writes that import in a context where they’ll have a mixture of the two types, their code will have all sorts of bugs, and I think that’s not really jax’s fault? So I think havingisinstance(arr, numpy.ndarray)returnFalsewill be the least of their problems. Like, yeah, they might have written code expecting that to returnTrue— but the actual truth isFalse, and they’re better off knowing it.@jakevdp We never wrote down that promise, but yes, I would consider it a bug to ever return a numpy array from a JAX function. I did a sequence of changes to make sure that
jax.numpynever does that. The main thing that convinced me was that promotion semantics are different.I agree, as I wrote over in https://github.com/google/jax/pull/1081#issuecomment-517099099, this behavior surprised me. This sort of dynamic inheritance is rarely used; my guess is that it could lead to bugs. My vote would definitely be for encouraging separate/explicit
jax.numpyandnumpyimports.