triton: Crash is occured
Below code make a crash with the latest Triton.
@triton.jit
def backward(grad_out_ptr, inp_ptr, grad_inp_ptr, vec_sz, wgt_ptr, grad_wgt_ptr, grad_bis_ptr, mean_ptr,
std_ptr, vec_bs: triton.language.constexpr):
pid = triton.language.program_id(0)
wgt_blk = triton.language.arange(0, vec_bs)
inp_blk = wgt_blk + pid * vec_sz
msk = wgt_blk < vec_sz
grad_out = triton.language.load(grad_out_ptr + inp_blk, msk, 0)
inp = triton.language.load(inp_ptr + inp_blk, msk, 0)
mean = triton.language.load(mean_ptr + pid)
std = triton.language.load(std_ptr + pid)
# wgt = triton.language.load(wgt_ptr + wgt_blk, msk, 0)
if wgt_ptr:
wgt = triton.language.load(wgt_ptr + wgt_blk, msk, 0)
else:
wgt = triton.language.full(wgt_blk.shape, 1, dtype=grad_out.dtype)
a = grad_out * wgt
b = (inp - mean) / std
b = triton.language.where(msk, b, 0)
c = language.gemv(b, a) / vec_sz
d = language.gemv(grad_out, wgt) / vec_sz
grad_inp = (a - c * b - d) / std
triton.language.store(grad_inp_ptr + inp_blk, grad_inp, msk)
if grad_wgt_ptr:
grad_wgt = grad_out * b
triton.language.atomic_add(grad_wgt_ptr + wgt_blk, grad_wgt, msk)
if grad_bis_ptr:
grad_bis = grad_out
triton.language.atomic_add(grad_bis_ptr + wgt_blk, grad_bis, msk)
Below is the call stack.
Current thread 0x00007f6743c03700 (most recent call first):
File "/home/deploy/.local/lib/python3.8/site-packages/triton/compiler/compiler.py", line 106 in ttgir_to_llir
File "/home/deploy/.local/lib/python3.8/site-packages/triton/compiler/compiler.py", line 403 in <lambda>
File "/home/deploy/.local/lib/python3.8/site-packages/triton/compiler/compiler.py", line 494 in compile
File "<string>", line 62 in backward
File "/tmp/pycharm_project_125/trident/operation/layer_norm.py", line 70 in __backward
File "/tmp/pycharm_project_125/trident/operation/layer_norm.py", line 36 in backward
File "/home/deploy/.local/lib/python3.8/site-packages/torch/autograd/function.py", line 274 in apply
You can reproduce this crash easily.
git clone https://github.com/kakaobrain/trident.gitcd tritongit checkout tritonexport PYTHONPATH=${PYTHONPATH}:${pwd}pytest -k test_layer_norm
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 23 (9 by maintainers)
If so, the steps should be:
I don’t think it has to do with cuda versions, the problem occurred in
ttgir_to_llir.We just use the
ptxasshipped with triton, which is picked from CUDA 12.0.*