jax: new version breaks arange?

Hi Jax team,

Working with @alipwong on this.

Just upgraded from jax 0.1.75 & jaxlib 0.1.52 where everything was working fine, to jax-0.2.7 and jaxlib-0.1.57.

I have a longer program using jax.np.arange that used to work just fine. Now it breaks and I can’t understand what changed.

Rather than the full thing, here is an example simple code

def test_arange(n):
    return np.arange(0,n,1)

test = jit(test_arange)

test(5)

returns errors

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-27-613700c58fec> in <module>

----> 6 test(5)

~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)

~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)

<ipython-input-27-613700c58fec> in test_arange(n)
      2 #     n = x.shape[0]
----> 3     return np.arange(0,n,1)
      4 

~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in arange(start, stop, step, dtype)

FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

It arose in jax.numpy.arange argument `stop`.

While tracing the function test_arange at <ipython-input-27-613700c58fec>:1, this concrete value was not available in Python because it depends on the value of the arguments to test_arange at <ipython-input-27-613700c58fec>:1 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

ConcretizationTypeError                   Traceback (most recent call last)
<ipython-input-27-613700c58fec> in <module>
      6 
      7 # test(np.arange(5))
----> 8 test(5)

~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)

~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    369     c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
    370     xla_consts = map(partial(xb.constant, c), consts)
--> 371     xla_args = xla._xla_callable_args(c, avals, tuple_args)
    372     outs = xla.jaxpr_subcomp(
    373         c, jaxpr, backend, axis_env_, xla_consts,

~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in cache_miss(*args, **kwargs)
    282     sine.4 = f32[] sine(cosine.3)
    283     ROOT tuple.5 = (f32[]) tuple(sine.4)
--> 284   }
    285   <BLANKLINE>
    286   <BLANKLINE>

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1227     else:
   1228       msg, = e.args
-> 1229       jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20))
   1230     msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
   1231     raise JaxprTypeError(msg) from None

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1218 
   1219   Raises `TypeError` if `jaxpr` is determined invalid. Returns `None` otherwise.
-> 1220   """
   1221   try:
   1222     _check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1230     msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
   1231     raise JaxprTypeError(msg) from None
-> 1232 
   1233 def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
   1234 

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    596 
    597   def __hash__(self) -> int:
--> 598     return hash((self.level, self.trace_type))
    599 
    600   def __eq__(self, other: object) -> bool:

~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    568     else:
    569       assert s.is_tuple()
--> 570       for i, sub in enumerate(s.tuple_shapes()):
    571         subindex = index + (i,)
    572         if sub.is_tuple():

~/opt/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    249     # store 1 was occupied, so let's check store 2 is not occupied
    250     try:
--> 251       out2 = aux2()
    252     except StoreException:
    253       return True, out1

~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    643   out_nodes = jaxpr_subcomp(
    644       c, jaxpr, backend, AxisEnv(nreps, (), (), None), xla_consts,
--> 645       extend_name_stack(wrap_name(name, 'jit')), *xla_args)
    646   out_tuple = xops.Tuple(c, out_nodes)
    647   backend = xb.get_backend(backend)

~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)

~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)

~/opt/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    158 
    159     return ans
--> 160 
    161   def __repr__(self):
    162     def transform_to_str(x):

<ipython-input-27-613700c58fec> in test_arange(n)
      1 def test_arange(n):
      2 #     n = x.shape[0]
----> 3     return np.arange(0,n,1)
      4 
      5 test = jit(test_arange)

~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in arange(start, stop, step, dtype)

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in concrete_or_error(force, val, context)
    920   @property
    921   def shape(self):
--> 922     msg = ("UnshapedArray has no shape. Please open an issue at "
    923            "https://github.com/google/jax/issues because it's unexpected for "
    924            "UnshapedArray instances to ever be produced.")

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in raise_concretization_error(val, context)
    897       complex, "Try using `x.astype(complex)` instead.")
    898   _hex     = concretization_function_error(hex)
--> 899   _oct     = concretization_function_error(oct)
    900 
    901   def at_least_vspace(self) -> AbstractValue:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

It arose in jax.numpy.arange argument `stop`.

While tracing the function test_arange at <ipython-input-27-613700c58fec>:1, this concrete value was not available in Python because it depends on the value of the arguments to test_arange at <ipython-input-27-613700c58fec>:1 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

Is there a safe way to jit np.arange in the new version of Jax?

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 21 (2 by maintainers)

Most upvoted comments

I think this changed in https://github.com/google/jax/pull/4717 . See the description of that PR for details as to why we changed this.

Yes, you can use static_argnums to jit-compile arange:

test = jit(test_arange, static_argnums=0)

Does that answer your question?