jax: new version breaks arange?
Hi Jax team,
Working with @alipwong on this.
Just upgraded from jax 0.1.75 & jaxlib 0.1.52 where everything was working fine, to jax-0.2.7 and jaxlib-0.1.57.
I have a longer program using jax.np.arange that used to work just fine. Now it breaks and I can’t understand what changed.
Rather than the full thing, here is an example simple code
def test_arange(n):
return np.arange(0,n,1)
test = jit(test_arange)
test(5)
returns errors
---------------------------------------------------------------------------
FilteredStackTrace Traceback (most recent call last)
<ipython-input-27-613700c58fec> in <module>
----> 6 test(5)
~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
<ipython-input-27-613700c58fec> in test_arange(n)
2 # n = x.shape[0]
----> 3 return np.arange(0,n,1)
4
~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in arange(start, stop, step, dtype)
FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
It arose in jax.numpy.arange argument `stop`.
While tracing the function test_arange at <ipython-input-27-613700c58fec>:1, this concrete value was not available in Python because it depends on the value of the arguments to test_arange at <ipython-input-27-613700c58fec>:1 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ConcretizationTypeError Traceback (most recent call last)
<ipython-input-27-613700c58fec> in <module>
6
7 # test(np.arange(5))
----> 8 test(5)
~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
369 c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
370 xla_consts = map(partial(xb.constant, c), consts)
--> 371 xla_args = xla._xla_callable_args(c, avals, tuple_args)
372 outs = xla.jaxpr_subcomp(
373 c, jaxpr, backend, axis_env_, xla_consts,
~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in cache_miss(*args, **kwargs)
282 sine.4 = f32[] sine(cosine.3)
283 ROOT tuple.5 = (f32[]) tuple(sine.4)
--> 284 }
285 <BLANKLINE>
286 <BLANKLINE>
~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
1227 else:
1228 msg, = e.args
-> 1229 jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20))
1230 msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
1231 raise JaxprTypeError(msg) from None
~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1218
1219 Raises `TypeError` if `jaxpr` is determined invalid. Returns `None` otherwise.
-> 1220 """
1221 try:
1222 _check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1230 msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
1231 raise JaxprTypeError(msg) from None
-> 1232
1233 def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
1234
~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
596
597 def __hash__(self) -> int:
--> 598 return hash((self.level, self.trace_type))
599
600 def __eq__(self, other: object) -> bool:
~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
568 else:
569 assert s.is_tuple()
--> 570 for i, sub in enumerate(s.tuple_shapes()):
571 subindex = index + (i,)
572 if sub.is_tuple():
~/opt/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
249 # store 1 was occupied, so let's check store 2 is not occupied
250 try:
--> 251 out2 = aux2()
252 except StoreException:
253 return True, out1
~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
643 out_nodes = jaxpr_subcomp(
644 c, jaxpr, backend, AxisEnv(nreps, (), (), None), xla_consts,
--> 645 extend_name_stack(wrap_name(name, 'jit')), *xla_args)
646 out_tuple = xops.Tuple(c, out_nodes)
647 backend = xb.get_backend(backend)
~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
~/opt/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
158
159 return ans
--> 160
161 def __repr__(self):
162 def transform_to_str(x):
<ipython-input-27-613700c58fec> in test_arange(n)
1 def test_arange(n):
2 # n = x.shape[0]
----> 3 return np.arange(0,n,1)
4
5 test = jit(test_arange)
~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in arange(start, stop, step, dtype)
~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in concrete_or_error(force, val, context)
920 @property
921 def shape(self):
--> 922 msg = ("UnshapedArray has no shape. Please open an issue at "
923 "https://github.com/google/jax/issues because it's unexpected for "
924 "UnshapedArray instances to ever be produced.")
~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in raise_concretization_error(val, context)
897 complex, "Try using `x.astype(complex)` instead.")
898 _hex = concretization_function_error(hex)
--> 899 _oct = concretization_function_error(oct)
900
901 def at_least_vspace(self) -> AbstractValue:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
It arose in jax.numpy.arange argument `stop`.
While tracing the function test_arange at <ipython-input-27-613700c58fec>:1, this concrete value was not available in Python because it depends on the value of the arguments to test_arange at <ipython-input-27-613700c58fec>:1 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Is there a safe way to jit np.arange in the new version of Jax?
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 21 (2 by maintainers)
I think this changed in https://github.com/google/jax/pull/4717 . See the description of that PR for details as to why we changed this.
Yes, you can use
static_argnums
to jit-compile arange:Does that answer your question?