triton: Crash is occured

Below code make a crash with the latest Triton.

@triton.jit
    def backward(grad_out_ptr, inp_ptr, grad_inp_ptr, vec_sz, wgt_ptr, grad_wgt_ptr, grad_bis_ptr, mean_ptr,
                 std_ptr, vec_bs: triton.language.constexpr):
        pid = triton.language.program_id(0)
        wgt_blk = triton.language.arange(0, vec_bs)
        inp_blk = wgt_blk + pid * vec_sz
        msk = wgt_blk < vec_sz

        grad_out = triton.language.load(grad_out_ptr + inp_blk, msk, 0)
        inp = triton.language.load(inp_ptr + inp_blk, msk, 0)
        mean = triton.language.load(mean_ptr + pid)
        std = triton.language.load(std_ptr + pid)
        # wgt = triton.language.load(wgt_ptr + wgt_blk, msk, 0)

        if wgt_ptr:
            wgt = triton.language.load(wgt_ptr + wgt_blk, msk, 0)
        else:
            wgt = triton.language.full(wgt_blk.shape, 1, dtype=grad_out.dtype)

        a = grad_out * wgt
        b = (inp - mean) / std
        b = triton.language.where(msk, b, 0)
        c = language.gemv(b, a) / vec_sz
        d = language.gemv(grad_out, wgt) / vec_sz
        grad_inp = (a - c * b - d) / std

        triton.language.store(grad_inp_ptr + inp_blk, grad_inp, msk)

        if grad_wgt_ptr:
            grad_wgt = grad_out * b
            triton.language.atomic_add(grad_wgt_ptr + wgt_blk, grad_wgt, msk)

        if grad_bis_ptr:
            grad_bis = grad_out
            triton.language.atomic_add(grad_bis_ptr + wgt_blk, grad_bis, msk)

Below is the call stack.

Current thread 0x00007f6743c03700 (most recent call first):
  File "/home/deploy/.local/lib/python3.8/site-packages/triton/compiler/compiler.py", line 106 in ttgir_to_llir
  File "/home/deploy/.local/lib/python3.8/site-packages/triton/compiler/compiler.py", line 403 in <lambda>
  File "/home/deploy/.local/lib/python3.8/site-packages/triton/compiler/compiler.py", line 494 in compile
  File "<string>", line 62 in backward
  File "/tmp/pycharm_project_125/trident/operation/layer_norm.py", line 70 in __backward
  File "/tmp/pycharm_project_125/trident/operation/layer_norm.py", line 36 in backward
  File "/home/deploy/.local/lib/python3.8/site-packages/torch/autograd/function.py", line 274 in apply

You can reproduce this crash easily.

  1. git clone https://github.com/kakaobrain/trident.git
  2. cd triton
  3. git checkout triton
  4. export PYTHONPATH=${PYTHONPATH}:${pwd}
  5. pytest -k test_layer_norm

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 23 (9 by maintainers)

Most upvoted comments

If so, the steps should be:

git clone https://github.com/kakaobrain/trident.git
cd trident # instead of cd triton?
git checkout triton
export PYTHONPATH=${PYTHONPATH}:${pwd}
pytest -k test_layer_norm

I don’t think it has to do with cuda versions, the problem occurred in ttgir_to_llir.

We just use the ptxas shipped with triton, which is picked from CUDA 12.0.*

tests/test_layer_norm.py … [100%]

============================================================================= 12 passed, 119 deselected in 8.58s =============================================================================