triton: Accuracy failure in reduction kernel

This is another bug mimimized from our attempt to update PyTorch’s triton pin in pytorch/pytorch#109601. The minimized pytorch code is:

PyTorch reproducer
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
torch._dynamo.config.translation_validation = True
torch._inductor.config.fallback_random = True
torch._inductor.config.generate_intermediate_hooks = True
torch._inductor.config.triton.cudagraphs = True
torch._inductor.config.split_reductions = False



isolate_fails_code_str = None



# torch version: 2.2.0a0+git9530e5e
# torch cuda version: 12.1
# torch git version: 9530e5e0b0f125bd9cf8e1cfd45693105b7a74e0


# CUDA Info: 
# nvcc: NVIDIA (R) Cuda compiler driver 
# Copyright (c) 2005-2023 NVIDIA Corporation 
# Built on Mon_Apr__3_17:16:06_PDT_2023 
# Cuda compilation tools, release 12.1, V12.1.105 
# Build cuda_12.1.r12.1/compiler.32688072_0 

# GPU Hardware Info: 
# NVIDIA A10G : 1 


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    
    
    def forward(self, add_21, add_24):
        add_25 = torch.ops.aten.add.Tensor(add_21, add_24);  add_21 = add_24 = None
        convert_element_type_72 = torch.ops.prims.convert_element_type.default(add_25, torch.float32);  add_25 = None
        clone_38 = torch.ops.aten.clone.default(convert_element_type_72, memory_format = torch.contiguous_format);  convert_element_type_72 = None
        getitem_13 = torch.ops.aten.sum(clone_38, [3], keepdim = True)
        sub_10 = torch.ops.aten.sub.Tensor(clone_38, getitem_13);  clone_38 = getitem_13 = None
        return (sub_10,)
        
def load_args(reader):
    buf0 = reader.storage('171672dbc2d675e7d867cdace0e3744e83a3788f', 2408448, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    shape = (1, 16, 16, 128)
    strides = [32768, 16, 1, 256]
    reader.tensor(buf0, shape, strides, dtype=torch.float16, is_leaf=True)  # add_21
    buf1 = reader.storage('61710f795a410091170650fe92ae4a0a93ee918f', 2408448, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    reader.tensor(buf1, shape, dtype=torch.float16, is_leaf=True)  # add_24
load_args._version = 0
mod = Repro()
if __name__ == '__main__':
    from torch._dynamo.repro.after_aot import run_repro
    with torch.no_grad():        run_repro(mod, load_args, accuracy=True, command='run', save_dir='/var/lib/jenkins/workspace/test/torch_compile_debug/run_2023_10_11_15_20_47_513247-pid_16805/minifier/checkpoints', tracing_mode='real', check_str=None)

and the corresponding triton code is here

Triton reproducer
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile

from torch import empty_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_pbell/tk/ctkuf77gwypetcuq2y4drcjiz5fdpijp4ob54gudfgscnfsqxkvg.py
# Source Nodes: [], Original ATen: []

triton_red_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers

@reduction(
    size_hints=[256, 128],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'mutated_arg_names': [], 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_0', 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(3, 4))]}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 256
    rnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (256*r1)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r1 + (128*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
        tmp6 = _tmp5 + tmp4
        _tmp5 = tl.where(rmask & xmask, tmp6, _tmp5)
    tmp5 = tl.sum(_tmp5, 1)[:, None]
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp7 = tl.load(in_ptr0 + (x0 + (256*r1)), rmask & xmask, eviction_policy='evict_first', other=0).to(tl.float32)
        tmp8 = tl.load(in_ptr1 + (r1 + (128*x0)), rmask & xmask, eviction_policy='evict_first', other=0).to(tl.float32)
        tmp9 = tmp7 + tmp8
        tmp10 = tmp9.to(tl.float32)
        tmp11 = tmp10 - tmp5
        tl.store(out_ptr1 + (r1 + (128*x0)), tmp11, rmask & xmask)
''')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (1, 16, 16, 128), (32768, 16, 1, 256))
    assert_size_stride(arg1_1, (1, 16, 16, 128), (32768, 2048, 128, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf1 = empty_strided((1, 16, 16, 128), (32768, 2048, 128, 1), device='cuda', dtype=torch.float32)
        # Source Nodes: [], Original ATen: []
        stream0 = get_cuda_stream(0)
        triton_red_fused_0.run(arg0_1, arg1_1, buf1, 256, 128, grid=grid(256), stream=stream0)
        run_intermediate_hooks('sub', buf1)
        del arg0_1
        del arg1_1
        return (buf1, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = rand_strided((1, 16, 16, 128), (32768, 16, 1, 256), device='cuda:0', dtype=torch.float16)
    arg1_1 = rand_strided((1, 16, 16, 128), (32768, 2048, 128, 1), device='cuda:0', dtype=torch.float16)
    return print_performance(lambda: call([arg0_1, arg1_1]), times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

git bisect suggests that #2181 broke this kernel, and I see from the ttgir that different layout are used before and after the commit.

About this issue

  • Original URL
  • State: closed
  • Created 9 months ago
  • Reactions: 2
  • Comments: 39 (28 by maintainers)

Commits related to this issue

Most upvoted comments

Okay I found the bug exists because v8f16 is marked as a legal type in NVPTXISelLowering.cpp but setLoadExtAction is never called for the type. I’ll submit a PR shortly.

Yes, if the high bytes including the exponent are zeroed and some of the mantissa bits are non-zero then we would get a a denormal float value, which is what I meant by “very small but non-zero”.

Oh, I see. Yes, we’re on the same page. I just finished building LLVM, am testing out your reproducer.

I’m starting to think the bug is actually in llvm. I was able to reduce the ttgir reproducer down to this:

Click here for details
import torch
import tempfile
from torch import tensor, device
from torch._dynamo.testing import rand_strided

import triton
import triton.language as tl

from torch import device, empty, empty_strided
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, arg0_1, arg5_1, arg6_1, permute, rsqrt):
        add_2 = permute
        #add_2 = torch.ops.aten.add.Tensor(permute, arg0_1);  permute = arg0_1 = None
        clone_1 = torch.ops.aten.clone.default(add_2, memory_format = torch.contiguous_format);  add_2 = None
        getitem_1 = torch.ops.aten.sum.dim_IntList(clone_1, [2], keepdim=True)
        sub = torch.ops.aten.sub.Tensor(clone_1, getitem_1);
        add_4 = sub
        convert_element_type_11 = torch.ops.prims.convert_element_type.default(add_4, torch.float16);  add_4 = None
        view = torch.ops.aten.view.default(convert_element_type_11, [6000, 256]);  convert_element_type_11 = None
        return view[0:1],
        #return (view,)


repro_ttgir = """
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
module attributes {"triton_gpu.compute-capability" = 86 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @triton_red_fused_0_0d1d2d3d4d5d6de7de(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
    %cst = arith.constant dense<6000> : tensor<1x1xi32, #blocked>
    %cst_1 = arith.constant dense<256> : tensor<1x256xi32, #blocked>
    %cst_3 = arith.constant dense<256> : tensor<1x256xi32, #blocked1>
    %cst_5 = arith.constant dense<1500> : tensor<1x1xi32, #blocked>
    %cst_6 = arith.constant dense<256> : tensor<1x1xi32, #blocked>
    %cst_9 = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked>
    %cst_11 = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked1>
    %c16_i32 = arith.constant 1 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c16_i32 : i32
    %2 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %4 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1x1xi32, #blocked>
    %6 = tt.splat %1 : (i32) -> tensor<1x1xi32, #blocked>
    %8 = arith.addi %6, %4 : tensor<1x1xi32, #blocked>
    %10 = arith.cmpi slt, %8, %cst : tensor<1x1xi32, #blocked>
    %13 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %14 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %16 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x256xi32, #blocked>
    %17 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x256xi32, #blocked1>
    %18 = arith.cmpi slt, %16, %cst_1 : tensor<1x256xi32, #blocked>
    %22 = arith.remsi %8, %cst_5 : tensor<1x1xi32, #blocked>
    %33 = tt.broadcast %18 : (tensor<1x256xi1, #blocked>) -> tensor<1x256xi1, #blocked>
    %35 = tt.broadcast %10 : (tensor<1x1xi1, #blocked>) -> tensor<1x256xi1, #blocked>
    %37 = arith.andi %33, %35 : tensor<1x256xi1, #blocked>

    %39 = arith.constant dense<1.000000e+00> : tensor<1x256xf16, #blocked1>
    %40 = triton_gpu.convert_layout %39 : (tensor<1x256xf16, #blocked1>) -> tensor<1x256xf16, #blocked>

    %43 = tt.broadcast %16 : (tensor<1x256xi32, #blocked>) -> tensor<1x256xi32, #blocked>
    %49 = arith.extf %40 : tensor<1x256xf16, #blocked> to tensor<1x256xf32, #blocked>
    %50 = arith.select %37, %49, %cst_9 : tensor<1x256xi1, #blocked>, tensor<1x256xf32, #blocked>
    %51 = "tt.reduce"(%50) <{axis = 1 : i32}> ({
    ^bb0(%arg8: f32, %arg9: f32):
      %78 = arith.addf %arg8, %arg9 : f32
      tt.reduce.return %78 : f32
    }) : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %52 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1x1xf32, #blocked>
    %62 = tt.broadcast %52 : (tensor<1x1xf32, #blocked>) -> tensor<1x256xf32, #blocked>
    %63 = arith.subf %49, %62 : tensor<1x256xf32, #blocked>

    %66 = triton_gpu.convert_layout %63 : (tensor<1x256xf32, #blocked>) -> tensor<1x256xf32, #blocked1>

    %71 = arith.muli %8, %cst_6 : tensor<1x1xi32, #blocked>
    %72 = tt.broadcast %71 : (tensor<1x1xi32, #blocked>) -> tensor<1x256xi32, #blocked>
    %73 = arith.addi %43, %72 : tensor<1x256xi32, #blocked>
    %74 = tt.splat %arg5 : (!tt.ptr<f16, 1>) -> tensor<1x256x!tt.ptr<f16, 1>, #blocked>
    %75 = tt.addptr %74, %73 : tensor<1x256x!tt.ptr<f16, 1>, #blocked>, tensor<1x256xi32, #blocked>
    %77 = arith.truncf %63 : tensor<1x256xf32, #blocked> to tensor<1x256xf16, #blocked>
    tt.store %75, %77, %37 {cache = 1 : i32, evict = 1 : i32} : tensor<1x256xf16, #blocked>
    tt.return
  }
}
"""

def call(args):
    assert_size_stride = torch._C._dynamo.guards.assert_size_stride
    arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
    args.clear()
    assert_size_stride(arg0_1, (1500, 256), (256, 1))
    assert_size_stride(arg1_1, (256, ), (1, ))
    assert_size_stride(arg2_1, (256, ), (1, ))
    assert_size_stride(arg3_1, (4, 1500, 256), (384000, 1, 1500))
    assert_size_stride(arg4_1, (4, 1500, 1), (1500, 1, 1))
    assert_size_stride = torch._C._dynamo.guards.assert_size_stride
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf1 = empty((1, 256), device='cuda', dtype=torch.float16)
        with tempfile.NamedTemporaryFile(suffix=".ttgir") as ttgir_file:
            ttgir_file.write(repro_ttgir.encode("utf-8"))
            ttgir_file.flush()

            triton_red_fused_0 = triton.compile(ttgir_file.name)
        triton_red_fused_0[(1, 1, 1)](arg3_1, arg0_1, arg4_1, arg1_1, arg2_1, buf1, 6000, 256)
    return (buf1, )


torch.manual_seed(0)
arg0_1 = rand_strided((1500, 256), (256, 1), device='cuda:0', dtype=torch.float32).fill_(1.0)
arg1_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32).fill_(1.0)
arg2_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32).fill_(1.0)
arg3_1 = rand_strided((4, 1500, 256), (384000, 1, 1500), device='cuda:0', dtype=torch.float16).fill_(1.0)
arg4_1 = rand_strided((4, 1500, 1), (1500, 1, 1), device='cuda:0', dtype=torch.float32).fill_(1.0)
args = [arg0_1, arg1_1, arg2_1, arg3_1, arg4_1]

mod = Repro()
expect = mod(*args)
actual = call(args)

print(expect[0], actual[0])
torch.testing.assert_close(expect, actual)

Small changes to the reproducer return the expected result, and the main difference in the llir seems to be vectorization of the load in the first convert_layout op and the subsequent fpext. (Failing on left, working on right)

  %19 = load <8 x half>, ptr addrspace(3) %18, align 16	      |	  %19 = load half, ptr addrspace(3) %18, align 16
  %20 = fpext <8 x half> %19 to <8 x float>		      |	  %20 = getelementptr inbounds <8 x half>, ptr addrspace(3) %
  %shift = shufflevector <8 x float> %20, <8 x float> poison, |	  %21 = load half, ptr addrspace(3) %20, align 2
  %21 = fadd <8 x float> %shift, %20			      |	  %22 = getelementptr inbounds <8 x half>, ptr addrspace(3) %
  %shift1 = shufflevector <8 x float> %20, <8 x float> poison |	  %23 = load half, ptr addrspace(3) %22, align 4
  %22 = fadd <8 x float> %21, %shift1			      |	  %24 = getelementptr inbounds <8 x half>, ptr addrspace(3) %
  %shift2 = shufflevector <8 x float> %20, <8 x float> poison |	  %25 = load half, ptr addrspace(3) %24, align 2
  %23 = fadd <8 x float> %22, %shift2			      |	  %26 = getelementptr inbounds <8 x half>, ptr addrspace(3) %
  %shift3 = shufflevector <8 x float> %20, <8 x float> poison |	  %27 = load half, ptr addrspace(3) %26, align 8
  %24 = fadd <8 x float> %23, %shift3			      |	  %28 = getelementptr inbounds <8 x half>, ptr addrspace(3) %
  %shift4 = shufflevector <8 x float> %20, <8 x float> poison |	  %29 = load half, ptr addrspace(3) %28, align 2
  %25 = fadd <8 x float> %24, %shift4			      |	  %30 = getelementptr inbounds <8 x half>, ptr addrspace(3) %
  %shift5 = shufflevector <8 x float> %20, <8 x float> poison |	  %31 = load half, ptr addrspace(3) %30, align 4
  %26 = fadd <8 x float> %25, %shift5			      |	  %32 = getelementptr inbounds <8 x half>, ptr addrspace(3) %
  %shift6 = shufflevector <8 x float> %20, <8 x float> poison |	  %33 = load half, ptr addrspace(3) %32, align 2
  %27 = fadd <8 x float> %26, %shift6			      |	  %34 = fpext half %19 to float
  %28 = extractelement <8 x float> %27, i64 0		      |	  %35 = fpext half %21 to float
  %29 = select i1 %14, float %28, float 0.000000e+00	      |	  %36 = fpext half %23 to float
  %30 = bitcast float %29 to i32			      |	  %37 = fpext half %25 to float
  %31 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i |	  %38 = fpext half %27 to float
  %32 = bitcast i32 %31 to float			      |	  %39 = fpext half %29 to float
  %33 = fadd float %29, %32				      |	  %40 = fpext half %31 to float
  %34 = bitcast float %33 to i32			      |	  %41 = fpext half %33 to float

And once we get down to the PTX level the diff looks suspicious (again, failing on left)

	ld.shared.v4.b16 	{%f1, %f2, %f3, %f4}, [%r15+8 |		ld.shared.v4.b32 	{%r16, %r17, %r18, %r19}, [%r
	ld.shared.v4.b16 	{%f5, %f6, %f7, %f8}, [%r15]; |		mov.b32 	{%rs1, %rs2}, %r19;
	add.f32 	%f9, %f6, %f5;			      |		mov.b32 	{%rs3, %rs4}, %r18;
	add.f32 	%f10, %f9, %f7;			      |		mov.b32 	{%rs5, %rs6}, %r17;
	add.f32 	%f11, %f10, %f8;		      |		mov.b32 	{%rs7, %rs8}, %r16;
	add.f32 	%f12, %f11, %f1;		      |		cvt.f32.f16 	%f1, %rs7;
	add.f32 	%f13, %f12, %f2;		      |		cvt.f32.f16 	%f2, %rs8;
	add.f32 	%f14, %f13, %f3;		      |		cvt.f32.f16 	%f3, %rs5;
	add.f32 	%f15, %f14, %f4;		      |		cvt.f32.f16 	%f4, %rs6;
							      >		cvt.f32.f16 	%f5, %rs3;
							      >		cvt.f32.f16 	%f6, %rs4;
							      >		cvt.f32.f16 	%f7, %rs1;
							      >		cvt.f32.f16 	%f8, %rs2;
							      >		add.f32 	%f9, %f1, %f2;
							      >		add.f32 	%f10, %f9, %f3;
							      >		add.f32 	%f11, %f10, %f4;
							      >		add.f32 	%f12, %f11, %f5;
							      >		add.f32 	%f13, %f12, %f6;
							      >		add.f32 	%f14, %f13, %f7;
							      >		add.f32 	%f15, %f14, %f

For some reason the failing PTX loads b16 data into a f32 register and never does the cvt.f32.f16 op. If that results in the lower bytes of the float being set, it would mean a very small but non-zero value and indeed the output of the sum is rounded to -0.0000.

Here is an updated reproducer that doesn’t actually call torch.compile so should work for more PyTorch versions. We just compare the triton kernel output against the equivalent PyTorch eager code.

Second reproducer
import torch
from torch import tensor, device
from torch._dynamo.testing import rand_strided

import triton
import triton.language as tl

from torch import device, empty, empty_strided
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, arg0_1, arg5_1, arg6_1, permute, rsqrt):
        add_2 = torch.ops.aten.add.Tensor(permute, arg0_1);  permute = arg0_1 = None
        clone_1 = torch.ops.aten.clone.default(add_2, memory_format = torch.contiguous_format);  add_2 = None
        getitem_1 = torch.ops.aten.sum.dim_IntList(clone_1, [2], keepdim=True)
        sub = torch.ops.aten.sub.Tensor(clone_1, getitem_1);  clone_1 = getitem_1 = None
        mul_6 = torch.ops.aten.mul.Tensor(sub, rsqrt);  sub = rsqrt = None
        mul_7 = torch.ops.aten.mul.Tensor(mul_6, arg5_1);  mul_6 = arg5_1 = None
        add_4 = torch.ops.aten.add.Tensor(mul_7, arg6_1);  mul_7 = arg6_1 = None
        convert_element_type_11 = torch.ops.prims.convert_element_type.default(add_4, torch.float16);  add_4 = None
        view = torch.ops.aten.view.default(convert_element_type_11, [6000, 256]);  convert_element_type_11 = None
        return (view,)

@triton.jit
def triton_red_fused_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel):
    xnumel = 6000
    rnumel = 256
    XBLOCK: tl.constexpr = 16
    RBLOCK: tl.constexpr = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex % 1500
    x1 = (xindex // 1500)
    _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    x3 = xindex
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r2 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (1500*r2) + (384000*x1)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tl.load(in_ptr1 + (r2 + (256*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
        tmp1 = tmp0.to(tl.float32)
        tmp3 = tmp1 + tmp2
        tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
        tmp6 = _tmp5 + tmp4
        _tmp5 = tl.where(rmask & xmask, tmp6, _tmp5)
    tmp5 = tl.sum(_tmp5, 1)[:, None]
    tmp12 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last')
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r2 = rindex
        tmp7 = tl.load(in_ptr0 + ((1500*r2) + (384000*(x3 // 1500)) + (x3 % 1500)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp9 = tl.load(in_ptr1 + (r2 + (256*(x3 % 1500))), rmask & xmask, eviction_policy='evict_first', other=0.0)
        tmp14 = tl.load(in_ptr3 + (r2), rmask, eviction_policy='evict_last', other=0.0)
        tmp16 = tl.load(in_ptr4 + (r2), rmask, eviction_policy='evict_last', other=0.0)
        tmp8 = tmp7.to(tl.float32)
        tmp10 = tmp8 + tmp9
        tmp11 = tmp10 - tmp5
        tmp13 = tmp11 * tmp12
        tmp15 = tmp13 * tmp14
        tmp17 = tmp15 + tmp16
        tmp18 = tmp17.to(tl.float32)
        tl.store(out_ptr1 + (r2 + (256*x3)), tmp18, rmask & xmask)


def call(args):
    assert_size_stride = torch._C._dynamo.guards.assert_size_stride
    arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
    args.clear()
    assert_size_stride(arg0_1, (1500, 256), (256, 1))
    assert_size_stride(arg1_1, (256, ), (1, ))
    assert_size_stride(arg2_1, (256, ), (1, ))
    assert_size_stride(arg3_1, (4, 1500, 256), (384000, 1, 1500))
    assert_size_stride(arg4_1, (4, 1500, 1), (1500, 1, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf1 = empty((6000, 256), device='cuda', dtype=torch.float16)
        # Source Nodes: [], Original ATen: []
        stream0 = get_cuda_stream(0)
        triton_red_fused_0[(6000,)](arg3_1, arg0_1, arg4_1, arg1_1, arg2_1, buf1, 6000, 256)
        return (buf1, )


torch.manual_seed(0)
arg0_1 = rand_strided((1500, 256), (256, 1), device='cuda:0', dtype=torch.float32)
arg1_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg2_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg3_1 = rand_strided((4, 1500, 256), (384000, 1, 1500), device='cuda:0', dtype=torch.float16)
arg4_1 = rand_strided((4, 1500, 1), (1500, 1, 1), device='cuda:0', dtype=torch.float32)
args = [arg0_1, arg1_1, arg2_1, arg3_1, arg4_1]

mod = Repro()
expect = mod(*args)
actual = call(args)

torch.testing.assert_close(expect, actual)

This is optimistic, but perhaps I fixed it with https://github.com/openai/triton/commit/0e3bf3f58061bd62725e7c502e7e0af19690df1c (which was inspired by this issue).

This doesn’t seem to have had any effect for me.

As luck would have it, I’m going to start on this today! Sorry for the delay, I’m new on the team, and I’ve been working on other things.

It sounds like the layout propagation changed exposed a bug in reduce lowering for some layouts. We we look at it next week.