jax: Extremely slow GPU execution

The following code is almost instantaneous (<1ms) on the CPU, but is extremely slow on the GPU (7s). I’m trying to track down the source of the problem. I have pared down my code from 5000 lines down to 80 lines, and I don’t think I can remove any more. I have added comments in places that I found that have surprising (to me) effects on the GPU run time.

How can I make this code run faster on the GPU than it does on the CPU? What am I doing wrong?

from functools import partial
from typing import Any
import haiku as hk
import jax.numpy as jnp
from contexttimer import Timer
from jax import jit
from jax.experimental import enable_x64
from jax.lax import while_loop
from jax.nn import sigmoid, softplus
from jax.random import PRNGKey, normal, split
from tjax.dataclasses import dataclass  # Equivalent to flax.struct.dataclass

class Linear(hk.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.output_size = output_size

    def __call__(self, inputs):
        w = hk.get_parameter("w", [inputs.shape[-1], self.output_size],
                             inputs.dtype,  # Passing dtype costs 23%!
                             init=jnp.zeros)
        # Calling softplus costs 32%!
        return jnp.dot(inputs, softplus(w))

class NoisyMLP(hk.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = [Linear(output_size) for output_size in layer_sizes]

    def __call__(self, inputs):
        out = inputs
        for layer in self.layers:
            out = layer(out)
            out = sigmoid(out)  # Sigmoid costs 10%!
        return out

@dataclass
class SamplerState:
    code_momentum: Any
    rng: Any
    iterations: Any

shape = (1,)

def nat_to_exp(natural_explanation):
    mlp = NoisyMLP((12, *shape))
    return mlp(natural_explanation)

def haiku_weight_initializer() -> None:
    nat_to_exp(jnp.zeros(shape))

def state_needs_iteration(maximum_iterations, state) -> bool:
    return state.iterations < maximum_iterations

def update_state(weights, state):
    leak_rng, new_rng = split(state.rng)
    nat_to_exp_f = hk.transform(nat_to_exp).apply
    force = nat_to_exp_f(weights, None, state.code_momentum)
    new_code_momentum = force + normal(leak_rng, force.shape)
    return SamplerState(new_code_momentum, new_rng, state.iterations + 1)

def find_fixed_point(weights, initial_state, maximum_iterations):
    return while_loop(partial(state_needs_iteration, maximum_iterations),
                      partial(update_state, weights),
                      initial_state)

@partial(jit, static_argnums=())  # Passing maximum_iterations non-statically costs 43%!
def infer_encoding(weights, initial_rng, maximum_iterations):
    initial_sampler_state = SamplerState(jnp.zeros(shape), initial_rng, 0)
    return find_fixed_point(weights, initial_sampler_state, maximum_iterations)

with enable_x64():  # Enabling 64-bit costs 50%.
    rng = PRNGKey(12)
    weight_rng, inference_rng = split(rng)
    weights = hk.transform(haiku_weight_initializer).init(weight_rng)
    for _ in range(10):
        with Timer() as timer:
            infer_encoding(weights, inference_rng, 8000)
        print(timer.elapsed)

About this issue

  • Original URL
  • State: open
  • Created 3 years ago
  • Comments: 30 (22 by maintainers)

Most upvoted comments

There’s good documentation on how to use the trace viewer, but unfortunately it looks like it’s all Google-internal, and hasn’t been open-sourced yet…

Thanks for raising this, and for working so hard to minimize it!

The best tool here is to use profiling. If you can get a profile showing a realistic workload, we can really dig in to what improvements can be made (either to your code, to JAX itself, or to XLA:GPU).

There’s one effect that would explain one of your comments, though I don’t think it would explain the code as written being slow. General while_loops can require returning control to the host on each iteration just to decide whether to dispatch another iteration of the loop body on the GPU, incurring expensive synchronization and transfer overheads (which would loom large when the loop body itself is cheap). But in XLA:GPU there’s a “for loop optimization” which is meant to notice when the loop actually has a statically fixed trip count (as it does here, at least with the code as written!) so that control need not be returned to the host on each iteration.

Could you share a profile of the execution so we can dig in?

You do not need to do any memory allocation. You still use JAX. You write the custom kernel that XLA doesn’t generate fast code for.

I understand. What I’m trying to say is that the custom kernel writing that I would be doing is tantamount to compilation, which is already done by the XLA compiler. I want to simply write in Jax using its primitives. Triton is a different language. As you point out in your comment, it is possible to convert Jax to Triton, so that’s what I’m asking for when I say “I want to program in the Jax I love”.

I think it is less work then what you describe above.

Just so we’re clear, I’m suggesting that instead of writing Triton code, I would write a rudimentary module-to-Triton converter.

You are right that writing a single Triton kernel is less work. However, I don’t think this is a good approach in my case:

  • As my research ideas change, I will then have to rewrite that kernel many times. The automated conversion would mean not having to rewrite Triton code.
  • Any errors I make in transcribing Jax code to Triton will require debugging time. The automated conversion (once it works) will not require debugging.
  • The automated conversion means that my code will be easier to read, understand, and therefore debug since my code would comprise a graph of module calls.

Your idea of having a XLA/triton backend would be hard to implement as we currently can’t guide what to do in triton vs XLA.

Can you elaborate on this?

I’d like to generalize my suggestion. What I want is for XLA to produce fused kernels rather than the kernels it’s producing now. Why can’t it produce fused kernels? If I can write fused kernels in Triton, surely the XLA compiler can produce such kernels?

The benefit of XLA producing such kernels is that I wouldn’t have to worry about producing backwards passes, which involves:

  • computing and printing the grad graph, and
  • implementing it in Triton and all of the concomitant debugging.

Maybe a simpler thing could be a decorator on JAX expression that get converted to Triton.

Yes, I considered something like this. This would be a fantastic step in the right direction. If you consider my more general suggestion, then what I really want is a decorator to demand that the XLA compiler produces a fused kernel for a decorated function.

PS Salut Frederic! We met ages ago when you were working on Theano. Very nice to see you here 😄

First of all, Jax Triton looks amazing! Yes, it should solve my problem with quite of bit of work on my side. So thank you for that.

However, I have some thoughts that I’d like to get feedback on.

My problem boils down to an internal scan that evaluates something like

x[i+1] = x[i] + k * f_bwd(z - f(x[i]))

Where f is a “forwards pass” function of primals, and f_bwd is the corresponding backward pass of cotangents.

If f is a simple neural network with noise, then it’s fairly straightforward to write this in Triton. The backward pass can easily be written, but it’s annoying. Why am I doing this? Jax is already calculating the backward pass, and I might make mistakes that I’ll have to debug. That’s what I meant when I asked if I would have access to Jax’s automatic differentiation. It appears that I’ll have to manually differentiate f and then implement that in Triton.

I also thought about how I would write this in Triton. I could just manually write every fused kernel I need. And at the end of it, I’d have a library of pieces of kernels that I could compose to do what I need. These would probably be extra methods on “modules” (from Haiku or Flax) that would do things like:

  • report the shapes and dtypes of the inputs and outputs,
  • allocate the intermediate storage,
  • sample from a Jax PRNG whatever random numbers are necessary,
  • produce Triton code for the forward and backward pass (without the decorated function header–just the code)

Then I would have some way of composing multiple modules into a single fused kernel. This entails two functions triton_forwards and triton_backwards that would produce a forward or backward Triton function. Each function

  • has parameters that could be a graph of “modules”.
  • allocates all of the intermediate storage it needs (maybe even reusing space if possible),
  • samples all of the random numbers,
  • then call a jitted Triton function.
  • The jitted Triton function would contain various sub-calls to ordinary Python functions corresponding to each “module”, and would have Triton pointers to the intermediate data structures, the inputs and outputs, and the intermediate storage.

Then I thought: why am I doing all this? Wouldn’t it make much more sense to have a conversion from XLA to Triton?

I understand that Triton is a very limited language. I understand that it may not be possible to convert everything that XLA can do to Triton. But I’m not doing anything that crazy. If the converter wants to bail out if I try to do something like take a hyperbolic sin, that’s fine! I’m just doing ordinary multiplications, exponentiations, addition, etc.

And I remember Matt explaining to me that Nvidia’s kernels (e.g. matrix multiplication) are better optimized than anything the user can do. But I’m pretty sure that the last time I looked at this, my runtime is dominated by kernel spawning. Even if Triton is 50% as fast as Nvidia’s hand-crafted kernels, the ability to fuse literally hundreds of kernels together would more than compensate. And the reason it’s hundreds is because I have a scan (described above), and each iteration of the scan is a whole new set of kernel spawning.

So, my question boils down to: Why have we decided on Jax-Triton as the solution? Why not convert XLA to Triton as best you can, and then we can keep programming in the Jax we love?

You’d need to write differentiation rules for any new Primitives introduced. In the dfm tutorial there’s code in src/kepler_jax/kepler_jax.py which shows how to define a JVP rule for the primitive introduced.

These custom differentiation rules are not meant for custom kernels (though they kinda work there…) but rather for JAX-traceable code. When introducing a new Primitive, as in the dfm tutorial (and as I would recommend for a custom kernel), you just attach transformation rules to the primitive directly.

Once you have a differentiation rule for your primitive, you can differentiate any function that applies it along with other JAX primitives.

(Someday JAX may be able to generate derivatives for Triton code automatically. It’s something we’re looking at, but it’s a long way off.)

I only very briefly looked at the custom op myself. I’m sure that you can register another custom op for the gradient. But there is also other option I think.

@mattjj, do you know if we can provide a forward graph where the gradient will be taken of for the custom op gradient?

Neil, at worst, you can print your forward and backward graph. From this, you can find that is the gradient graph. Then you can create a jax graph that does it and ask JAX to use it for the gradient: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

Hi Neil,

Did you saw this jax/triton project https://github.com/jax-ml/jax-triton?

Also, what are your shapes? In this example, you have: shape = (1,). If you have layer of only 1 neuron, this is very very tiny shapes.

Is this XLA generator code something I can look at?

XLA:GPU is all open-source!

That would definitely explain some of the slowness, but I don’t see how that’s better than just generating one kernel to do both things at once? Even if this were applied to an array with a million elements, it still seems like it would be faster spawn one kernel instead of two.

XLA:GPU’s hands are tied here, because the way to generate the fastest GPU kernels for matmul and conv are proprietary. Only Nvidia can do it, and they release those kernels as binaries. That’s what cuBLAS, cuDNN, etc are. There are many such kernels; XLA:GPU does autotuning and kernel selection to choose the best routines for your array shapes and specific GPU hardware. See here for example. But because these are proprietary pre-built kernels, it can’t e.g. fuse operations into them. That’s why loop bodies may have to be separated into multiple separate kernels.

I guess I should convert these to jax.lax.scan?

Well, I wouldn’t suggest using scan, because I’m pretty sure XLA will do that optimization for you when you make the trip count static (as you did in the example). So there’s nothing additional to be gained by using scan I don’t think.

Yes that comment is related.

        # Calling softplus costs 32%!
        return jnp.dot(inputs, softplus(w))

For this part, because XLA:GPU is likely calling into pre-packaged CUDA kernels for the dot (unless it knows it can generate better code than the closed-source Nvidia-provided kernels, which is rare), adding the softplus may mean that you have to launch two kernels (one for the dot, one for the presumably fused xla-generated softplus computation) per call to the MLP (i.e. at least two per loop iteration), rather than just one (just the dot).

(Tangentially, XLA:TPU has much more flexibility here: since it generates the dot routine too, it can fuse things like elementwise operations into the loads and stores of the dot operation, and indeed on TPU any jitted function leads to one big optimized XLA:TPU program, rather than separate kernels as on XLA:GPU. )

By the way, if instead of splitting the RNG on every iteration, you just split it once into a big array (with leading axis size maximum_iterations) and, say, scan over it (or just index into it with the iteration counter). That can also save kernel launches in the loop body, though it’ll mean your program uses more memory.

The overall theme here is to try to minimize the number of kernel launches per loop iteration.

I haven’t looked at your profile yet, but I’ll try to get the chance soon!