jax: Jax at head infinite loops printing an array
Description
Tensorflow checked out commit: d06d16ef29767129aeaaefd007481aeb759e5400 Jax: dc6bf9b725fcdd96f8152d6d3a1ce3f0c14d9ced
This simple test program:
import jax
# Do once and print.
a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9])
b = a
for i in range(100):
b = jax.numpy.asarray([i]) * a + b
print(b)
Fails with:
...
File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 351, in __array__
return np.asarray(self._value, dtype=dtype)
File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 497, in _value
self._npy_value = np.asarray(self._arrays[0]) # type: ignore
File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 351, in __array__
return np.asarray(self._value, dtype=dtype)
File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 496, in _value
if self.is_fully_replicated:
File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 325, in is_fully_replicated
return self.shape == self._arrays[0].shape
File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 196, in shape
return self.aval.shape
RecursionError: maximum recursion depth exceeded while calling a Python object
I can reproduce this on all backends, including CPU.
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response
About this issue
- Original URL
- State: open
- Created a year ago
- Comments: 15 (13 by maintainers)
I think this PR rollback should fix it: https://github.com/google/jax/pull/14421