jax: inside pmap, grad(lambda x: psum(loss(x))) inconsistent with jvp(lambda x: psum(loss(x))

Running this on a machine with >=2 devices:

import jax.numpy as jnp
from jax import vmap, pmap, grad, lax

def loss(w, d):
  return w * d

def local(w, data):
  def local_loss(w):
    return vmap(lambda d: loss(w, d))(data).sum()
  return local_loss(w), grad(local_loss)(w)
print(local(1., jnp.arange(2.))) # ==> (1., 1.)


def distributed(w, data):
  def agg_loss(w):
    return lax.psum(loss(w, data), 'batch')
  return agg_loss(w), grad(agg_loss)(w)
print(pmap(distributed, in_axes=(None, 0), axis_name='batch')(1., jnp.arange(2.))) # ==> ([1., 1.], [0., 2.])

The losses are correct but gradients are incorrect across each shard.

What should be the intended behavior here?

About this issue

  • Original URL
  • State: open
  • Created 4 years ago
  • Reactions: 1
  • Comments: 15 (13 by maintainers)

Most upvoted comments

I think we shouldn’t necessarily expect invariants that are true in ordinary code to also hold for SPMD code; we often assume that SPMD code will behave like a “batch of (independent) programs”, but the presence of collectives means that can be a bad analogy! Instead, I like to think of SPMD programming as a different perspective on bulk array programming (the named axis is a real axis that just doesn’t have a position, and collectives are just bulk primitives, like reductions, that operate on that axis).

That means that a scalar -> scalar SPMD program is not actually a scalar -> scalar program, and nor is it a batch of scalar -> scalar programs. It’s a vector -> vector program (that acts elementwise/has a diagonal Jacobian if there aren’t any collectives, and doesn’t if there are). This is a particularly important distinction for autodiff, since a single application of forward-mode autodiff or finite differences can only give the derivative of a scalar -> function, while a single application of reverse-mode autodiff can only give the derivative of a -> scalar function (with an exception in both cases if the Jacobian is known to be diagonal).

In bulk array terms, Matt’s three examples are:

import jax.numpy as jnp
from jax import vmap, pmap, grad, lax, jvp

def f_jvp(w, data):
  def agg_loss(w):
    return jnp.sum(w * data, 0)
  return jvp(agg_loss, (w,), (1.,))[1]
print(f_jvp(jnp.ones(2), jnp.arange(2)))
# prints 1.0

def f_fd(w, data):
  def agg_loss(w):
    return jnp.sum(w * data, 0)
  return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(f_fd(jnp.ones(2), jnp.arange(2)))
# prints 1.0000467

def f_grad(w, data):
  def agg_loss(w):
    return jnp.sum(w * data, 0)
  return grad(agg_loss)(w)
print(f_grad(jnp.ones(2), jnp.arange(2)))
# prints [0. 1.]

These results are a little less surprising! Here agg_loss is an ordinary vector -> scalar function and f_grad is computing the gradient, while f_jvp and f_fd are both computing something else (the product of the Jacobian with a vector of ones).

In this bulk array context, another way to compute that same Jacobian-vector product is to take the grad of a scalar -> scalar function that broadcasts its input (expressed with the same agg_loss Python code through NumPy rank polymorphism/implicit broadcasting):

def f_grad_bcast(w, data):
  def agg_loss(w):
    return jnp.sum(w * data, 0)
  return grad(agg_loss)(w)
print(f_grad_bcast(1., jnp.arange(2)))
# prints 1.0

Here the forward pass performs a broadcast, multiply, then sum, so the transpose performs a broadcast, multiply, then sum too (broadcast is fan-out and its transpose is fan-in-sum). How to make sure the transpose contains a sum when the forward pass actually needed to broadcast, but doesn’t when it didn’t, is the implicit broadcasting problem in autodiff systems for dynamically-ranked languages like TF graphs and TorchScript.

My claim is that all of these concerns apply in the SPMD case, too. An SPMD program cannot run except with all of its named axes bound by corresponding maps, and the resulting program is a bulk array program with well-defined semantics. We need to treat those semantics as the source of truth for thinking about SPMD programs, rather than intuitions about batches of independent programs, because those intuitions break down when collectives are involved.

In bulk array programs, a particular logical axis can either be present or absent in each of the arguments and return values. In the context of bulk array programs where that axis was created by mapping an SPMD axis using JAX’s map APIs, this corresponds to whether the in_axes/out_axes setting for each argument/return is an integer or None. Here’s what I think that means for the SPMD versions of our examples:

import jax.numpy as jnp
from jax import vmap, pmap, grad, lax, jvp

def f_jvp(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(f_jvp, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1. 1.]

def f_fd(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(pmap(f_fd, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1.0000466 1.0000466]

def f_grad(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(f_grad, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# should print [0. 1.]

def f_grad_bcast(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(f_grad_bcast, axis_name='batch', in_axes=(None, 0))(1., jnp.arange(2)))
# should print [1. 1.], or just 1.0 with out_axes=None

(That last one is Sharad’s code.)

We don’t currently produce the desired output in the last case, because we include a spurious psum at the beginning of the transpose (from [1. 1.] to [2. 2.]) and don’t include a necessary one at the end. There might be multiple ways to fix the system and get the right answer, but my understanding is that the best way forward looks something like this:

First, we should distinguish between “scalars” and “batched scalars” with respect to a particular SPMD axis, and introduce a pbroadcast to take a scalar to a batched scalar with the same value (while psum should take a batched scalar and produce a scalar). Then, the transpose of pbroadcast should be psum and vice versa, and the cotangent of a scalar should always be a scalar and the cotangent of a batched scalar should always be a batched scalar. This means that (just like in the rank-polymorphic bulk case!) we should insert a pbroadcast when a scalar needs to be promoted to a batched scalar (as in w * data) and the backward pass should have a psum there.

Ideally we would also match the bulk-array error semantics too, and constrain the return value of the function passed to grad to be a scalar, not a batched scalar; we should also support out_axes=None in maps so that we can return these scalars out to the bulk world and make the semantics less obscure.

As Matt points out, the semantics of SPMD programs at head are more consistent than I thought—they just correspond to different (and in my view less useful) bulk array semantics. Adopting my preferred approach would be breaking, and would need to be sequenced carefully with other enhancements we’re planning to make.

In particular: at head, values in SPMD code always “contain” the pmapped axis (by which I mean that their bulk array counterparts always contain the corresponding logical axis). This has a few consequences:

  1. pmap has to insert pbroadcast of in_axis=None or closed-over values at the beginning of the mapped region, rather than later on/when needed, because a value inside a pmap that hasn’t been pbroadcasted isn’t representable
  2. psum has to have reduce-sum-broadcast semantics, because the output of reduce-sum without a broadcast isn’t representable
  3. grad inside a pmap has to correspond to bulk vjp-with-ones because a bulk scalar isn’t representable

These are essentially the things that surprised Sharad. (1) meant that the output of the gradient function wasn’t psummed, because the input of the forward function had already been broadcasted; (2) meant that the values flowing back through the VJP were 2 rather than 1, because those values had been psummed; and (3) meant that the output values computed were no longer gradients, since the bulk version of the forward function had a non-scalar output.

Collectively they mean that a pattern that would otherwise be natural and useful, and represents the SPMD counterpart of standard bulk code for NN training—taking the gradient of a psummed loss with respect to replicated parameters—can’t be expressed in SPMD JAX. There are two alternatives: using grad of pmap, which has significant runtime overheads, and moving the psum outside the loss function, which is the most common approach in data-parallel JAX today, but represents an unfortunate loss of expressiveness and a gap between SPMD code and bulk code.

These things and others would become easier, and the system more aligned with Sharad’s (and my) mental model, if SPMD JAX were extended with the ability to represent values that don’t contain the mapped axis. This is a strict increase in expressiveness, because it increases the set of bulk programs that have SPMD counterparts. It would not require a change in the semantics of existing programs, but it would open up the possibility of certain changes (and wouldn’t be very useful without them):

  1. pmap can now wait to insert pbroadcast of in_axis=None or closed-over values until they’re mixed with values that contain the mapped axis
  2. psum can correspond to bulk reduce-sum
  3. grad inside a pmap can correspond to bulk grad

With these changes, Sharad’s code would work out of the box with his desired semantics, and the set of representable programs becomes more symmetric between the SPMD and bulk worlds.

But of course the “most common approach in data-parallel JAX today” (psum/pmean of the result of grad inside a pmap) assumes per-device grad with vjp-with-ones bulk semantics! So we have to be careful about how we make changes like these.

We could, for instance, introduce kwargs to control these behaviors (e.g., keepdims on psum could default to True at first). We could also sequence these changes to take place alongside new APIs we’re adding or considering adding, like gmap and a potential sum-psum unification.

Okay, back to it!

Notice the Jacobian of the agg_loss function written with the broadcast is [[0, 1], [0, 1]]:

import jax.numpy as jnp
from jax import lax
from jax import jacfwd

def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(0), (2,))
  return jacfwd(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [[0. 1.]
#  [0. 1.]]

So, while keeping the current (at HEAD) bulk array definition of psum as a reduce-sum followed by a broadcast, the SPMD AD semantics is consistent so long as we take grad to mean “compute the VJP against a ones vector broadcast along all named axes”:

import jax.numpy as jnp
from jax import vjp, jvp, pmap, grad
from jax import lax


### reverse-mode

# At HEAD, we define this SPMD program:
def f(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [0. 2.]

# To mean the same as this bulk array vjp-with-ones program (`grad` is always
# defined as vjp-with-ones plus an error check for scalar outputs that we don't
# include in the definition of SPMD semantics):
def grad2(f):
  def gradfun(x):
    ans, f_vjp = vjp(f, x)
    x_bar, = f_vjp(jnp.ones_like(ans))
    return x_bar
  return gradfun

def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(), (2,))  # bulk array version of psum
  return grad2(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [0. 2.]


# ### forward-mode

# At HEAD, we define this SPMD program:
def f(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [1. 1.]

# To mean the same as this bulk array jvp-with-ones program:
def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(), (2,))  # bulk array version of psum
  return jvp(agg_loss, (w,), (jnp.ones_like(w),))[1]
print(f(jnp.ones(2), jnp.arange(2)))
# [1. 1.]

(In Autograd, like TF today, we used to define grad as vjp-with-ones, but exactly this sort of confusion is why we made it raise an error for non-scalar outputs. Yet we didn’t make that check work with SPMD functions, in the sense that grad will in effect happily allow broadcasting along named mapped axes!)

If the semantics at HEAD are self-consistent, except the error semantics for grad, do we need to change anything, except perhaps to avoid this potential confusion by making grad error semantics consistent in the positional and SPMD worlds?

Maybe yes. One problem with the current semantics is that if we make @sharadmv’s use of grad here an error (rather than a vjp-with-ones) not only would that have been surprising to him, but also it would break pretty much all existing SPMD neural net training; they’d have to write vjp-with-ones themselves, e.g. by defining grad2 as above. Even then, the answers can be surprising: within the context of the SPMD function, it looks like we’re calling grad/grad2 on a scalar-input scalar-output function (but for the closed-over value of data which is different on each device) with the same primal input value on every device, yet getting different grad2 results on different devices (perhaps not noticing that if we looked at the primal output value we’d also have a different value on each device, which might make getting different gradients less surprising). That is, in the expression grad(agg_loss)(w) in @sharadmv 's original code, the function agg_loss is a different function on each device because it closes over mapped data, which is why we should expect to get different answers.

In any case, while I now think the semantics at HEAD are actually consistent (modulo error semantics for grad) and defensible, this example has clearly shown that they can be confusing, especially when differentiating functions that close over mapped values. The main difference with @jekbradbury 's proposed semantics (as he’s pointed out too) is whether psum corresponds in bulk array land to a reduce-sum, or a reduce-sum-then-broadcast.

I have thought about this issue a bit in another context and I think you need the notion/type of a SPMD-global scalar to make grad make sense with internal pmap. The API difficulty with pmap is that it has no way to return something of rank < 1 so semantically the program never has a function on which grad should be defined since the output is never a scalar. Have a ReplicatedArray be a tensor in some contexts and a scalar in others seems like it could lead to semantic problems.

By “replicated thing” I mean a guaranteed-symmetric value, i.e. one where out_axes=None would be allowed if it were a return value of the map. In terms of functions, I mean those whose return values are replicated.