triton: type conversion before tl.dot fails compilation
Context: we are trying to upgrade trition pin in pytorch and start encountering test failure for python test/inductor/test_pattern_matcher.py -k test_mixed_mm
Doing a type conversion for an input tensor to tl.dot starts to fail. It works in triton commit: e6216047b8b0aef1fe8da6ca8667a3ad0a016411 , but fail on a recent commit 0410652666ddcad54aa4c403276b478a36379b90 .
Here is a standalone repro w.py: https://gist.github.com/shunting314/3a3b8ce1ccee7b51b8ee0d9a2d24dd3d
Running python w.py will report the following error in the new commit:
loc("/tmp/torchinductor_shunting/bm/cbm7qsh5esh6xdkdddmv7l2ilel4kdbfwgy2luolzmme62njagrb.py":64:17): error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
Failed to emit LLVM IR
Translate to LLVM IR failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
Aborted (core dumped)
Here is a same repro but with inductor dependencies removed: https://gist.github.com/shunting314/4eb4a6a7d8cbfc6396726518196e26d1
About this issue
- Original URL
- State: open
- Created 10 months ago
- Comments: 20 (10 by maintainers)
Commits related to this issue
- [BACKEND] Fix BF16 dot operand type mismatch (#2162) https://github.com/openai/triton/issues/2156 — committed to openai/triton by Jokeren 10 months ago
- [BACKEND] Fix BF16 dot operand type mismatch (#2162) https://github.com/openai/triton/issues/2156 — committed to ROCm/triton by Jokeren 10 months ago
- [BACKEND] Fix BF16 dot operand type mismatch (#2162) https://github.com/openai/triton/issues/2156 — committed to siliconflow/triton by Jokeren 10 months ago
@Jokeren Just FYI, we work around the issue regarding int8 by make sure the BLOCK SIZE is at least 32. In case you figure out why it works in previous commit later, let us know 😃
Sounds good. Thank you for all your efforts unblocking us upgrading triton version in pytorch!