CoLT5-attention: Wrong results given by triton bwd
here should be last_da = tl.sum(ds, axis = 0)
test script:
import torch
from torch.autograd import grad
from colt5_attention import coor_descent
from colt5_attention.triton_coor_descent import triton_coor_descent
s = torch.randn(512).cuda().requires_grad_()
result1 = coor_descent(s, n_iters=50, k=64, eps=1.0)
result2 = triton_coor_descent(s, n_iters=50, k=64, eps=1.0, checkpoint_segments=1)
g1 = grad(result1[:32].sum(), s)[0]
g2 = grad(result2[:32].sum(), s)[0]
print((g1 - g2).norm())
A suggestion. I recommand to add a type conversion here https://github.com/lucidrains/CoLT5-attention/blob/860081fcaabfc0125e8530507f4bf8d802596281/colt5_attention/transformer_block.py#L436-L438 s.to(torch.float32)
. Then the code works in half-precision training.
About this issue
- Original URL
- State: closed
- Created 8 months ago
- Comments: 16 (10 by maintainers)
Commits related to this issue
- address https://github.com/lucidrains/CoLT5-attention/issues/8 — committed to lucidrains/CoLT5-attention by lucidrains 4 months ago
It looks good. Thank you!
I agree. I am working on some experiments about auto-regressive colt5. I have read your
mixture-of-attention
and I use it as the start point.Sorry, my machine is not accessable publicly. It is a little hard to share with you.
I use your way to install triton, and I get
~1e-8
error. I test two machines using my way to install triton and get similar large error.I believe this is a triton issue.
from the openai’s repo:
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
and the version is
2.1.0.post20231114151003
.I get a huge error like 14.92 using 0.10.15. It is really strange if you have such a ignorable error.
I have one doubt that there is a problem.
checkpoint_segments > 1
makescoor_descent_kernel_backward
to be called multiple times. And every time enteringcoor_descent_kernel_backward
will add nonzerolast_da
once. In my test case, settingcheckpoint_segments=n_iters
will avoid the error.