jax: inside pmap, grad(lambda x: psum(loss(x))) inconsistent with jvp(lambda x: psum(loss(x))
Running this on a machine with >=2 devices:
import jax.numpy as jnp
from jax import vmap, pmap, grad, lax
def loss(w, d):
return w * d
def local(w, data):
def local_loss(w):
return vmap(lambda d: loss(w, d))(data).sum()
return local_loss(w), grad(local_loss)(w)
print(local(1., jnp.arange(2.))) # ==> (1., 1.)
def distributed(w, data):
def agg_loss(w):
return lax.psum(loss(w, data), 'batch')
return agg_loss(w), grad(agg_loss)(w)
print(pmap(distributed, in_axes=(None, 0), axis_name='batch')(1., jnp.arange(2.))) # ==> ([1., 1.], [0., 2.])
The losses are correct but gradients are incorrect across each shard.
What should be the intended behavior here?
About this issue
- Original URL
- State: open
- Created 4 years ago
- Reactions: 1
- Comments: 15 (13 by maintainers)
I think we shouldn’t necessarily expect invariants that are true in ordinary code to also hold for SPMD code; we often assume that SPMD code will behave like a “batch of (independent) programs”, but the presence of collectives means that can be a bad analogy! Instead, I like to think of SPMD programming as a different perspective on bulk array programming (the named axis is a real axis that just doesn’t have a position, and collectives are just bulk primitives, like reductions, that operate on that axis).
That means that a scalar -> scalar SPMD program is not actually a scalar -> scalar program, and nor is it a batch of scalar -> scalar programs. It’s a vector -> vector program (that acts elementwise/has a diagonal Jacobian if there aren’t any collectives, and doesn’t if there are). This is a particularly important distinction for autodiff, since a single application of forward-mode autodiff or finite differences can only give the derivative of a
scalar ->
function, while a single application of reverse-mode autodiff can only give the derivative of a-> scalar
function (with an exception in both cases if the Jacobian is known to be diagonal).In bulk array terms, Matt’s three examples are:
These results are a little less surprising! Here
agg_loss
is an ordinary vector -> scalar function andf_grad
is computing the gradient, whilef_jvp
andf_fd
are both computing something else (the product of the Jacobian with a vector of ones).In this bulk array context, another way to compute that same Jacobian-vector product is to take the
grad
of a scalar -> scalar function that broadcasts its input (expressed with the sameagg_loss
Python code through NumPy rank polymorphism/implicit broadcasting):Here the forward pass performs a broadcast, multiply, then sum, so the transpose performs a broadcast, multiply, then sum too (broadcast is fan-out and its transpose is fan-in-sum). How to make sure the transpose contains a sum when the forward pass actually needed to broadcast, but doesn’t when it didn’t, is the implicit broadcasting problem in autodiff systems for dynamically-ranked languages like TF graphs and TorchScript.
My claim is that all of these concerns apply in the SPMD case, too. An SPMD program cannot run except with all of its named axes bound by corresponding maps, and the resulting program is a bulk array program with well-defined semantics. We need to treat those semantics as the source of truth for thinking about SPMD programs, rather than intuitions about batches of independent programs, because those intuitions break down when collectives are involved.
In bulk array programs, a particular logical axis can either be present or absent in each of the arguments and return values. In the context of bulk array programs where that axis was created by mapping an SPMD axis using JAX’s map APIs, this corresponds to whether the
in_axes
/out_axes
setting for each argument/return is an integer orNone
. Here’s what I think that means for the SPMD versions of our examples:(That last one is Sharad’s code.)
We don’t currently produce the desired output in the last case, because we include a spurious
psum
at the beginning of the transpose (from [1. 1.] to [2. 2.]) and don’t include a necessary one at the end. There might be multiple ways to fix the system and get the right answer, but my understanding is that the best way forward looks something like this:First, we should distinguish between “scalars” and “batched scalars” with respect to a particular SPMD axis, and introduce a
pbroadcast
to take a scalar to a batched scalar with the same value (whilepsum
should take a batched scalar and produce a scalar). Then, the transpose ofpbroadcast
should bepsum
and vice versa, and the cotangent of a scalar should always be a scalar and the cotangent of a batched scalar should always be a batched scalar. This means that (just like in the rank-polymorphic bulk case!) we should insert apbroadcast
when a scalar needs to be promoted to a batched scalar (as inw * data
) and the backward pass should have apsum
there.Ideally we would also match the bulk-array error semantics too, and constrain the return value of the function passed to
grad
to be a scalar, not a batched scalar; we should also supportout_axes=None
in maps so that we can return these scalars out to the bulk world and make the semantics less obscure.As Matt points out, the semantics of SPMD programs at head are more consistent than I thought—they just correspond to different (and in my view less useful) bulk array semantics. Adopting my preferred approach would be breaking, and would need to be sequenced carefully with other enhancements we’re planning to make.
In particular: at head, values in SPMD code always “contain” the pmapped axis (by which I mean that their bulk array counterparts always contain the corresponding logical axis). This has a few consequences:
pmap
has to insertpbroadcast
ofin_axis=None
or closed-over values at the beginning of the mapped region, rather than later on/when needed, because a value inside apmap
that hasn’t beenpbroadcast
ed isn’t representablepsum
has to have reduce-sum-broadcast semantics, because the output of reduce-sum without a broadcast isn’t representablegrad
inside apmap
has to correspond to bulk vjp-with-ones because a bulk scalar isn’t representableThese are essentially the things that surprised Sharad. (1) meant that the output of the gradient function wasn’t
psum
med, because the input of the forward function had already been broadcasted; (2) meant that the values flowing back through the VJP were 2 rather than 1, because those values had beenpsum
med; and (3) meant that the output values computed were no longer gradients, since the bulk version of the forward function had a non-scalar output.Collectively they mean that a pattern that would otherwise be natural and useful, and represents the SPMD counterpart of standard bulk code for NN training—taking the gradient of a psummed loss with respect to replicated parameters—can’t be expressed in SPMD JAX. There are two alternatives: using
grad
ofpmap
, which has significant runtime overheads, and moving thepsum
outside the loss function, which is the most common approach in data-parallel JAX today, but represents an unfortunate loss of expressiveness and a gap between SPMD code and bulk code.These things and others would become easier, and the system more aligned with Sharad’s (and my) mental model, if SPMD JAX were extended with the ability to represent values that don’t contain the mapped axis. This is a strict increase in expressiveness, because it increases the set of bulk programs that have SPMD counterparts. It would not require a change in the semantics of existing programs, but it would open up the possibility of certain changes (and wouldn’t be very useful without them):
pmap
can now wait to insertpbroadcast
ofin_axis=None
or closed-over values until they’re mixed with values that contain the mapped axispsum
can correspond to bulk reduce-sumgrad
inside apmap
can correspond to bulkgrad
With these changes, Sharad’s code would work out of the box with his desired semantics, and the set of representable programs becomes more symmetric between the SPMD and bulk worlds.
But of course the “most common approach in data-parallel JAX today” (
psum
/pmean
of the result ofgrad
inside apmap
) assumes per-devicegrad
with vjp-with-ones bulk semantics! So we have to be careful about how we make changes like these.We could, for instance, introduce kwargs to control these behaviors (e.g.,
keepdims
onpsum
could default to True at first). We could also sequence these changes to take place alongside new APIs we’re adding or considering adding, likegmap
and a potentialsum
-psum
unification.Okay, back to it!
Notice the Jacobian of the
agg_loss
function written with the broadcast is[[0, 1], [0, 1]]
:So, while keeping the current (at HEAD) bulk array definition of
psum
as a reduce-sum followed by a broadcast, the SPMD AD semantics is consistent so long as we takegrad
to mean “compute the VJP against a ones vector broadcast along all named axes”:(In Autograd, like TF today, we used to define
grad
as vjp-with-ones, but exactly this sort of confusion is why we made it raise an error for non-scalar outputs. Yet we didn’t make that check work with SPMD functions, in the sense thatgrad
will in effect happily allow broadcasting along named mapped axes!)If the semantics at HEAD are self-consistent, except the error semantics for
grad
, do we need to change anything, except perhaps to avoid this potential confusion by makinggrad
error semantics consistent in the positional and SPMD worlds?Maybe yes. One problem with the current semantics is that if we make @sharadmv’s use of
grad
here an error (rather than a vjp-with-ones) not only would that have been surprising to him, but also it would break pretty much all existing SPMD neural net training; they’d have to write vjp-with-ones themselves, e.g. by defininggrad2
as above. Even then, the answers can be surprising: within the context of the SPMD function, it looks like we’re callinggrad
/grad2
on a scalar-input scalar-output function (but for the closed-over value ofdata
which is different on each device) with the same primal input value on every device, yet getting differentgrad2
results on different devices (perhaps not noticing that if we looked at the primal output value we’d also have a different value on each device, which might make getting different gradients less surprising). That is, in the expressiongrad(agg_loss)(w)
in @sharadmv 's original code, the functionagg_loss
is a different function on each device because it closes over mapped data, which is why we should expect to get different answers.In any case, while I now think the semantics at HEAD are actually consistent (modulo error semantics for
grad
) and defensible, this example has clearly shown that they can be confusing, especially when differentiating functions that close over mapped values. The main difference with @jekbradbury 's proposed semantics (as he’s pointed out too) is whetherpsum
corresponds in bulk array land to a reduce-sum, or a reduce-sum-then-broadcast.I have thought about this issue a bit in another context and I think you need the notion/type of a SPMD-global scalar to make grad make sense with internal pmap. The API difficulty with pmap is that it has no way to return something of rank < 1 so semantically the program never has a function on which grad should be defined since the output is never a scalar. Have a ReplicatedArray be a tensor in some contexts and a scalar in others seems like it could lead to semantic problems.
By “replicated thing” I mean a guaranteed-symmetric value, i.e. one where
out_axes=None
would be allowed if it were a return value of the map. In terms of functions, I mean those whose return values are replicated.