triton: Accuracy failure in reduction kernel
This is another bug mimimized from our attempt to update PyTorch’s triton pin in pytorch/pytorch#109601. The minimized pytorch code is:
PyTorch reproducer
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
torch._dynamo.config.translation_validation = True
torch._inductor.config.fallback_random = True
torch._inductor.config.generate_intermediate_hooks = True
torch._inductor.config.triton.cudagraphs = True
torch._inductor.config.split_reductions = False
isolate_fails_code_str = None
# torch version: 2.2.0a0+git9530e5e
# torch cuda version: 12.1
# torch git version: 9530e5e0b0f125bd9cf8e1cfd45693105b7a74e0
# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2023 NVIDIA Corporation
# Built on Mon_Apr__3_17:16:06_PDT_2023
# Cuda compilation tools, release 12.1, V12.1.105
# Build cuda_12.1.r12.1/compiler.32688072_0
# GPU Hardware Info:
# NVIDIA A10G : 1
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, add_21, add_24):
add_25 = torch.ops.aten.add.Tensor(add_21, add_24); add_21 = add_24 = None
convert_element_type_72 = torch.ops.prims.convert_element_type.default(add_25, torch.float32); add_25 = None
clone_38 = torch.ops.aten.clone.default(convert_element_type_72, memory_format = torch.contiguous_format); convert_element_type_72 = None
getitem_13 = torch.ops.aten.sum(clone_38, [3], keepdim = True)
sub_10 = torch.ops.aten.sub.Tensor(clone_38, getitem_13); clone_38 = getitem_13 = None
return (sub_10,)
def load_args(reader):
buf0 = reader.storage('171672dbc2d675e7d867cdace0e3744e83a3788f', 2408448, device=device(type='cuda', index=0), dtype_hint=torch.float16)
shape = (1, 16, 16, 128)
strides = [32768, 16, 1, 256]
reader.tensor(buf0, shape, strides, dtype=torch.float16, is_leaf=True) # add_21
buf1 = reader.storage('61710f795a410091170650fe92ae4a0a93ee918f', 2408448, device=device(type='cuda', index=0), dtype_hint=torch.float16)
reader.tensor(buf1, shape, dtype=torch.float16, is_leaf=True) # add_24
load_args._version = 0
mod = Repro()
if __name__ == '__main__':
from torch._dynamo.repro.after_aot import run_repro
with torch.no_grad(): run_repro(mod, load_args, accuracy=True, command='run', save_dir='/var/lib/jenkins/workspace/test/torch_compile_debug/run_2023_10_11_15_20_47_513247-pid_16805/minifier/checkpoints', tracing_mode='real', check_str=None)
and the corresponding triton code is here
Triton reproducer
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch import empty_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_pbell/tk/ctkuf77gwypetcuq2y4drcjiz5fdpijp4ob54gudfgscnfsqxkvg.py
# Source Nodes: [], Original ATen: []
triton_red_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@reduction(
size_hints=[256, 128],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'mutated_arg_names': [], 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_0', 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(3, 4))]}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 256
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (256*r1)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (128*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
tmp6 = _tmp5 + tmp4
_tmp5 = tl.where(rmask & xmask, tmp6, _tmp5)
tmp5 = tl.sum(_tmp5, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp7 = tl.load(in_ptr0 + (x0 + (256*r1)), rmask & xmask, eviction_policy='evict_first', other=0).to(tl.float32)
tmp8 = tl.load(in_ptr1 + (r1 + (128*x0)), rmask & xmask, eviction_policy='evict_first', other=0).to(tl.float32)
tmp9 = tmp7 + tmp8
tmp10 = tmp9.to(tl.float32)
tmp11 = tmp10 - tmp5
tl.store(out_ptr1 + (r1 + (128*x0)), tmp11, rmask & xmask)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (1, 16, 16, 128), (32768, 16, 1, 256))
assert_size_stride(arg1_1, (1, 16, 16, 128), (32768, 2048, 128, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf1 = empty_strided((1, 16, 16, 128), (32768, 2048, 128, 1), device='cuda', dtype=torch.float32)
# Source Nodes: [], Original ATen: []
stream0 = get_cuda_stream(0)
triton_red_fused_0.run(arg0_1, arg1_1, buf1, 256, 128, grid=grid(256), stream=stream0)
run_intermediate_hooks('sub', buf1)
del arg0_1
del arg1_1
return (buf1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((1, 16, 16, 128), (32768, 16, 1, 256), device='cuda:0', dtype=torch.float16)
arg1_1 = rand_strided((1, 16, 16, 128), (32768, 2048, 128, 1), device='cuda:0', dtype=torch.float16)
return print_performance(lambda: call([arg0_1, arg1_1]), times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
git bisect suggests that #2181 broke this kernel, and I see from the ttgir that different layout are used before and after the commit.
About this issue
- Original URL
- State: closed
- Created 9 months ago
- Reactions: 2
- Comments: 39 (28 by maintainers)
Commits related to this issue
- Add failing tests for reduction issues (#2483). https://github.com/openai/triton/issues/2483 GPC: stop — committed to openai/triton by jlebar 8 months ago
- [NVPTX] Expand EXTLOAD for v8f16 and v8bf16 (#72672) In openai/triton#2483 I've encountered a bug in the NVPTX codegen. Given `load<8 x half>` followed by `fpext to <8 x float>` we get ``` ld.sh... — committed to llvm/llvm-project by peterbell10 7 months ago
- [NVPTX] Expand EXTLOAD for v8f16 and v8bf16 (#72672) In openai/triton#2483 I've encountered a bug in the NVPTX codegen. Given `load<8 x half>` followed by `fpext to <8 x float>` we get ``` ld.sh... — committed to sr-tream/llvm-project by peterbell10 7 months ago
- [NVPTX] Expand EXTLOAD for v8f16 and v8bf16 (#72672) In openai/triton#2483 I've encountered a bug in the NVPTX codegen. Given `load<8 x half>` followed by `fpext to <8 x float>` we get ``` ld.sh... — committed to zahiraam/llvm-project by peterbell10 7 months ago
- [NVPTX] Expand EXTLOAD for v8f16 and v8bf16 In openai/triton#2483 I've encountered a bug in the NVPTX codegen. Given `load<8 x half>` followed by `fpext to <8 x float>` we get ``` ld.shared.v4.b16 ... — committed to rvgpu/llvm by peterbell10 7 months ago
Okay I found the bug exists because
v8f16is marked as a legal type inNVPTXISelLowering.cppbutsetLoadExtActionis never called for the type. I’ll submit a PR shortly.Oh, I see. Yes, we’re on the same page. I just finished building LLVM, am testing out your reproducer.
I’m starting to think the bug is actually in llvm. I was able to reduce the ttgir reproducer down to this:
Click here for details
Small changes to the reproducer return the expected result, and the main difference in the llir seems to be vectorization of the
loadin the firstconvert_layoutop and the subsequentfpext. (Failing on left, working on right)And once we get down to the PTX level the diff looks suspicious (again, failing on left)
For some reason the failing PTX loads
b16data into af32register and never does thecvt.f32.f16op. If that results in the lower bytes of the float being set, it would mean a very small but non-zero value and indeed the output of the sum is rounded to-0.0000.Here is an updated reproducer that doesn’t actually call
torch.compileso should work for more PyTorch versions. We just compare the triton kernel output against the equivalent PyTorch eager code.Second reproducer
This doesn’t seem to have had any effect for me.
As luck would have it, I’m going to start on this today! Sorry for the delay, I’m new on the team, and I’ve been working on other things.
It sounds like the layout propagation changed exposed a bug in reduce lowering for some layouts. We we look at it next week.