CoLT5-attention: Wrong results given by triton bwd

https://github.com/lucidrains/CoLT5-attention/blob/860081fcaabfc0125e8530507f4bf8d802596281/colt5_attention/triton_coor_descent.py#L291

here should be last_da = tl.sum(ds, axis = 0)

test script:


import torch
from torch.autograd import grad
from colt5_attention import coor_descent
from colt5_attention.triton_coor_descent import triton_coor_descent

s = torch.randn(512).cuda().requires_grad_()

result1 = coor_descent(s, n_iters=50, k=64, eps=1.0)
result2 = triton_coor_descent(s, n_iters=50, k=64, eps=1.0, checkpoint_segments=1)

g1 = grad(result1[:32].sum(), s)[0]
g2 = grad(result2[:32].sum(), s)[0]
print((g1 - g2).norm())

A suggestion. I recommand to add a type conversion here https://github.com/lucidrains/CoLT5-attention/blob/860081fcaabfc0125e8530507f4bf8d802596281/colt5_attention/transformer_block.py#L436-L438 s.to(torch.float32). Then the code works in half-precision training.

About this issue

  • Original URL
  • State: closed
  • Created 8 months ago
  • Comments: 16 (10 by maintainers)

Commits related to this issue

Most upvoted comments

It looks good. Thank you!


@LouChao98 ok, let me try. perhaps the triton version i’m on is bugged out (triton is always a bit rough on the edges) and masks some error

i want this repository to be perfect, as coordinate descent is very promising

I agree. I am working on some experiments about auto-regressive colt5. I have read your mixture-of-attention and I use it as the start point.

@LouChao98 you don’t happen to have a working machine you can lend me to debug this in the next few hours? lol

Sorry, my machine is not accessable publicly. It is a little hard to share with you.

I use your way to install triton, and I get ~1e-8 error. I test two machines using my way to install triton and get similar large error.

I believe this is a triton issue.

@LouChao98 how did you install triton? i did pip install triton --pre -U. it shows that i have triton 2.1 installed

from the openai’s repo: pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly

and the version is 2.1.0.post20231114151003.

I get a huge error like 14.92 using 0.10.15. It is really strange if you have such a ignorable error.

I have one doubt that there is a problem. checkpoint_segments > 1 makes coor_descent_kernel_backward to be called multiple times. And every time entering coor_descent_kernel_backward will add nonzero last_da once. In my test case, setting checkpoint_segments=n_iters will avoid the error.