jax: Very slow compile for tree-map over large number of leaves

Description

In my problem I want to run gradient descent over a groups of data of different sizes. There is significant heterogeneity in the group sizes; one group might have 100 datapoints, another might have 1000. I need to run a model on each individual datapoint in each group and then sum the output and push that through a loss with a group outcome, which we can call the group_Y. So for a single group we have ${(X_0, …, X_{N_g}), Y_g}$ and want to compute z = (model_fn(X_0),...,model_fn(X_{N_g}) and then loss_fn(z, Y_g). We then need to take the mean of the loss over the groups so we can get a scalar output. The very natural way to do this in jax, at least to me, is to use tree maps and tree reduces:

def get_mapped_grad_and_value(
        loss_fn: Callable,
        model_fn: Callable,
) -> Tuple[Callable, Callable]:
    """Returns a function that computes the gradient of the loss function with
    respect to the model parameters.

    Args:
        loss_fn (function): loss function
        model_fn (function): model function
        data (tuple): tuple of 

    Returns:
        tuple: loss function, gradient function
    """
    def mapped_loss(model_params, grouped_Xs, grouped_Ys):
        """Compute the loss function mapped over groups.
        
        Args:
            model_params: model parameters
            grouped_Xs: dict of np.ndarrays of group datapoints
            grouped_Ys: dict of float group outcomes
            
        Returns:
            mean loss over all groups"""
        num_groups = len(grouped_Ys)
        group_losses = jax.tree_util.tree_map(
            lambda group_X, group_Y: 
            loss_fn(model_fn(model_params, group_X), group_Y) / num_groups, 
            grouped_Xs, grouped_Ys
        )
        loss = jax.tree_util.tree_reduce(lambda x, y: x + y, group_losses)
        return loss
    
    grad_and_loss = jax.jit(jax.value_and_grad(mapped_loss))
    return grad_and_loss

Now I do have a large dataset, about 1000-8000 groups, with roughly 1-20 million datapoints partitioned between the groups. But the compile times I am seeing are very long, about 4-5 hours. I am wondering what, if anything, I can do to speed up compile times here.

It’s worth noting I am seeing very similar compile times for any analogous code:

def mapped_loss(model_params, grouped_Xs, grouped_Ys):
    group_losses = jax.tree_util.tree_map(
        lambda group_X, group_Y, group_N: 
        loss_fn(model_fn(model_params, group_X), group_Y), 
        grouped_Xs, grouped_Ys
    )
    return np.mean(np.array(list(group_losses.values())))
def mapped_loss(model_params, grouped_Xs, grouped_Ys):
    loss = 0
    for g in grouped_Xs:
        loss += loss_fn(model_fn(model_params, grouped_Xs[g]), grouped_Ys[g])
    return loss

Currently I am compiling the loss/gradient of the loss function, tree-mapping that, and using a non-compiled tree-reduce to actually compute the final gradient update. This works fine if a bit slow, but it makes it prohibitive to use more interesting optimizers from e.g. jaxopt which use jitted loops.

Maybe it is not possible to speed this up, but if it is I’d be grateful.

What jax/jaxlib version are you using?

jax v0.4.10

Which accelerator(s) are you using?

M1 CPU

Additional system info

Python 3.9.13

NVIDIA GPU info

No response

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 16

Most upvoted comments

Yes! Thank you @patrick-kidger I think this noinline functionality is basically doing what I was thinking of re: “weak compilation”.

More generally I guess I think this behavior is odd, and I think it’d be quite nice to have an override (e.g. a parameter in jit that would let something act jitted for the purpose of use in other functions/recompilation without inlining every instance which I presume is what’s exploding the compile times). The reason I think it’s odd is that in my use case I can do the following:

def get_mapped_loss_and_grad(
        loss_fn: Callable,
        model_fn: Callable,
) -> Tuple[Callable, Callable]:
    """Returns a function that computes the gradient of the loss function with
    respect to the model parameters.

    Args:
        loss_fn (function): loss function
        model_fn (function): model function
        data (tuple): tuple of 

    Returns:
        tuple: loss function, gradient function
    """

    _loss_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))

    def loss_and_grad_fn(model_params, group_Xs, group_Ys):
        num_groups = len(group_Ys)
        group_losses = jax.tree_util.tree_map(
            lambda group_X, group_Y, group_N: 
            _loss_and_grad_fn(model_params, group_X, group_Y), 
            group_Xs, group_Ys
        )
        mean_loss, grad = jax.tree_util.tree_reduce(
            lambda x, y: (x[0] + y[0], x[1] + y[1]), group_losses, 
            is_leaf=lambda x: not isinstance(x, dict)
        )
        return mean_loss, grad

    return loss_and_grad_fn

The compile time is a couple minutes to compile in total to run this with all my groups. The code is also very fast!

But if I run loss_and_grad_fn = jax.jit(loss_and_grad_fn) (or I compile any of the similar functions outlined above) that takes many hours to compile for qualitatively very little actual gain in runtime (I can actually run tests for this if it would be useful/interesting). And not being able to jit comile is annoying! Lots of code presumes you can jit internal functions almost without cost. Now presumably as Jake and Patrick suggest the general solution is to pad things out in some appropriate way. But this gets quite cumbersome to implement and understand (if you aren’t familiar with jax already). I think it would also be quite wasteful of memory which is a bottleneck in my problem.

Now in my particular case I can re-write my codebase to use jax.experimental.sparse (I knew this from the start but was intentionally trying to steer clear of experimental code), so I will probably go that route! I guess I just want to gesture vaguely that I think it would be reasonable to support something like this allowing users to more explicitly tradeoff between compile time and run time in their specific setting with some options on jit like those Patrick is pointing to in equinox. On the other hand, maybe equinox is all I need and this doesn’t need to make its way into the base package.

Anyway, this has been quite helpful! Thank you all for the discussion. If folks feel this is resolved I’m happy to close the issue.

If you’re on the CPU then you can try wrapping your model_fn in an equinox.internal.noinline. (See here.) This will compile the wrapped function only once for each group size. The downside is that this will probably only work on the CPU, and it dramatically increases runtime.

The “real” fix here is indeed as Jake says – pad out the groups to just a few sizes, and then replace multiple calls fn(group1); fn(group2) to groups of the same size with a jax.vmap(fn)(jnp.concatenate([group1, group2]).