jax: Extremely slow GPU execution
The following code is almost instantaneous (<1ms) on the CPU, but is extremely slow on the GPU (7s). I’m trying to track down the source of the problem. I have pared down my code from 5000 lines down to 80 lines, and I don’t think I can remove any more. I have added comments in places that I found that have surprising (to me) effects on the GPU run time.
How can I make this code run faster on the GPU than it does on the CPU? What am I doing wrong?
from functools import partial
from typing import Any
import haiku as hk
import jax.numpy as jnp
from contexttimer import Timer
from jax import jit
from jax.experimental import enable_x64
from jax.lax import while_loop
from jax.nn import sigmoid, softplus
from jax.random import PRNGKey, normal, split
from tjax.dataclasses import dataclass # Equivalent to flax.struct.dataclass
class Linear(hk.Module):
def __init__(self, output_size: int):
super().__init__()
self.output_size = output_size
def __call__(self, inputs):
w = hk.get_parameter("w", [inputs.shape[-1], self.output_size],
inputs.dtype, # Passing dtype costs 23%!
init=jnp.zeros)
# Calling softplus costs 32%!
return jnp.dot(inputs, softplus(w))
class NoisyMLP(hk.Module):
def __init__(self, layer_sizes):
super().__init__()
self.layers = [Linear(output_size) for output_size in layer_sizes]
def __call__(self, inputs):
out = inputs
for layer in self.layers:
out = layer(out)
out = sigmoid(out) # Sigmoid costs 10%!
return out
@dataclass
class SamplerState:
code_momentum: Any
rng: Any
iterations: Any
shape = (1,)
def nat_to_exp(natural_explanation):
mlp = NoisyMLP((12, *shape))
return mlp(natural_explanation)
def haiku_weight_initializer() -> None:
nat_to_exp(jnp.zeros(shape))
def state_needs_iteration(maximum_iterations, state) -> bool:
return state.iterations < maximum_iterations
def update_state(weights, state):
leak_rng, new_rng = split(state.rng)
nat_to_exp_f = hk.transform(nat_to_exp).apply
force = nat_to_exp_f(weights, None, state.code_momentum)
new_code_momentum = force + normal(leak_rng, force.shape)
return SamplerState(new_code_momentum, new_rng, state.iterations + 1)
def find_fixed_point(weights, initial_state, maximum_iterations):
return while_loop(partial(state_needs_iteration, maximum_iterations),
partial(update_state, weights),
initial_state)
@partial(jit, static_argnums=()) # Passing maximum_iterations non-statically costs 43%!
def infer_encoding(weights, initial_rng, maximum_iterations):
initial_sampler_state = SamplerState(jnp.zeros(shape), initial_rng, 0)
return find_fixed_point(weights, initial_sampler_state, maximum_iterations)
with enable_x64(): # Enabling 64-bit costs 50%.
rng = PRNGKey(12)
weight_rng, inference_rng = split(rng)
weights = hk.transform(haiku_weight_initializer).init(weight_rng)
for _ in range(10):
with Timer() as timer:
infer_encoding(weights, inference_rng, 8000)
print(timer.elapsed)
About this issue
- Original URL
- State: open
- Created 3 years ago
- Comments: 30 (22 by maintainers)
There’s good documentation on how to use the trace viewer, but unfortunately it looks like it’s all Google-internal, and hasn’t been open-sourced yet…
Thanks for raising this, and for working so hard to minimize it!
The best tool here is to use profiling. If you can get a profile showing a realistic workload, we can really dig in to what improvements can be made (either to your code, to JAX itself, or to XLA:GPU).
There’s one effect that would explain one of your comments, though I don’t think it would explain the code as written being slow. General while_loops can require returning control to the host on each iteration just to decide whether to dispatch another iteration of the loop body on the GPU, incurring expensive synchronization and transfer overheads (which would loom large when the loop body itself is cheap). But in XLA:GPU there’s a “for loop optimization” which is meant to notice when the loop actually has a statically fixed trip count (as it does here, at least with the code as written!) so that control need not be returned to the host on each iteration.
Could you share a profile of the execution so we can dig in?
I understand. What I’m trying to say is that the custom kernel writing that I would be doing is tantamount to compilation, which is already done by the XLA compiler. I want to simply write in Jax using its primitives. Triton is a different language. As you point out in your comment, it is possible to convert Jax to Triton, so that’s what I’m asking for when I say “I want to program in the Jax I love”.
Just so we’re clear, I’m suggesting that instead of writing Triton code, I would write a rudimentary module-to-Triton converter.
You are right that writing a single Triton kernel is less work. However, I don’t think this is a good approach in my case:
Can you elaborate on this?
I’d like to generalize my suggestion. What I want is for XLA to produce fused kernels rather than the kernels it’s producing now. Why can’t it produce fused kernels? If I can write fused kernels in Triton, surely the XLA compiler can produce such kernels?
The benefit of XLA producing such kernels is that I wouldn’t have to worry about producing backwards passes, which involves:
Yes, I considered something like this. This would be a fantastic step in the right direction. If you consider my more general suggestion, then what I really want is a decorator to demand that the XLA compiler produces a fused kernel for a decorated function.
PS Salut Frederic! We met ages ago when you were working on Theano. Very nice to see you here 😄
First of all, Jax Triton looks amazing! Yes, it should solve my problem with quite of bit of work on my side. So thank you for that.
However, I have some thoughts that I’d like to get feedback on.
My problem boils down to an internal scan that evaluates something like
Where
f
is a “forwards pass” function of primals, andf_bwd
is the corresponding backward pass of cotangents.If
f
is a simple neural network with noise, then it’s fairly straightforward to write this in Triton. The backward pass can easily be written, but it’s annoying. Why am I doing this? Jax is already calculating the backward pass, and I might make mistakes that I’ll have to debug. That’s what I meant when I asked if I would have access to Jax’s automatic differentiation. It appears that I’ll have to manually differentiatef
and then implement that in Triton.I also thought about how I would write this in Triton. I could just manually write every fused kernel I need. And at the end of it, I’d have a library of pieces of kernels that I could compose to do what I need. These would probably be extra methods on “modules” (from Haiku or Flax) that would do things like:
Then I would have some way of composing multiple modules into a single fused kernel. This entails two functions
triton_forwards
andtriton_backwards
that would produce a forward or backward Triton function. Each functionThen I thought: why am I doing all this? Wouldn’t it make much more sense to have a conversion from XLA to Triton?
I understand that Triton is a very limited language. I understand that it may not be possible to convert everything that XLA can do to Triton. But I’m not doing anything that crazy. If the converter wants to bail out if I try to do something like take a hyperbolic sin, that’s fine! I’m just doing ordinary multiplications, exponentiations, addition, etc.
And I remember Matt explaining to me that Nvidia’s kernels (e.g. matrix multiplication) are better optimized than anything the user can do. But I’m pretty sure that the last time I looked at this, my runtime is dominated by kernel spawning. Even if Triton is 50% as fast as Nvidia’s hand-crafted kernels, the ability to fuse literally hundreds of kernels together would more than compensate. And the reason it’s hundreds is because I have a scan (described above), and each iteration of the scan is a whole new set of kernel spawning.
So, my question boils down to: Why have we decided on Jax-Triton as the solution? Why not convert XLA to Triton as best you can, and then we can keep programming in the Jax we love?
You’d need to write differentiation rules for any new Primitives introduced. In the dfm tutorial there’s code in src/kepler_jax/kepler_jax.py which shows how to define a JVP rule for the primitive introduced.
These custom differentiation rules are not meant for custom kernels (though they kinda work there…) but rather for JAX-traceable code. When introducing a new Primitive, as in the dfm tutorial (and as I would recommend for a custom kernel), you just attach transformation rules to the primitive directly.
Once you have a differentiation rule for your primitive, you can differentiate any function that applies it along with other JAX primitives.
(Someday JAX may be able to generate derivatives for Triton code automatically. It’s something we’re looking at, but it’s a long way off.)
I only very briefly looked at the custom op myself. I’m sure that you can register another custom op for the gradient. But there is also other option I think.
@mattjj, do you know if we can provide a forward graph where the gradient will be taken of for the custom op gradient?
Neil, at worst, you can print your forward and backward graph. From this, you can find that is the gradient graph. Then you can create a jax graph that does it and ask JAX to use it for the gradient: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
Hi Neil,
Did you saw this jax/triton project https://github.com/jax-ml/jax-triton?
Also, what are your shapes? In this example, you have:
shape = (1,)
. If you have layer of only 1 neuron, this is very very tiny shapes.XLA:GPU is all open-source!
XLA:GPU’s hands are tied here, because the way to generate the fastest GPU kernels for matmul and conv are proprietary. Only Nvidia can do it, and they release those kernels as binaries. That’s what cuBLAS, cuDNN, etc are. There are many such kernels; XLA:GPU does autotuning and kernel selection to choose the best routines for your array shapes and specific GPU hardware. See here for example. But because these are proprietary pre-built kernels, it can’t e.g. fuse operations into them. That’s why loop bodies may have to be separated into multiple separate kernels.
Well, I wouldn’t suggest using
scan
, because I’m pretty sure XLA will do that optimization for you when you make the trip count static (as you did in the example). So there’s nothing additional to be gained by usingscan
I don’t think.Yes that comment is related.
For this part, because XLA:GPU is likely calling into pre-packaged CUDA kernels for the dot (unless it knows it can generate better code than the closed-source Nvidia-provided kernels, which is rare), adding the softplus may mean that you have to launch two kernels (one for the dot, one for the presumably fused xla-generated softplus computation) per call to the MLP (i.e. at least two per loop iteration), rather than just one (just the dot).
(Tangentially, XLA:TPU has much more flexibility here: since it generates the dot routine too, it can fuse things like elementwise operations into the loads and stores of the dot operation, and indeed on TPU any
jit
ted function leads to one big optimized XLA:TPU program, rather than separate kernels as on XLA:GPU. )By the way, if instead of splitting the RNG on every iteration, you just split it once into a big array (with leading axis size
maximum_iterations
) and, say, scan over it (or just index into it with the iteration counter). That can also save kernel launches in the loop body, though it’ll mean your program uses more memory.The overall theme here is to try to minimize the number of kernel launches per loop iteration.
I haven’t looked at your profile yet, but I’ll try to get the chance soon!