jax: isinstance(numpy.zeros(1), jax.numpy.ndarray) returns True

numpy.ndarray instances currently return True if you run an isinstance check against jax.numpy.ndarray. I guess I see how this happens: I think Jax doesn’t actually use that type, so it’s maybe the actual one from numpy? It’s a bit of a hassle when you’re checking the array provenances though.

def test_numpy_ndarray_is_not_instance_of_jax_numpy_ndarray():
    assert not isinstance(numpy.zeros(1), jax.numpy.ndarray)

Btw, what’s the preferred way to convert data from Jax to numpy? I’ve found jax.device_get() by poking around, but I don’t think it’s documented.

Thanks for the great project!

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 20 (11 by maintainers)

Most upvoted comments

@mattjj Thanks for the quick reply!

I would’ve expected the following:

jax_arr = jax.numpy.zeros(1)
assert isinstance(jax_arr, numpy.ndarray)
np_arr = numpy.ndarray(1)
assert not isinstance(np_arr, jax.numpy.ndarray)

This is the normal inheritancish behaviour: jax.numpy.ndarray is the new impostor, and it can claim to be a type of numpy.ndarray (even when that’s not literally true). But it’s kind of weird to trick numpy.ndarray into believing it’s a type of jax.numpy.ndarray, since that’s really not at all true.

Finally, I think it can be convenient to do something like from jax import numpy in a quick script, and it’s nice for that to work fine when the program only has to deal with jax arrays or numpy arrays, but not both. But if a user writes that import in a context where they’ll have a mixture of the two types, their code will have all sorts of bugs, and I think that’s not really jax’s fault? So I think having isinstance(arr, numpy.ndarray) return False will be the least of their problems. Like, yeah, they might have written code expecting that to return True — but the actual truth is False, and they’re better off knowing it.

@jakevdp We never wrote down that promise, but yes, I would consider it a bug to ever return a numpy array from a JAX function. I did a sequence of changes to make sure that jax.numpy never does that. The main thing that convinced me was that promotion semantics are different.

I agree, as I wrote over in https://github.com/google/jax/pull/1081#issuecomment-517099099, this behavior surprised me. This sort of dynamic inheritance is rarely used; my guess is that it could lead to bugs. My vote would definitely be for encouraging separate/explicit jax.numpy and numpy imports.