jax: inside pmap, grad(lambda x: psum(loss(x))) inconsistent with jvp(lambda x: psum(loss(x))
Running this on a machine with >=2 devices:
import jax.numpy as jnp
from jax import vmap, pmap, grad, lax
def loss(w, d):
return w * d
def local(w, data):
def local_loss(w):
return vmap(lambda d: loss(w, d))(data).sum()
return local_loss(w), grad(local_loss)(w)
print(local(1., jnp.arange(2.))) # ==> (1., 1.)
def distributed(w, data):
def agg_loss(w):
return lax.psum(loss(w, data), 'batch')
return agg_loss(w), grad(agg_loss)(w)
print(pmap(distributed, in_axes=(None, 0), axis_name='batch')(1., jnp.arange(2.))) # ==> ([1., 1.], [0., 2.])
The losses are correct but gradients are incorrect across each shard.
What should be the intended behavior here?
About this issue
- Original URL
- State: open
- Created 4 years ago
- Reactions: 1
- Comments: 15 (13 by maintainers)
I think we shouldn’t necessarily expect invariants that are true in ordinary code to also hold for SPMD code; we often assume that SPMD code will behave like a “batch of (independent) programs”, but the presence of collectives means that can be a bad analogy! Instead, I like to think of SPMD programming as a different perspective on bulk array programming (the named axis is a real axis that just doesn’t have a position, and collectives are just bulk primitives, like reductions, that operate on that axis).
That means that a scalar -> scalar SPMD program is not actually a scalar -> scalar program, and nor is it a batch of scalar -> scalar programs. It’s a vector -> vector program (that acts elementwise/has a diagonal Jacobian if there aren’t any collectives, and doesn’t if there are). This is a particularly important distinction for autodiff, since a single application of forward-mode autodiff or finite differences can only give the derivative of a
scalar ->function, while a single application of reverse-mode autodiff can only give the derivative of a-> scalarfunction (with an exception in both cases if the Jacobian is known to be diagonal).In bulk array terms, Matt’s three examples are:
These results are a little less surprising! Here
agg_lossis an ordinary vector -> scalar function andf_gradis computing the gradient, whilef_jvpandf_fdare both computing something else (the product of the Jacobian with a vector of ones).In this bulk array context, another way to compute that same Jacobian-vector product is to take the
gradof a scalar -> scalar function that broadcasts its input (expressed with the sameagg_lossPython code through NumPy rank polymorphism/implicit broadcasting):Here the forward pass performs a broadcast, multiply, then sum, so the transpose performs a broadcast, multiply, then sum too (broadcast is fan-out and its transpose is fan-in-sum). How to make sure the transpose contains a sum when the forward pass actually needed to broadcast, but doesn’t when it didn’t, is the implicit broadcasting problem in autodiff systems for dynamically-ranked languages like TF graphs and TorchScript.
My claim is that all of these concerns apply in the SPMD case, too. An SPMD program cannot run except with all of its named axes bound by corresponding maps, and the resulting program is a bulk array program with well-defined semantics. We need to treat those semantics as the source of truth for thinking about SPMD programs, rather than intuitions about batches of independent programs, because those intuitions break down when collectives are involved.
In bulk array programs, a particular logical axis can either be present or absent in each of the arguments and return values. In the context of bulk array programs where that axis was created by mapping an SPMD axis using JAX’s map APIs, this corresponds to whether the
in_axes/out_axessetting for each argument/return is an integer orNone. Here’s what I think that means for the SPMD versions of our examples:(That last one is Sharad’s code.)
We don’t currently produce the desired output in the last case, because we include a spurious
psumat the beginning of the transpose (from [1. 1.] to [2. 2.]) and don’t include a necessary one at the end. There might be multiple ways to fix the system and get the right answer, but my understanding is that the best way forward looks something like this:First, we should distinguish between “scalars” and “batched scalars” with respect to a particular SPMD axis, and introduce a
pbroadcastto take a scalar to a batched scalar with the same value (whilepsumshould take a batched scalar and produce a scalar). Then, the transpose ofpbroadcastshould bepsumand vice versa, and the cotangent of a scalar should always be a scalar and the cotangent of a batched scalar should always be a batched scalar. This means that (just like in the rank-polymorphic bulk case!) we should insert apbroadcastwhen a scalar needs to be promoted to a batched scalar (as inw * data) and the backward pass should have apsumthere.Ideally we would also match the bulk-array error semantics too, and constrain the return value of the function passed to
gradto be a scalar, not a batched scalar; we should also supportout_axes=Nonein maps so that we can return these scalars out to the bulk world and make the semantics less obscure.As Matt points out, the semantics of SPMD programs at head are more consistent than I thought—they just correspond to different (and in my view less useful) bulk array semantics. Adopting my preferred approach would be breaking, and would need to be sequenced carefully with other enhancements we’re planning to make.
In particular: at head, values in SPMD code always “contain” the pmapped axis (by which I mean that their bulk array counterparts always contain the corresponding logical axis). This has a few consequences:
pmaphas to insertpbroadcastofin_axis=Noneor closed-over values at the beginning of the mapped region, rather than later on/when needed, because a value inside apmapthat hasn’t beenpbroadcasted isn’t representablepsumhas to have reduce-sum-broadcast semantics, because the output of reduce-sum without a broadcast isn’t representablegradinside apmaphas to correspond to bulk vjp-with-ones because a bulk scalar isn’t representableThese are essentially the things that surprised Sharad. (1) meant that the output of the gradient function wasn’t
psummed, because the input of the forward function had already been broadcasted; (2) meant that the values flowing back through the VJP were 2 rather than 1, because those values had beenpsummed; and (3) meant that the output values computed were no longer gradients, since the bulk version of the forward function had a non-scalar output.Collectively they mean that a pattern that would otherwise be natural and useful, and represents the SPMD counterpart of standard bulk code for NN training—taking the gradient of a psummed loss with respect to replicated parameters—can’t be expressed in SPMD JAX. There are two alternatives: using
gradofpmap, which has significant runtime overheads, and moving thepsumoutside the loss function, which is the most common approach in data-parallel JAX today, but represents an unfortunate loss of expressiveness and a gap between SPMD code and bulk code.These things and others would become easier, and the system more aligned with Sharad’s (and my) mental model, if SPMD JAX were extended with the ability to represent values that don’t contain the mapped axis. This is a strict increase in expressiveness, because it increases the set of bulk programs that have SPMD counterparts. It would not require a change in the semantics of existing programs, but it would open up the possibility of certain changes (and wouldn’t be very useful without them):
pmapcan now wait to insertpbroadcastofin_axis=Noneor closed-over values until they’re mixed with values that contain the mapped axispsumcan correspond to bulk reduce-sumgradinside apmapcan correspond to bulkgradWith these changes, Sharad’s code would work out of the box with his desired semantics, and the set of representable programs becomes more symmetric between the SPMD and bulk worlds.
But of course the “most common approach in data-parallel JAX today” (
psum/pmeanof the result ofgradinside apmap) assumes per-devicegradwith vjp-with-ones bulk semantics! So we have to be careful about how we make changes like these.We could, for instance, introduce kwargs to control these behaviors (e.g.,
keepdimsonpsumcould default to True at first). We could also sequence these changes to take place alongside new APIs we’re adding or considering adding, likegmapand a potentialsum-psumunification.Okay, back to it!
Notice the Jacobian of the
agg_lossfunction written with the broadcast is[[0, 1], [0, 1]]:So, while keeping the current (at HEAD) bulk array definition of
psumas a reduce-sum followed by a broadcast, the SPMD AD semantics is consistent so long as we takegradto mean “compute the VJP against a ones vector broadcast along all named axes”:(In Autograd, like TF today, we used to define
gradas vjp-with-ones, but exactly this sort of confusion is why we made it raise an error for non-scalar outputs. Yet we didn’t make that check work with SPMD functions, in the sense thatgradwill in effect happily allow broadcasting along named mapped axes!)If the semantics at HEAD are self-consistent, except the error semantics for
grad, do we need to change anything, except perhaps to avoid this potential confusion by makinggraderror semantics consistent in the positional and SPMD worlds?Maybe yes. One problem with the current semantics is that if we make @sharadmv’s use of
gradhere an error (rather than a vjp-with-ones) not only would that have been surprising to him, but also it would break pretty much all existing SPMD neural net training; they’d have to write vjp-with-ones themselves, e.g. by defininggrad2as above. Even then, the answers can be surprising: within the context of the SPMD function, it looks like we’re callinggrad/grad2on a scalar-input scalar-output function (but for the closed-over value ofdatawhich is different on each device) with the same primal input value on every device, yet getting differentgrad2results on different devices (perhaps not noticing that if we looked at the primal output value we’d also have a different value on each device, which might make getting different gradients less surprising). That is, in the expressiongrad(agg_loss)(w)in @sharadmv 's original code, the functionagg_lossis a different function on each device because it closes over mapped data, which is why we should expect to get different answers.In any case, while I now think the semantics at HEAD are actually consistent (modulo error semantics for
grad) and defensible, this example has clearly shown that they can be confusing, especially when differentiating functions that close over mapped values. The main difference with @jekbradbury 's proposed semantics (as he’s pointed out too) is whetherpsumcorresponds in bulk array land to a reduce-sum, or a reduce-sum-then-broadcast.I have thought about this issue a bit in another context and I think you need the notion/type of a SPMD-global scalar to make grad make sense with internal pmap. The API difficulty with pmap is that it has no way to return something of rank < 1 so semantically the program never has a function on which grad should be defined since the output is never a scalar. Have a ReplicatedArray be a tensor in some contexts and a scalar in others seems like it could lead to semantic problems.
By “replicated thing” I mean a guaranteed-symmetric value, i.e. one where
out_axes=Nonewould be allowed if it were a return value of the map. In terms of functions, I mean those whose return values are replicated.