triton: Flash Attention test fails for bfloat16

Hi, I have enabled bfloat16 testing in Triton (https://github.com/openai/triton/pull/1244/), but I’m getting this error with this data type

giorgio@giorgio:triton$ pytest python/test/unit/operators/test_flash_attention.py -s
========================================================================================== test session starts ==========================================================================================
platform linux -- Python 3.10.9, pytest-7.2.1, pluggy-1.0.0
rootdir: /usr/local/home/giorgio/triton/python
collected 2 items                                                                                                                                                                                       

python/test/unit/operators/test_flash_attention.py .error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: builtin.unrealized_conversion_cast
Failed to emit LLVM IR
Translate to LLVM IR failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
Fatal Python error: Aborted

Thread 0x00007fbc757fe6c0 (most recent call first):
  <no Python frame>

Current thread 0x00007fbd053ae200 (most recent call first):
  File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1018 in ttgir_to_llir
  File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1570 in <lambda>
  File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1637 in compile
  File "<string>", line 41 in _fwd_kernel
  File "/usr/local/home/giorgio/triton/python/triton/ops/flash_attention.py", line 214 in forward
  File "/usr/local/home/giorgio/triton/python/test/unit/operators/test_flash_attention.py", line 33 in test_op
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/python.py", line 195 in pytest_pyfunc_call
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/python.py", line 1789 in runtest
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 167 in pytest_runtest_call
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 260 in <lambda>
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 339 in from_call
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 259 in call_runtest_hook
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 220 in call_and_report
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 131 in runtestprotocol
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 324 in _main
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/config/__init__.py", line 167 in main
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/config/__init__.py", line 190 in console_main
  File "/usr/local/home/giorgio/.local/bin/pytest", line 8 in <module>

Extension modules: torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg.lapack_lite, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, __triton_launcher, cuda_utils (total: 23)
Aborted

Could this get fixed please? Thanks

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 22 (4 by maintainers)

Commits related to this issue

Most upvoted comments

yeah, or we can reduce the input size. Seems like these errors are pretty unlikely

We’re putting off a few fires now, but we’ll look into this more closely once things cool down