jax: isinstance(numpy.zeros(1), jax.numpy.ndarray) returns True
numpy.ndarray
instances currently return True
if you run an isinstance
check against jax.numpy.ndarray
. I guess I see how this happens: I think Jax doesn’t actually use that type, so it’s maybe the actual one from numpy? It’s a bit of a hassle when you’re checking the array provenances though.
def test_numpy_ndarray_is_not_instance_of_jax_numpy_ndarray():
assert not isinstance(numpy.zeros(1), jax.numpy.ndarray)
Btw, what’s the preferred way to convert data from Jax to numpy? I’ve found jax.device_get()
by poking around, but I don’t think it’s documented.
Thanks for the great project!
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 20 (11 by maintainers)
@mattjj Thanks for the quick reply!
I would’ve expected the following:
This is the normal inheritancish behaviour:
jax.numpy.ndarray
is the new impostor, and it can claim to be a type ofnumpy.ndarray
(even when that’s not literally true). But it’s kind of weird to tricknumpy.ndarray
into believing it’s a type ofjax.numpy.ndarray
, since that’s really not at all true.Finally, I think it can be convenient to do something like
from jax import numpy
in a quick script, and it’s nice for that to work fine when the program only has to deal with jax arrays or numpy arrays, but not both. But if a user writes that import in a context where they’ll have a mixture of the two types, their code will have all sorts of bugs, and I think that’s not really jax’s fault? So I think havingisinstance(arr, numpy.ndarray)
returnFalse
will be the least of their problems. Like, yeah, they might have written code expecting that to returnTrue
— but the actual truth isFalse
, and they’re better off knowing it.@jakevdp We never wrote down that promise, but yes, I would consider it a bug to ever return a numpy array from a JAX function. I did a sequence of changes to make sure that
jax.numpy
never does that. The main thing that convinced me was that promotion semantics are different.I agree, as I wrote over in https://github.com/google/jax/pull/1081#issuecomment-517099099, this behavior surprised me. This sort of dynamic inheritance is rarely used; my guess is that it could lead to bugs. My vote would definitely be for encouraging separate/explicit
jax.numpy
andnumpy
imports.