jax: Jax at head infinite loops printing an array

Description

Tensorflow checked out commit: d06d16ef29767129aeaaefd007481aeb759e5400 Jax: dc6bf9b725fcdd96f8152d6d3a1ce3f0c14d9ced

This simple test program:

import jax

# Do once and print.
a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9])
b = a
for i in range(100):
  b = jax.numpy.asarray([i]) * a + b
print(b)

Fails with:

...
  File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 351, in __array__
    return np.asarray(self._value, dtype=dtype)
  File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 497, in _value
    self._npy_value = np.asarray(self._arrays[0])  # type: ignore
  File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 351, in __array__
    return np.asarray(self._value, dtype=dtype)
  File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 496, in _value
    if self.is_fully_replicated:
  File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 325, in is_fully_replicated
    return self.shape == self._arrays[0].shape
  File "/usr/local/google/home/laurenzo/src/openxla-pjrt-plugin/external/.clones/jax/jax/_src/array.py", line 196, in shape
    return self.aval.shape
RecursionError: maximum recursion depth exceeded while calling a Python object

I can reproduce this on all backends, including CPU.

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 15 (13 by maintainers)

Most upvoted comments

I think this PR rollback should fix it: https://github.com/google/jax/pull/14421