jax: xmap regression

Possibly related to #6959 , but there seems to be additional regressions in xmap unrelated to the multi-host case. One simplified example is provided in the colab below: https://colab.research.google.com/drive/18nU5rz8CF7YDYPSuPqJH38Dc7HkroYOn?usp=sharing

The example works with jax == 0.2.12, but not with the latest version.

Full stack trace:

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-2-2f2533018f84> in <module>()
     65                                             axis_resources={'shard': 'mp', 'batch': 'dp'})
     66 
---> 67     run_xmap(params, x)

8 frames

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in fun_mapped(*args)
    577       backend=backend,
    578       spmd_in_axes=None,
--> 579       spmd_out_axes_thunk=None)
    580     if has_output_rank_assertions:
    581       for out, spec in zip(out_flat, out_axes_thunk()):

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in bind(self, fun, *args, **params)
    764   def bind(self, fun, *args, **params):
    765     assert len(params['in_axes']) == len(args)
--> 766     return core.call_bind(self, fun, *args, **params)  # type: ignore
    767 
    768   def process(self, trace, fun, tracers, params):

/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1549       params_tuple, out_axes_transforms)
   1550   tracers = map(top_trace.full_raise, args)
-> 1551   outs = primitive.process(top_trace, fun, tracers, params)
   1552   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1553 

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in process(self, trace, fun, tracers, params)
    767 
    768   def process(self, trace, fun, tracers, params):
--> 769     return trace.process_xmap(self, fun, tracers, params)
    770 
    771   def post_process(self, trace, out_tracers, params):

/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    604 
    605   def process_call(self, primitive, f, tracers, params):
--> 606     return primitive.impl(f, *tracers, **params)
    607   process_map = process_call
    608 

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in xmap_impl(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, *args)
    596       axis_resources, resource_env, backend,
    597       spmd_in_axes, spmd_out_axes_thunk,
--> 598       *in_avals)
    599   distributed_debug_log(("Running xmapped function", name),
    600                         ("python function", fun.f),

/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
    260       fun.populate_stores(stores)
    261     else:
--> 262       ans = call(fun, *args)
    263       cache[key] = (ans, fun.stores)
    264 

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, *in_avals)
    619     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
    620   out_axes = out_axes_thunk()
--> 621   _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
    622   # NOTE: We don't use avals and all params, so only pass in the relevant parts (too lazy...)
    623   _resource_typing_xmap([], dict(axis_resources=axis_resources,

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
   1401     if undeclared_axes:
   1402       undeclared_axes_str = sorted([str(axis) for axis in undeclared_axes])
-> 1403       raise TypeError(f"One of xmap results has an out_axes specification of "
   1404                       f"{axes.user_repr}, but is actually mapped along more axes "
   1405                       f"defined by this xmap call: {', '.join(undeclared_axes_str)}")

TypeError: One of xmap results has an out_axes specification of ['shard', ...], but is actually mapped along more axes defined by this xmap call: batch

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 1
  • Comments: 19 (9 by maintainers)

Commits related to this issue

Most upvoted comments

Hi! Wonderful comment, and you’re not alone. I’m sending this mostly as a reminder to myself not to let you fall through the cracks – it’s 11pm right now, so I won’t be able to get to this until some other time. But if it goes too long without you getting a reply, please ping me or DM me on twitter (https://twitter.com/theshawwn).

The JAX team really is awesome, and I suspect they’ll be along with an answer, but this darn issue is a tricky one. Basically, it was a situation where there was a breaking API change (yuck!) for good reason (oh…) and it happened to break the only popular thing using xmap (GPT-J).

I actually complained a lot about the breaking change, but was dragged kicking and screaming to the conclusion that it was a good thing. I hate when developers take the attitude of “correctness matters more than dev experience,” but in this case the correctness was sort of important.

Anyway, I’ll try to make sure someone comes along with a detailed answer about what to do next. Maybe it’ll be me.

On Fri, Sep 17, 2021 at 8:02 PM Matthew Piziak @.***> wrote:

Hej @apaszke https://github.com/apaszke (and others 👋),

I’m interested in experimenting with text-generation models and to that end I’m trying to run GPT-J-6B inference on a TPU Cloud Research machine. I’m a bit out of my depth, but if you don’t mind here are my observations regarding how this issue impacts the GPT-J-6B Inference Demo https://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb hosted on Colab. Observations

First, I see a reference to a so-called “regression” with xmap in jax version 0.2.13.

jax 0.2.12 is required due to a regression with xmap in 0.2.13

!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

Second, I see a seeming incompatibility between the 0.2.12 version and the TPU environment:

TpuTransferManager_ReadDynamicShapes not available in this library. Aborted (core dumped)

That issue is also reported in #7334 https://github.com/google/jax/issues/7334. Catch-22

Based on my reading of this issue, xmap has stopped letting one return a piece of a sharded tensor, and the old behavior should not be depended upon. That makes sense and indeed I wouldn’t expect any further mitigation from the jax maintenance team, since xmap warns the user that it is experimental. This is fully an issue of downstream usage, but nevertheless if you’d offer me any wisdom here I’d appreciate it.

The issue:

  • As far as I know GPT-J-6B is the most powerful public autoregressive text-generation weight set.
  • GPT-J-6B only works on model v1 of mesh-transformer-jax.
  • v1 only works with JAX 0.2.12.
  • Dependency set mesh-transformer-jax+jax==0.2.12+tensorflow==2.5.0 aborts (as above, tested in a TPU v3 environment).
  • When used with jax>=0.2.13, it fails as above.

In summary, the model I planned to use is semi-deprecated. This comment https://github.com/kingoflolz/mesh-transformer-jax/issues/67#issuecomment-886395688 says it’s not likely to be fixed by the mesh-transformer-jax team.

Gosh, I don’t know if this is a mesh-transformer-jax issue or not. I feel a bit embarrassed because I feel like there isn’t a right place to post this. I suppose that I’m really looking for big-picture advice, this is where it all began, and you have the big-picture context.

Ultimately, what would you recommend for an experimenter who does not have access to proprietary models?

Options:

  • Find and fix the xmap usage in model v1 of mesh-transformer-jax, restoring GPT-J-6B.
  • Use an alternative model.
  • Attempt to train a new model on v2 of mesh-transformer-jax.
  • Something else?

I appreciate your patience with my pre-JAX-familiarity notes. You have good documentation; I ought to read through it and I will do that shortly. Thanks and best wishes.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/6962#issuecomment-922151574, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAAOR4FJXUMUTW3D2UH7RUDUCPQKZANCNFSM46UTLNFQ .