jax: Pallas implementation of attention doesn't work on CloudTPU

Description

import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops import attention

bs = 2
seqlen = 1000
n_heads = 32
dim = 128
rng = jax.random.PRNGKey(0)
xq = jax.random.normal(rng, (bs, seqlen, n_heads, dim))
xk = jax.random.normal(rng, (bs, seqlen, n_heads, dim))
xv = jax.random.normal(rng, (bs, seqlen, n_heads, dim))

print('reference')
res = attention.mha_reference(xq, xk, xv, None)
print(res)
print(res.shape)

print('real kernel')
print(attention.mha(xq, xk, xv, None))

Got:

hanq@t1v-n-cfe84bb3-w-0:~/llama$ python test_jax.py
reference

   ...
   [-0.68359375 -1.9609375  -1.2734375  ...  0.02490234  0.703125
     0.69921875]
   [-0.4453125  -0.5859375   1.28125    ... -0.24414062  1.21875
     0.47851562]
   [ 0.0859375  -1.703125   -0.4921875  ... -0.38476562  0.8828125
     0.09765625]]

  [[ 0.93359375 -0.6015625   1.4296875  ...  1.3515625   0.6796875
    -1.0859375 ]
   [ 0.875       0.22460938  1.0625     ...  1.421875   -0.09472656
    -0.18847656]
   [ 1.34375     0.02441406 -1.34375    ... -0.01501465 -0.90234375
    -0.31054688]
   ...
   [ 0.4609375   0.609375   -0.9921875  ... -1.015625    0.6796875
    -0.5546875 ]
   [-1.25       -0.59765625  0.1328125  ...  0.0390625   0.43945312
    -0.97265625]
   [ 1.140625   -0.17382812 -1.03125    ... -0.76171875  0.77734375
    -0.18261719]]

  ...

   [ 0.734375   -1.3203125   0.92578125 ...  0.9453125   0.18554688
     1.0078125 ]
   [ 0.96875    -1.5390625   0.47265625 ...  0.96875    -1.1640625
     0.29882812]
   ...
   [ 0.04296875  0.53515625  0.20410156 ...  1.28125    -0.69140625
    -0.10058594]
   [-0.04150391 -0.66796875 -1.078125   ... -1.09375    -0.4296875
     0.828125  ]
   [-0.5234375   0.43164062 -0.69140625 ...  1.1796875  -2.296875
     0.21972656]]

  ...

  [[ 0.06445312  0.40429688  0.03735352 ... -1.6796875   1.1796875
    -0.98828125]
   [-0.7109375  -0.34375    -0.23144531 ...  0.13085938 -0.47070312
    -0.21875   ]
   [-0.66796875 -0.02954102 -1.046875   ...  0.2421875   1.203125
    -0.42382812]
   ...
   [ 0.14257812  0.58984375  0.40234375 ... -0.01672363 -0.57421875
     1.046875  ]
   [-0.62890625 -1.1171875   0.84375    ... -0.35351562 -0.22558594
   ...
   [ 0.859375    1.9765625   0.54296875 ...  1.109375    0.05639648
     0.6796875 ]
   [-0.29101562 -1.9921875  -1.734375   ...  1.2265625   0.14453125
    -0.53125   ]
   [ 1.6484375  -0.40820312 -0.828125   ... -0.265625   -0.28320312
    -0.43164062]]

  [[-0.27539062 -1.8671875  -0.078125   ...  0.515625    0.90625
    -2.453125  ]
   [ 0.17773438 -0.11572266  0.5390625  ... -0.5546875  -0.40625
     0.9765625 ]
   [-0.05395508 -0.00325012 -0.08691406 ... -0.8046875  -0.03979492
     0.07666016]
   ...
   [-0.3359375   0.87890625 -1.453125   ...  1.1328125   0.46875
     0.65625   ]
   [ 0.30273438  0.546875    0.11083984 ...  0.98828125 -0.10791016
     1.4375    ]
   [ 1.4375     -0.46484375  1.71875    ...  1.265625    0.21386719
    -0.70703125]]]]
(2, 1000, 32, 128)
real kernel
Traceback (most recent call last):
  File "/home/hanq/llama/test_jax.py", line 20, in <module>
    print(attention.mha(xq, xk, xv, None))
  File "/home/hanq/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/attention.py", line 216, in mha
    return pl.pallas_call(
  File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 410, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: Internal TPU kernel compiler error: Loading elements out of bounds

The MLIR operation involved:
  %209 = "vector.load"(%14, %9, %0, %9, %9) : (memref<1x1000x1x128xf32, #tpu.memory_space<vmem>>, index, index, index, index) -> vector<1x128x1x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/hanq/llama/test_jax.py", line 20, in <module>
    print(attention.mha(xq, xk, xv, None))

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/hanq/llama/test_jax.py", line 20, in <module>
    print(attention.mha(xq, xk, xv, None))
  File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 88, in pallas_call_tpu_lowering_rule
    return mlir.lower_fun(_lower_fun, multiple_results=True)(
  File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 79, in _lower_fun
    return mosaic.as_tpu_kernel(
  File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 408, in as_tpu_kernel
    lowered_module_asm, constants = _lower_tpu_kernel(
  File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 335, in _lower_tpu_kernel
    _run_pass_pipeline(pipeline, module, "infer vector layout")
  File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 264, in _run_pass_pipeline
    raise RuntimeError("\n".join(msg)) from None
RuntimeError: Internal TPU kernel compiler error: Loading elements out of bounds

The MLIR operation involved:
  %209 = "vector.load"(%14, %9, %0, %9, %9) : (memref<1x1000x1x128xf32, #tpu.memory_space<vmem>>, index, index, index, index) -> vector<1x128x1x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

What jax/jaxlib version are you using?

jax==0.4.21.dev20231117 jaxlib==0.4.21.dev20231117

Which accelerator(s) are you using?

TPU

Additional system info

Python 3.10.2 Uname=Linux t1v-n-cfe84bb3-w-0 5.19.0-1022-gcp #24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux

NVIDIA GPU info

No response

About this issue

  • Original URL
  • State: closed
  • Created 7 months ago
  • Comments: 16 (8 by maintainers)

Most upvoted comments

When seqlen = 1 it gives ValueError: block_q=128 should be smaller or equal to q_seq_len=1. I understand that it probably it is by design at this point. However, I would like to ask for the support for the case of seqlen = 1 as this is the case during the decode phase for a typical LLM inference.

You probably want a slightly different kernel for decoding. On a TPU, naively using a sequence length of 1 will result in an extremely padded matmul. We might want to do a matrix vector product or somethign else.

It’s not faster than the reference implementation; which kinda of defeats the purpose of using a specialized kernel.

You’re seeing these results for 2 reasons:

  1. You are using default block sizes, which tend to be slow. Try passing in block_sizes (https://github.com/google/jax/blob/c855bb0371fd7df3e2c33c0d153a23299b4f1988/jax/experimental/pallas/ops/tpu/flash_attention.py#L148) and sweep over larger ones. You should see significant performance improvements.
  2. You are using a somewhat small sequence length. For sequence lengths <= 4096, XLA has some fusions that do something similar to flash attention so the expected improvement over XLA isn’t that big. Once you go to 8k and above, you should see much bigger improvements.

Okay, I’ll do a more thorough investigation tomorrow.

Wrt to dimension order, the kernel expects a sequence length dimension (ie something usually bigger than number of heads) in the second to last position.