jax: Very slow compile for tree-map over large number of leaves
Description
In my problem I want to run gradient descent over a groups of data of different sizes. There is significant heterogeneity in the group sizes; one group might have 100 datapoints, another might have 1000. I need to run a model on each individual datapoint in each group and then sum the output and push that through a loss with a group outcome, which we can call the group_Y. So for a single group we have ${(X_0, …, X_{N_g}), Y_g}$ and want to compute z = (model_fn(X_0),...,model_fn(X_{N_g})
and then loss_fn(z, Y_g)
. We then need to take the mean of the loss over the groups so we can get a scalar output. The very natural way to do this in jax, at least to me, is to use tree maps and tree reduces:
def get_mapped_grad_and_value(
loss_fn: Callable,
model_fn: Callable,
) -> Tuple[Callable, Callable]:
"""Returns a function that computes the gradient of the loss function with
respect to the model parameters.
Args:
loss_fn (function): loss function
model_fn (function): model function
data (tuple): tuple of
Returns:
tuple: loss function, gradient function
"""
def mapped_loss(model_params, grouped_Xs, grouped_Ys):
"""Compute the loss function mapped over groups.
Args:
model_params: model parameters
grouped_Xs: dict of np.ndarrays of group datapoints
grouped_Ys: dict of float group outcomes
Returns:
mean loss over all groups"""
num_groups = len(grouped_Ys)
group_losses = jax.tree_util.tree_map(
lambda group_X, group_Y:
loss_fn(model_fn(model_params, group_X), group_Y) / num_groups,
grouped_Xs, grouped_Ys
)
loss = jax.tree_util.tree_reduce(lambda x, y: x + y, group_losses)
return loss
grad_and_loss = jax.jit(jax.value_and_grad(mapped_loss))
return grad_and_loss
Now I do have a large dataset, about 1000-8000 groups, with roughly 1-20 million datapoints partitioned between the groups. But the compile times I am seeing are very long, about 4-5 hours. I am wondering what, if anything, I can do to speed up compile times here.
It’s worth noting I am seeing very similar compile times for any analogous code:
def mapped_loss(model_params, grouped_Xs, grouped_Ys):
group_losses = jax.tree_util.tree_map(
lambda group_X, group_Y, group_N:
loss_fn(model_fn(model_params, group_X), group_Y),
grouped_Xs, grouped_Ys
)
return np.mean(np.array(list(group_losses.values())))
def mapped_loss(model_params, grouped_Xs, grouped_Ys):
loss = 0
for g in grouped_Xs:
loss += loss_fn(model_fn(model_params, grouped_Xs[g]), grouped_Ys[g])
return loss
Currently I am compiling the loss/gradient of the loss function, tree-mapping that, and using a non-compiled tree-reduce to actually compute the final gradient update. This works fine if a bit slow, but it makes it prohibitive to use more interesting optimizers from e.g. jaxopt which use jitted loops.
Maybe it is not possible to speed this up, but if it is I’d be grateful.
What jax/jaxlib version are you using?
jax v0.4.10
Which accelerator(s) are you using?
M1 CPU
Additional system info
Python 3.9.13
NVIDIA GPU info
No response
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 16
Yes! Thank you @patrick-kidger I think this noinline functionality is basically doing what I was thinking of re: “weak compilation”.
More generally I guess I think this behavior is odd, and I think it’d be quite nice to have an override (e.g. a parameter in jit that would let something act
jitted
for the purpose of use in other functions/recompilation without inlining every instance which I presume is what’s exploding the compile times). The reason I think it’s odd is that in my use case I can do the following:The compile time is a couple minutes to compile in total to run this with all my groups. The code is also very fast!
But if I run
loss_and_grad_fn = jax.jit(loss_and_grad_fn)
(or I compile any of the similar functions outlined above) that takes many hours to compile for qualitatively very little actual gain in runtime (I can actually run tests for this if it would be useful/interesting). And not being able to jit comile is annoying! Lots of code presumes you can jit internal functions almost without cost. Now presumably as Jake and Patrick suggest the general solution is to pad things out in some appropriate way. But this gets quite cumbersome to implement and understand (if you aren’t familiar with jax already). I think it would also be quite wasteful of memory which is a bottleneck in my problem.Now in my particular case I can re-write my codebase to use jax.experimental.sparse (I knew this from the start but was intentionally trying to steer clear of experimental code), so I will probably go that route! I guess I just want to gesture vaguely that I think it would be reasonable to support something like this allowing users to more explicitly tradeoff between compile time and run time in their specific setting with some options on jit like those Patrick is pointing to in equinox. On the other hand, maybe equinox is all I need and this doesn’t need to make its way into the base package.
Anyway, this has been quite helpful! Thank you all for the discussion. If folks feel this is resolved I’m happy to close the issue.
If you’re on the CPU then you can try wrapping your
model_fn
in anequinox.internal.noinline
. (See here.) This will compile the wrapped function only once for each group size. The downside is that this will probably only work on the CPU, and it dramatically increases runtime.The “real” fix here is indeed as Jake says – pad out the groups to just a few sizes, and then replace multiple calls
fn(group1); fn(group2)
to groups of the same size with ajax.vmap(fn)(jnp.concatenate([group1, group2])
.