triton: associative_scan gives incorrect result for non-commutative operator

Forgive the somewhat involved example, but consider the problem of computing a cumulative exponential moving average,

$$ x^\prime_{i+1} = x^\prime_i \gamma_{i+1} + x_{i+1}. $$

This can be computed with an associative scan using the operation

$$ (x_a, \gamma_a) \oplus (x_b, \gamma_b) = ( x_a \gamma_b + x_b, \gamma_a \gamma_b). $$

Proof that this is associative (though non-commutative) and reduces to the cumulative EMA is left to the reader 😃.

associative_scan does not currently support multiple input arguments, so the following example uses a dirty work-around that merges float32 values $x$ and factors $\gamma$ into a single int64, and the hacky $\oplus$ first splits then, performs the ‘sum’ as defined above, then re-merges. This is just a representation change and does not affect associativity/commutativity of the operator.

I’ve included a demonstration script below, but to summarise current behaviour:

  • for sequence_length <= 64, the correct behaviour can be recovered by switching operand order in the implementation
  • for sequence_length >= 128 I can’t find any work-around. In the 128 values case, the first 64 values are correct with the work-around implementation, but the second 64 are always off.

This makes me think the merging part of the implementation uses the correct order while the initial accumulates over small sequences use combine_fn(b, a) rather than combine_fn(a, b).

import torch
import numpy as np
import triton
import triton.language as tl

@triton.jit
def bitcast_unmerge(merged):
    tl.static_assert(merged.dtype == tl.int64)
    b = (merged & 0xFFFFFFFF).to(tl.int32).to(tl.float32, bitcast=True)
    a = (merged >> 32).to(tl.int32).to(tl.float32, bitcast=True)  # shifted by 32 bits
    return a, b


@triton.jit
def bitcast_merge(a, b):
    tl.static_assert(a.dtype == tl.float32)
    tl.static_assert(b.dtype == tl.float32)
    a = a.to(dtype=tl.int32, bitcast=True).to(tl.int64)  # directly converted to int32
    a = a << 32  # shifted by 32 bits
    b = b.to(dtype=tl.int32, bitcast=True).to(tl.int64)  # directly converted to int32
    return a | b


@triton.jit
def cumulative_ema_op(a, b):
    xa, fa = bitcast_unmerge(a)
    xb, fb = bitcast_unmerge(b)
    # x_out = xa * fb + xb  # <- this should be the correct one
    x_out = xb * fa + xa  # <- this makes it correct for small sequence lengths
    f_out = fa * fb
    return bitcast_merge(x_out, f_out)

@triton.jit
def kernel(
    values,
    factors,
    Z,
    SEQUENCE_LENGTH: tl.constexpr,
):
    global_id = tl.num_programs(axis=1) * tl.program_id(axis=0) + tl.program_id(axis=1)

    offsets = tl.arange(0, SEQUENCE_LENGTH) + global_id * SEQUENCE_LENGTH

    values_tensor = tl.load(values + offsets)
    factors_tensor = tl.load(factors + offsets)
    vf = bitcast_merge(values_tensor, factors_tensor)

    z = tl.associative_scan(vf, 0, combine_fn=cumulative_ema_op)

    out_values, out_factors = bitcast_unmerge(z)
    # out_factors contains cumulative product - but we're not interested in that
    tl.store(Z + offsets, out_values)


def cumulative_ema(
    values: torch.Tensor,
    factors: torch.Tensor,
) -> torch.Tensor:
    """
    Compute cumulative exponential moving average on last axis of rank-3 inputs.

    Args:
        values: [B, N, T] float32 values
        factors: [B, N, T] float32 decay factors

    Returns:
        cumulative ema values, same shape as values/factors.
    """
    assert len(values.shape) == 3, values.shape
    assert values.shape == factors.shape, (values.shape, factors.shape)

    shape = values.shape
    result = torch.empty_like(values)

    kernel[(shape[0], shape[1])](
        values,
        factors,
        result,
        SEQUENCE_LENGTH=shape[2],
    )
    return result
    
if __name__ == "__main__":
    shape = (1, 1, 64)  # if this is (1, 1, 128) the result is incorrect even with the work-around cumulative_ema_op
    values = np.arange(np.prod(shape)).reshape(shape)
    factors = np.full(shape, 0.9)

    expected = np.zeros(shape, np.float32)
    expected[:, :, 0] = values[:, :, 0]
    for i in range(1, shape[2]):
        expected[:, :, i] = expected[:, :, i - 1] * factors[:, :, i] + values[:, :, i]
    print(expected)

    device = "cuda"
    values = torch.tensor(values, dtype=torch.float32, device=device)
    factors = torch.tensor(factors, dtype=torch.float32, device=device)

    result = cumulative_ema(values, factors)
    result = result.cpu()
    print(result.numpy())
    print(torch.max(torch.abs((result - expected) / expected)[..., 1:]))

Using the theoretically correct cumulative_ema_op implementation, this prints:

[[[  0.         1.         2.9        5.61       9.049     13.1441
    17.82969   23.046722  28.74205   34.867844  41.381058  48.24295
    55.418655  62.87679   70.58911   78.5302    86.67718   95.00946
   103.508514 112.15766  120.941895 129.8477   138.86293  147.97664
   157.17897  166.46107  175.81497  185.23347  194.71013  204.23912
   213.81522  223.4337   233.09033  242.7813   252.50317  262.25287
   272.0276   281.82483  291.64233  301.4781   311.3303   321.19727
   331.07755  340.9698   350.8728   360.78552  370.70697  380.63626
   390.57263  400.51538  410.46384  420.41745  430.3757   440.33813
   450.30432  460.2739   470.24652  480.22186  490.19968  500.17972
   510.16174  520.14557  530.131    540.11786 ]]]
[[[ 0.        0.9       2.52      4.707     7.3314   10.283849 13.472494
   16.820572 20.264309 23.751093 27.237875 30.689793 34.078945 37.38337
   40.586117 43.674484 46.639313 49.47443  52.17614  54.742752 57.174282
   59.47208  61.638573 63.677048 65.59144  67.386185 69.06606  70.63611
   72.10148  73.467415 74.73914  75.92186  77.02064  78.04044  78.98608
   79.86218  80.6732   81.42339  82.11682  82.757324 83.34856  83.893974
   84.39682  84.86015  85.28685  85.67961  86.04095  86.37322  86.67863
   86.95922  87.21691  87.45347  87.67055  87.869675 88.05227  88.21964
   88.373024 88.513535 88.642204 88.76001  88.86783  88.966484 89.05673
   89.13927 ]]]

If we look at the second and third values from the actual output (bottom array), we see our add operation has been evaluated in the wrong order:

  • the second value should be 0 * 0.9 + 1 = 1, but instead we get 0 + 0.9 * 1
  • If we accept the second value as is, the third value should be 0.9 ** 2 * 0.9 + 2, but we instead get 0.9 + 0.9**2 * 2 = 2.52

About this issue

  • Original URL
  • State: closed
  • Created 9 months ago
  • Reactions: 1
  • Comments: 15 (2 by maintainers)

Most upvoted comments

Since this issue was fixed one more issue remained to make associative_scan usable: FMA operations made scan numerically unstable. enable_fp_fusion=False.

I made a library for first order scans that uses tl.associative_scan on triton 2.2.0: https://github.com/proger/accelerated-scan

Thanks @PeaBrane. Maybe this is only an issue on Cuda 12?

Edit: Can confirm the issue persists on A100’s and a 3090

Thanks for the analysis, this PR should solve it: https://github.com/openai/triton/pull/2362

For the multiple operand support I don’t have plans to work on it in the near future but anybody is welcome to contribute it.

The code worked for me, and the relative deviation from expected result is only 3.1392e-07. I’m running it on triton-nightly 2.1.0.dev20231014192330.