triton: associative_scan gives incorrect result for non-commutative operator
Forgive the somewhat involved example, but consider the problem of computing a cumulative exponential moving average,
$$ x^\prime_{i+1} = x^\prime_i \gamma_{i+1} + x_{i+1}. $$
This can be computed with an associative scan using the operation
$$ (x_a, \gamma_a) \oplus (x_b, \gamma_b) = ( x_a \gamma_b + x_b, \gamma_a \gamma_b). $$
Proof that this is associative (though non-commutative) and reduces to the cumulative EMA is left to the reader 😃.
associative_scan does not currently support multiple input arguments, so the following example uses a dirty work-around that merges float32 values $x$ and factors $\gamma$ into a single int64, and the hacky $\oplus$ first splits then, performs the ‘sum’ as defined above, then re-merges. This is just a representation change and does not affect associativity/commutativity of the operator.
I’ve included a demonstration script below, but to summarise current behaviour:
- for
sequence_length <= 64, the correct behaviour can be recovered by switching operand order in the implementation - for
sequence_length >= 128I can’t find any work-around. In the 128 values case, the first 64 values are correct with the work-around implementation, but the second 64 are always off.
This makes me think the merging part of the implementation uses the correct order while the initial accumulates over small sequences use combine_fn(b, a) rather than combine_fn(a, b).
import torch
import numpy as np
import triton
import triton.language as tl
@triton.jit
def bitcast_unmerge(merged):
tl.static_assert(merged.dtype == tl.int64)
b = (merged & 0xFFFFFFFF).to(tl.int32).to(tl.float32, bitcast=True)
a = (merged >> 32).to(tl.int32).to(tl.float32, bitcast=True) # shifted by 32 bits
return a, b
@triton.jit
def bitcast_merge(a, b):
tl.static_assert(a.dtype == tl.float32)
tl.static_assert(b.dtype == tl.float32)
a = a.to(dtype=tl.int32, bitcast=True).to(tl.int64) # directly converted to int32
a = a << 32 # shifted by 32 bits
b = b.to(dtype=tl.int32, bitcast=True).to(tl.int64) # directly converted to int32
return a | b
@triton.jit
def cumulative_ema_op(a, b):
xa, fa = bitcast_unmerge(a)
xb, fb = bitcast_unmerge(b)
# x_out = xa * fb + xb # <- this should be the correct one
x_out = xb * fa + xa # <- this makes it correct for small sequence lengths
f_out = fa * fb
return bitcast_merge(x_out, f_out)
@triton.jit
def kernel(
values,
factors,
Z,
SEQUENCE_LENGTH: tl.constexpr,
):
global_id = tl.num_programs(axis=1) * tl.program_id(axis=0) + tl.program_id(axis=1)
offsets = tl.arange(0, SEQUENCE_LENGTH) + global_id * SEQUENCE_LENGTH
values_tensor = tl.load(values + offsets)
factors_tensor = tl.load(factors + offsets)
vf = bitcast_merge(values_tensor, factors_tensor)
z = tl.associative_scan(vf, 0, combine_fn=cumulative_ema_op)
out_values, out_factors = bitcast_unmerge(z)
# out_factors contains cumulative product - but we're not interested in that
tl.store(Z + offsets, out_values)
def cumulative_ema(
values: torch.Tensor,
factors: torch.Tensor,
) -> torch.Tensor:
"""
Compute cumulative exponential moving average on last axis of rank-3 inputs.
Args:
values: [B, N, T] float32 values
factors: [B, N, T] float32 decay factors
Returns:
cumulative ema values, same shape as values/factors.
"""
assert len(values.shape) == 3, values.shape
assert values.shape == factors.shape, (values.shape, factors.shape)
shape = values.shape
result = torch.empty_like(values)
kernel[(shape[0], shape[1])](
values,
factors,
result,
SEQUENCE_LENGTH=shape[2],
)
return result
if __name__ == "__main__":
shape = (1, 1, 64) # if this is (1, 1, 128) the result is incorrect even with the work-around cumulative_ema_op
values = np.arange(np.prod(shape)).reshape(shape)
factors = np.full(shape, 0.9)
expected = np.zeros(shape, np.float32)
expected[:, :, 0] = values[:, :, 0]
for i in range(1, shape[2]):
expected[:, :, i] = expected[:, :, i - 1] * factors[:, :, i] + values[:, :, i]
print(expected)
device = "cuda"
values = torch.tensor(values, dtype=torch.float32, device=device)
factors = torch.tensor(factors, dtype=torch.float32, device=device)
result = cumulative_ema(values, factors)
result = result.cpu()
print(result.numpy())
print(torch.max(torch.abs((result - expected) / expected)[..., 1:]))
Using the theoretically correct cumulative_ema_op implementation, this prints:
[[[ 0. 1. 2.9 5.61 9.049 13.1441
17.82969 23.046722 28.74205 34.867844 41.381058 48.24295
55.418655 62.87679 70.58911 78.5302 86.67718 95.00946
103.508514 112.15766 120.941895 129.8477 138.86293 147.97664
157.17897 166.46107 175.81497 185.23347 194.71013 204.23912
213.81522 223.4337 233.09033 242.7813 252.50317 262.25287
272.0276 281.82483 291.64233 301.4781 311.3303 321.19727
331.07755 340.9698 350.8728 360.78552 370.70697 380.63626
390.57263 400.51538 410.46384 420.41745 430.3757 440.33813
450.30432 460.2739 470.24652 480.22186 490.19968 500.17972
510.16174 520.14557 530.131 540.11786 ]]]
[[[ 0. 0.9 2.52 4.707 7.3314 10.283849 13.472494
16.820572 20.264309 23.751093 27.237875 30.689793 34.078945 37.38337
40.586117 43.674484 46.639313 49.47443 52.17614 54.742752 57.174282
59.47208 61.638573 63.677048 65.59144 67.386185 69.06606 70.63611
72.10148 73.467415 74.73914 75.92186 77.02064 78.04044 78.98608
79.86218 80.6732 81.42339 82.11682 82.757324 83.34856 83.893974
84.39682 84.86015 85.28685 85.67961 86.04095 86.37322 86.67863
86.95922 87.21691 87.45347 87.67055 87.869675 88.05227 88.21964
88.373024 88.513535 88.642204 88.76001 88.86783 88.966484 89.05673
89.13927 ]]]
If we look at the second and third values from the actual output (bottom array), we see our add operation has been evaluated in the wrong order:
- the second value should be
0 * 0.9 + 1 = 1, but instead we get0 + 0.9 * 1 - If we accept the second value as is, the third value should be
0.9 ** 2 * 0.9 + 2, but we instead get0.9 + 0.9**2 * 2 = 2.52
About this issue
- Original URL
- State: closed
- Created 9 months ago
- Reactions: 1
- Comments: 15 (2 by maintainers)
Since this issue was fixed one more issue remained to make associative_scan usable: FMA operations made scan numerically unstable.
enable_fp_fusion=False.I made a library for first order scans that uses
tl.associative_scanon triton 2.2.0: https://github.com/proger/accelerated-scanThanks @PeaBrane. Maybe this is only an issue on Cuda 12?
Edit: Can confirm the issue persists on A100’s and a 3090
Thanks for the analysis, this PR should solve it: https://github.com/openai/triton/pull/2362
For the multiple operand support I don’t have plans to work on it in the near future but anybody is welcome to contribute it.
The code worked for me, and the relative deviation from expected result is only
3.1392e-07. I’m running it ontriton-nightly 2.1.0.dev20231014192330.