triton: Dot Product Computes Wrong Values

First off - this is really cool library! I’m trying to implement the causal dot product kernel, but I’m running into some issues. It could be either a bug or my misunderstanding of the documentation.

Triton code: https://github.com/calclavia/Triton-Transformer/blob/master/ttx/attention/causal_product.py#L26

The above implements the algorithm from “Transformers are RNNs” paper (Algorithm 1) in Triton. In summary, I’m trying to batch-parallelize a for loop that computes a prefix sum. This is a simple O(n) implementation (not the more efficient O(log(n)) version).

The equivalent naive Pytorch implementation is here: https://github.com/calclavia/Triton-Transformer/blob/master/ttx/attention/causal_product.py#L6

When running a simple unit test, I’m getting very different values.

ref_output tensor([[[[0.0157, 0.6251, 0.6993],
          [0.7910, 2.1930, 2.1000],
          [0.7413, 1.7217, 1.4139],
          [0.8863, 1.4795, 1.5222],
          [1.2453, 2.5844, 1.9665]]]])
triton_output tensor([[[0.4268, 0.6251, 0.6993],
         [2.4132, 0.6186, 0.3389],
         [0.8975, 0.4929, 0.2288],
         [0.8470, 0.0080, 0.3058],
         [1.1330, 0.7073, 0.0776]]])

I tested a single dimensional vector and was unable to get matching values. The logic seems to be correct, but I suspect the issue is related to tl.dot. If anyone has insights, I would appreciate comments/feedback!

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 2
  • Comments: 23 (3 by maintainers)

Most upvoted comments

@calclavia It’s a bit tricky. Our goal is actually to extend Triton-IR to support indexing operations, then prefix sum could just be implemented using an algorithm like https://en.wikipedia.org/wiki/Prefix_sum#Parallel_algorithms without having dedicated Triton-IR instructions. Then the compiler would look at all the indexing ops and add some shared memory barriers (and warp shuffles) accordingly to maximize performance without breaking correctness.

Pattern matching across loop boundaries tends to be pretty hard and brittle. So in order to avoid that, Triton tries really hard to find abstractions that make sense :p It’ll take a while before indexing ops are part of the ecosystem (I’m thinking maybe by Jan - Feb), but hopefully we’ll get a prototype before then that will be enough for causal attention.

Anyhow, the issues you’re having are unrelated and I’ll try to address them sooner than that 😄

Hey~ Sorry for the delay. I’ll be looking into this tonight or tomorrow. Since the dot unit tests pass, I wonder what is causing the results to be incorerct in this case. Maybe some interaction between broadcasting and tensor cores.

Hey!

First of all, thanks for digging into this issue. Sorry for having been unresponsive, I’ve been quite busy with other matters, and I am well aware of the instabilities of triton 😅 I will look into this problem. FWIW, we are planning a rewrite of the backend that should greatly improve stability on these kinds of issues.

Same wrong values for small blocks here! Is there any plan to fix such a bug?

Hey! I looked into this and there are multiple issues at play:

  • I think what you ran into is an issue for dot products when the blocks are too small. It’s quite tricky to get the thread partitioning right when the compiler needs to distribute 16 elements over hundreds of threads. This is one of the areas where the compiler is buggy at the moment.
  • It seems like with the most recent dev version, your causal attention makes the shared memory allocator hang. So that’s another issue that I’ll have to look into

FWIW I think causal attention will be much easier to implement when triton.language contains prefix sum primitives.

@lucidrains I haven’t tried CUDA Python. I’m not an expert at CUDA programming, hence Triton seems like a nice alternative that is easy to learn and can get the job done.

A brief look at EPFL’s CUDA seems to indicate they’re using an O(n) implementation in their fallback code: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/causal_product/causal_product_cuda.cu#L1204 but the optimized version is too complex for me to understand. It should be possible to implement the O(log(n)) version on Triton, it’ll just take some time to figure out how to port the prefix sum algorithm described by NVIDIA.

@lucidrains Great to see you’re interested! I was hoping this could be a drop in replacement for your Performers repo to solve this issue (https://github.com/lucidrains/performer-pytorch/issues/44)

Hey! Sorry I haven’t had time to look into it. I’ve been busy working on older issues (#170 and #176 ) that require quite a bit of work