jax: Pallas implementation of attention doesn't work on CloudTPU
Description
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops import attention
bs = 2
seqlen = 1000
n_heads = 32
dim = 128
rng = jax.random.PRNGKey(0)
xq = jax.random.normal(rng, (bs, seqlen, n_heads, dim))
xk = jax.random.normal(rng, (bs, seqlen, n_heads, dim))
xv = jax.random.normal(rng, (bs, seqlen, n_heads, dim))
print('reference')
res = attention.mha_reference(xq, xk, xv, None)
print(res)
print(res.shape)
print('real kernel')
print(attention.mha(xq, xk, xv, None))
Got:
hanq@t1v-n-cfe84bb3-w-0:~/llama$ python test_jax.py
reference
...
[-0.68359375 -1.9609375 -1.2734375 ... 0.02490234 0.703125
0.69921875]
[-0.4453125 -0.5859375 1.28125 ... -0.24414062 1.21875
0.47851562]
[ 0.0859375 -1.703125 -0.4921875 ... -0.38476562 0.8828125
0.09765625]]
[[ 0.93359375 -0.6015625 1.4296875 ... 1.3515625 0.6796875
-1.0859375 ]
[ 0.875 0.22460938 1.0625 ... 1.421875 -0.09472656
-0.18847656]
[ 1.34375 0.02441406 -1.34375 ... -0.01501465 -0.90234375
-0.31054688]
...
[ 0.4609375 0.609375 -0.9921875 ... -1.015625 0.6796875
-0.5546875 ]
[-1.25 -0.59765625 0.1328125 ... 0.0390625 0.43945312
-0.97265625]
[ 1.140625 -0.17382812 -1.03125 ... -0.76171875 0.77734375
-0.18261719]]
...
[ 0.734375 -1.3203125 0.92578125 ... 0.9453125 0.18554688
1.0078125 ]
[ 0.96875 -1.5390625 0.47265625 ... 0.96875 -1.1640625
0.29882812]
...
[ 0.04296875 0.53515625 0.20410156 ... 1.28125 -0.69140625
-0.10058594]
[-0.04150391 -0.66796875 -1.078125 ... -1.09375 -0.4296875
0.828125 ]
[-0.5234375 0.43164062 -0.69140625 ... 1.1796875 -2.296875
0.21972656]]
...
[[ 0.06445312 0.40429688 0.03735352 ... -1.6796875 1.1796875
-0.98828125]
[-0.7109375 -0.34375 -0.23144531 ... 0.13085938 -0.47070312
-0.21875 ]
[-0.66796875 -0.02954102 -1.046875 ... 0.2421875 1.203125
-0.42382812]
...
[ 0.14257812 0.58984375 0.40234375 ... -0.01672363 -0.57421875
1.046875 ]
[-0.62890625 -1.1171875 0.84375 ... -0.35351562 -0.22558594
...
[ 0.859375 1.9765625 0.54296875 ... 1.109375 0.05639648
0.6796875 ]
[-0.29101562 -1.9921875 -1.734375 ... 1.2265625 0.14453125
-0.53125 ]
[ 1.6484375 -0.40820312 -0.828125 ... -0.265625 -0.28320312
-0.43164062]]
[[-0.27539062 -1.8671875 -0.078125 ... 0.515625 0.90625
-2.453125 ]
[ 0.17773438 -0.11572266 0.5390625 ... -0.5546875 -0.40625
0.9765625 ]
[-0.05395508 -0.00325012 -0.08691406 ... -0.8046875 -0.03979492
0.07666016]
...
[-0.3359375 0.87890625 -1.453125 ... 1.1328125 0.46875
0.65625 ]
[ 0.30273438 0.546875 0.11083984 ... 0.98828125 -0.10791016
1.4375 ]
[ 1.4375 -0.46484375 1.71875 ... 1.265625 0.21386719
-0.70703125]]]]
(2, 1000, 32, 128)
real kernel
Traceback (most recent call last):
File "/home/hanq/llama/test_jax.py", line 20, in <module>
print(attention.mha(xq, xk, xv, None))
File "/home/hanq/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/attention.py", line 216, in mha
return pl.pallas_call(
File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 410, in wrapped
out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: Internal TPU kernel compiler error: Loading elements out of bounds
The MLIR operation involved:
%209 = "vector.load"(%14, %9, %0, %9, %9) : (memref<1x1000x1x128xf32, #tpu.memory_space<vmem>>, index, index, index, index) -> vector<1x128x1x128xf32>
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/hanq/llama/test_jax.py", line 20, in <module>
print(attention.mha(xq, xk, xv, None))
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/hanq/llama/test_jax.py", line 20, in <module>
print(attention.mha(xq, xk, xv, None))
File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 88, in pallas_call_tpu_lowering_rule
return mlir.lower_fun(_lower_fun, multiple_results=True)(
File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 79, in _lower_fun
return mosaic.as_tpu_kernel(
File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 408, in as_tpu_kernel
lowered_module_asm, constants = _lower_tpu_kernel(
File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 335, in _lower_tpu_kernel
_run_pass_pipeline(pipeline, module, "infer vector layout")
File "/home/hanq/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 264, in _run_pass_pipeline
raise RuntimeError("\n".join(msg)) from None
RuntimeError: Internal TPU kernel compiler error: Loading elements out of bounds
The MLIR operation involved:
%209 = "vector.load"(%14, %9, %0, %9, %9) : (memref<1x1000x1x128xf32, #tpu.memory_space<vmem>>, index, index, index, index) -> vector<1x128x1x128xf32>
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
What jax/jaxlib version are you using?
jax==0.4.21.dev20231117 jaxlib==0.4.21.dev20231117
Which accelerator(s) are you using?
TPU
Additional system info
Python 3.10.2 Uname=Linux t1v-n-cfe84bb3-w-0 5.19.0-1022-gcp #24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux
NVIDIA GPU info
No response
About this issue
- Original URL
- State: closed
- Created 7 months ago
- Comments: 16 (8 by maintainers)
You probably want a slightly different kernel for decoding. On a TPU, naively using a sequence length of 1 will result in an extremely padded matmul. We might want to do a matrix vector product or somethign else.
You’re seeing these results for 2 reasons:
block_sizes
(https://github.com/google/jax/blob/c855bb0371fd7df3e2c33c0d153a23299b4f1988/jax/experimental/pallas/ops/tpu/flash_attention.py#L148) and sweep over larger ones. You should see significant performance improvements.Okay, I’ll do a more thorough investigation tomorrow.
Wrt to dimension order, the kernel expects a sequence length dimension (ie something usually bigger than number of heads) in the second to last position.