jax: vmap of cond's predicate results in select, leading to unexpected compute/memory use

I have been playing around with converting diffmpm from the difftaichi package into a jax version, and while the forward pass has been working wonderfully, the backward pass has been using way too much GPU memory.

Today, I was able to track down that memory usage to the grid op. The grid op step is a series of nested if statements. At first, I was using jnp.where, which evaluates all branches. That is extremely inefficient and can lead to OOM errors. I simplified my code, and switch to jnp.cond, but my only conclusion is that cond is also evaluating both branches, otherwise I cannot see why this would run into OOM issues.

Below is a modified version of the grid op, that is composed into itself 4,000 times, like a simulation. Even run with the XLA_PYTHON_CLIENT_PREALLOCATE=false flag, this quickly leads to the the whole GPU being used, and more if the loop length is increased. This is not true if every line from lin = .... until right before the return of grid_op is commented out. In that case, memory usage is practically negligible. Note that because bound = 0, literally every line written v_out = jax.lax.cond ... evaluates to False by definition, and so most of the expressions, including the v_out_gate’s and their dependencies, shouldn’t even need to be evaluated in the jitted function.

Maybe I am misunderstanding cond; if so, what is the proper way to get this sparse branching behavior? I don’t want to evlauate and hang onto a bunch of expensive tensors that are never actually needed and crash my GPU with OOM, especially in an backward pass. This is a core bottleneck to practical deployment of my code and a feature that I think should be supported. FWIW, I am using Version: 0.1.69+cuda101

Code to reproduce is below.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import os
import math
import numpy as np
import matplotlib.pyplot as plt
import jax.nn as jnn
import jax.lax as jlax
import timeit
import jax

dim = 2
n_grid = 128
dt = 1e-3
gravity = 3.8


def allocate_arrays():
  global grid_m_in, grid_v_in, grid_v_out, loss, index_array
  grid_m_in = jnp.ones((n_grid, n_grid))
  grid_v_in = jnp.zeros((n_grid, n_grid, dim))
  grid_v_out = jnp.zeros((n_grid, n_grid, dim))


  index_array = np.zeros((n_grid, n_grid, dim))
  
  for i in range(n_grid):
    for j in range(n_grid):
      index_array[i, j] = np.array([i, j])
 
  index_array = jnp.array(index_array)

  

def grid_op(grid_v_in, grid_m_in, index_tuple):
  bound = 0
  coeff = 0.5
  
  i = index_tuple[0]
  j = index_tuple[1]
  
  normal = jnp.array([0., 1.])
  
  inv_m = 1 / (grid_m_in + 1e-10)
  v_out = jnp.expand_dims(inv_m, -1) * grid_v_in
  v_out -= dt * gravity * jnp.array([0., 1.])
  
  v_out = jax.lax.cond(jnp.logical_and(i < bound, v_out[0] < 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)
  
  
  v_out = jax.lax.cond(jnp.logical_and(i > n_grid - bound, v_out[0] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)
  
  lin = (v_out.transpose() @ normal)
  
  vit = v_out - lin * normal
  lit = jnp.linalg.norm(vit + 1e-10)  + 1e-10
  
  
  v_out_gate_2 = jax.lax.cond(lit + coeff * lin <= 0, lambda _: jnp.zeros_like(v_out), lambda _: (1 + coeff * lin / lit) * vit, operand=None)
  v_out_gate_1 = jax.lax.cond(lin < 0, lambda _: v_out_gate_2, lambda _: jnp.zeros_like(v_out), operand=None)
  v_out = jax.lax.cond(jnp.logical_and(j < bound, v_out[1] < 0), lambda _: v_out_gate_1, lambda _: v_out, operand=None)          
  v_out = jax.lax.cond(jnp.logical_and(j > n_grid - bound, v_out[1] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)
  
  return v_out

go_j = jit(vmap(vmap(grid_op)))


def advance2(t, args):
  grid_v_in = args[0]
  grid_m_in = args[1]
  index_array = args[2]
  grid_v_in = go_j(grid_v_in, grid_m_in, index_array)
  
  return grid_v_in, grid_m_in, index_array
  
  
def advance(t, args):
  x = args[0]
  v = args[1]
  C = args[2]
  F = args[3]
  x, v, C, F = p1_j(x, v, C, F, actuator_id)
  
  return x, v, C, F
  
a = jit(advance)

def forward2(grid_v_in, grid_m_in, index_array):
  grid_v_in, grid_m_in, index_array = jlax.fori_loop(0, 4000, advance2, (grid_v_in, grid_m_in, index_array))
  return jnp.mean(grid_v_in)


def main():
# initialization
  allocate_arrays()
  
  f2 = jit(forward2)
  forward_grad2 = jit(grad(forward2))

  number = 10
  

  print(timeit.timeit(lambda : f2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : f2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : forward_grad2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : forward_grad2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  
if __name__ == "__main__":
  main()

About this issue

  • Original URL
  • State: open
  • Created 3 years ago
  • Comments: 25 (3 by maintainers)

Most upvoted comments

We’re in agreement here by and large. This is something that we’ve thought about improving before, whether at the JAX or XLA level. I can’t find an open issue for it, so let’s use this one for it.

A loop may be possible in some cases, but not all, and I think that makes this problematic. And, if I can, I would like to make an argument for why I think somehow supporting cond’s in vmap’d scenarios is so important.

The wonderful thing about vmap, as advertised, is that anything should be easily vmappable. It’s essentially the beauty of abstraction and modularity, as described. Someone writes a module, foo, and if you want to parallelize it, it should be as easy as just calling vmap(foo). You don’t need to understand foo, it should just work. In this case, it seems like cond, a very critical dataflow construct, has different behaviors whether vmap’d or not. The problem, of course, is that people who use a module might not even know if cond is being used inside. This leads to functionally very different performance profiles when vmap’d or not that might require deep investigation of foo, which could be a quite complex piece of code. I don’t think this is good for the mission of vmap.

In some cases (as I would argue here), vmap appears to be the correct way to process such code, rather than a for loop, since as far as I understand, indexing into arrays via at is runtime expensive in jax (correct me if I’m mistaken), and creating large, sparse one-hot masks would be prohibitively memory-expensive without support for sparse matrices.

Hi there! I was wondering what the JAX team’s latest thinking is regarding the behavior of lax.cond when batched via vmap.

I find myself often running into the design pattern of conditionally branching into two subroutines, one expensive, and the other a “placeholder,” for example, returning a dummy zero tensor.

I see. Is there any reason why vmap’d conds are not supported? What is the logic there? I think this would be very important for a lot of people (and it would be great if these nuances were documented). I don’t know the internals of the jax compiler, but I do know that other similar systems support some form of branching. Is there any way to add this capability to XLA or somehow jax?

Actually - since I could, if I were carefully about my code, manually batch everything, including conditionals, I see no reason why this couldn’t be supported with vmap?

I did try jax.checkpoint, but even having one checkpoint and splitting the loop into half balloons the runtime (I think by 4x IIRC, I’ll have to double check, but that’s surprising, since I would think as long as checkpoints weren’t encompassing each other, that the maximum extra runtime this would incur is 2x). I’m actually On that note, adding cond’s vs. where’s also balloons the runtime by 2x…eventually these things start to really add up.

That’s what I thought – might be helpful to document that for switch similarly to what’s in cond doc.: e.g" However, when transformed with vmap to operate over a batch of predicates/indices, switch is converted to select"

I find this one of the biggest practical issues with jax – Vmap+jit are great but in a lot of code this also necessitates use of cond with them, which results in compute/memory issues, as described above,…and the time spent trying to work around those.

@minqi – the thinking hasn’t changed much since this issue was last active. Although there’s a fundamental puzzle regarding whether/how to do better, for now we’re still producing select when we batch cond’s predicate.