triton: Custom tensor roll kernel produces wrong results for variations of block size and tl.constexpr

Dear triton team,

I am currently debugging an issue with a kernel that is supposed to replace torch.roll followed by zeroing out the first row of a 2D matrix. This is the code I have:

import torch

import triton
import triton.language as tl

@triton.jit
def triton_roll_and_zero_first_row_kernel(in_ptr, out_ptr, NUM_ROWS: tl.constexpr, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    row_idx = xindex // 2
    col_idx = xindex % 2
    rolled_row = (row_idx + 1) % NUM_ROWS
    rolled_xindex = 2 * rolled_row + col_idx
    result = tl.load(in_ptr + rolled_xindex)
    result = tl.where(row_idx == 0, 0.0, result)
    tl.store(out_ptr + xindex, result)

def triton_roll_and_zero_first_row(x):
    assert x.size(1) == 2
    y = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(inp.numel(), meta['XBLOCK']), )
    triton_roll_and_zero_first_row_kernel[grid](x, y, NUM_ROWS=x.size(0), XBLOCK=256)
    return y

def roll_and_zero_first_row(x):
    x = torch.roll(x, -1, 0)
    x[0].fill_(0.0)
    return x

if __name__ == "__main__":
    inp = torch.rand(1024, 2, device="cuda")

    out_eager = roll_and_zero_first_row(inp)
    out_triton = triton_roll_and_zero_first_row(inp)

    print("eager", out_eager)
    print("triton", out_triton)
    assert torch.all(out_eager == out_triton)
    print("PASSED")

In this form the assert does not pass on my system. It passes if I either set XBLOCK=128 or if I remove tl.constexpr from NUM_ROWS. This strikes me as odd. Could you please help me to understand this behaviour?

The blueprint for this kernel was actually produced applying torch.compile to roll_and_zero_first_row and it is affected by this issue as well. I just renamed a few variables and reordered code to make it human-readable. If you can confirm this behavior I would open an issue at pytorch as well.

I am currently on triton-nightly==2.1.0.post20240108192258, torch==2.1.2 and CUDA 12.1

About this issue

  • Original URL
  • State: open
  • Created 5 months ago
  • Comments: 25 (4 by maintainers)

Most upvoted comments

If we “export DISABLE_LLVM_OPT=1”, the test case will pass. So it is related to llvm optimizations.

Reply from NVidia: This is a ptxas optimization bug that we’ve already fixed internally. It will be available in an update release of CUDA 12.4 soon. The optimization was incorrectly transforming a LOP3 (which supports predicate output) into a PRMT (which doesn’t support predicate out).

If it’s very possible that there’s a ptxas bug, we could prepare a reproducer and sent it to nvidia compiler folks. They usually respond quickly based on my experience.

Disabling optimization for ptxas also fixed the problem. def-sass.txt no-ptxas-opt-sass.txt

Patch to enable debugging: https://github.com/openai/triton/pull/2995