triton: Custom tensor roll kernel produces wrong results for variations of block size and tl.constexpr
Dear triton team,
I am currently debugging an issue with a kernel that is supposed to replace torch.roll followed by zeroing out the first row of a 2D matrix. This is the code I have:
import torch
import triton
import triton.language as tl
@triton.jit
def triton_roll_and_zero_first_row_kernel(in_ptr, out_ptr, NUM_ROWS: tl.constexpr, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)
row_idx = xindex // 2
col_idx = xindex % 2
rolled_row = (row_idx + 1) % NUM_ROWS
rolled_xindex = 2 * rolled_row + col_idx
result = tl.load(in_ptr + rolled_xindex)
result = tl.where(row_idx == 0, 0.0, result)
tl.store(out_ptr + xindex, result)
def triton_roll_and_zero_first_row(x):
assert x.size(1) == 2
y = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(inp.numel(), meta['XBLOCK']), )
triton_roll_and_zero_first_row_kernel[grid](x, y, NUM_ROWS=x.size(0), XBLOCK=256)
return y
def roll_and_zero_first_row(x):
x = torch.roll(x, -1, 0)
x[0].fill_(0.0)
return x
if __name__ == "__main__":
inp = torch.rand(1024, 2, device="cuda")
out_eager = roll_and_zero_first_row(inp)
out_triton = triton_roll_and_zero_first_row(inp)
print("eager", out_eager)
print("triton", out_triton)
assert torch.all(out_eager == out_triton)
print("PASSED")
In this form the assert does not pass on my system. It passes if I either set XBLOCK=128 or if I remove tl.constexpr from NUM_ROWS. This strikes me as odd. Could you please help me to understand this behaviour?
The blueprint for this kernel was actually produced applying torch.compile to roll_and_zero_first_row and it is affected by this issue as well. I just renamed a few variables and reordered code to make it human-readable. If you can confirm this behavior I would open an issue at pytorch as well.
I am currently on triton-nightly==2.1.0.post20240108192258, torch==2.1.2 and CUDA 12.1
About this issue
- Original URL
- State: open
- Created 5 months ago
- Comments: 25 (4 by maintainers)
If we “export DISABLE_LLVM_OPT=1”, the test case will pass. So it is related to llvm optimizations.
Reply from NVidia: This is a ptxas optimization bug that we’ve already fixed internally. It will be available in an update release of CUDA 12.4 soon. The optimization was incorrectly transforming a LOP3 (which supports predicate output) into a PRMT (which doesn’t support predicate out).
If it’s very possible that there’s a ptxas bug, we could prepare a reproducer and sent it to nvidia compiler folks. They usually respond quickly based on my experience.
Disabling optimization for ptxas also fixed the problem. def-sass.txt no-ptxas-opt-sass.txt
Patch to enable debugging: https://github.com/openai/triton/pull/2995