flash-attention: Illegal Memory Access when number of keys != queries

I’m experimenting with a scenario where I have more keys than queries (e.g., a common scenario during inference where the prior keys/values are cached). When trying to use different query and keys, I get:

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

This only happens during the backward pass. I can run the forward pass and it won’t crash. Seems like it’s related to FlashAttention indexing something that is out of bounds?

  File "/workspace/models/attention.py", line 176, in flash_attention
    causal=True
  File "/opt/conda/lib/python3.7/site-packages/flash_attn-0.1-py3.7-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 243, in flash_attn_unpadded_func
    dropout_p, softmax_scale, causal, return_attn_probs)
 (Triggered internally at  /opt/conda/conda-bld/pytorch_1656352464346/work/torch/csrc/autograd/python_anomaly_mode.cpp:102.)
  allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
Error executing job with overrides: ['name=test-tx-persist', '+experiment=tx']

Note: In my scenario, all batches have the same sequence length. I set cu seqlens q and k as the cumsum of [0]+[q_seq_len] * b and cumsum of [0]+[k_sq] * b respectively.

If I set both keys and queries to have the same sequence length, then everything works.

About this issue

  • Original URL
  • State: open
  • Created 2 years ago
  • Comments: 24 (17 by maintainers)

Most upvoted comments

Ya that makes sense. I’ll try to get to it this weekend (been busy with preparing for ICML talks etc.)

@jesus-andres-ferrer haha, masking is definitely one of the most confusing topics in attention and transformers

Hi, First of all I would like to congratulate the for this neat work. I am having a lot of fun going through the code and understanding it.

Right now causal means query i will attend to key 1, 2, ..., i. I’m thinking mainly of training auto-regressive LM, where seqlen_q == seqlen_k. I’m not too sure what causal should mean when seqlen_q != seqlen_k. Right now, if there are 2 queries and 10 keys, then query 1 will attend to key 1 and query 2 will attend to key 1 & 2. I’m open to suggestions here (what causal should mean if seqlen_q != seqlen_k). It’s all just simple index calculation in CUDA so it’ll be pretty easy to change.

When is_causal=True and seqlen_q < seqlen_k (not sure what to do in the opposite case seqlen_q > seqlen_k ), I would have expected personally the opposite behaviour, that is if ‘causal’ that should mean query ‘i’ can attend to keys ‘seqlen_k - i, …, seqlen_k-1, seqlen_k’ . This will simulate autoregressive with prefix and will be more useful than ‘1, 2, …, i’, which essentially means that seqlen_k - seqlen_q is ignored, if I understand properly and apologies if misunderstood on my side.

I think this will automatically cover some of the cases above ? Maybe a templated functor can be provided to the kernels to index left, rightmost or other patterns that would/could help other efficient user cases ?

Thank you very much,

if you did it this way, you would simply remove the k_len - q_len keys that are omitted from attention altogether. there wouldn’t be any point to passing it in

the most standard use-case for k_len > q_len when autoregressive setting is turned on (not cross attention), is what Henry described

@tridao Although it no longer crashes, for the example you showed, how would I specify the starting position for the query to attend to? For example:

I have a causal attention setup with 1 query and 10 keys. Does this mean the query will attend to all keys, or just the first key because it’s causal? How can I specify which keys are “before” or “after” the first query?

Another example: If I have 2 queries and 10 keys. If it’s a causal setup, does this mean the 1st query will attend to the first 9 keys, and the 2nd query will attend to all 10 keys? This is the typical scenario for causal attention inference. When I tested your example with 1 query and 2 keys, it doesn’t seem to exhibit the above expected behavior.

Some clarification on the docs for this would be great!

Right now causal means query i will attend to key 1, 2, ..., i. I’m thinking mainly of training auto-regressive LM, where seqlen_q == seqlen_k. I’m not too sure what causal should mean when seqlen_q != seqlen_k. Right now, if there are 2 queries and 10 keys, then query 1 will attend to key 1 and query 2 will attend to key 1 & 2.

I’m open to suggestions here (what causal should mean if seqlen_q != seqlen_k). It’s all just simple index calculation in CUDA so it’ll be pretty easy to change.

Eventually I want to support full N x N attention bias / mask, but that will probably be slightly less efficient since we need to read in an N x N matrix.

@jesus-andres-ferrer ah no problem, yes, I think both Henry and I are referring to the first configuration!

1 1 1 1 1 1 1 0 0 0 0
1 1 1 1 1 1 1 1 0 0 0
1 1 1 1 1 1 1 1 1 0 0
1 1 1 1 1 1 1 1 1 1 0
1 1 1 1 1 1 1 1 1 1 1

we are all on the same page then

Hi, I think we intended to say the same, maybe I failed to specify the idea on the indexes above, now not sure I understood the indexing you were using because answer above when I answered or I got confused myself. I think it is more clear in the code for me: begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin; this IIUC, should be: begin = begin; and the end should be modified so that the causal mask should be applied to the rightmost square (as given by q_len) elements of K while we can attend to the leftmost rectangle prefix (that shouldn’t be zeroed out). That only makes sense when q_len < k_len or when k_len > q_len. I agree specification above is incorrect. In both cases, it won’t make sense to pass the delta columns from k and v, independently of whether we focus on the leftmost square to apply self-attention or on the rightmost. I essentially intended to refer to below for (k_len=11 and q_len = 5) where 1 means attention is possible and 0 means no attention:

1 1 1 1 1 1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1

instead of IIUC

1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0

Obviously below is likewise not useful IMO

0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 1 1 1 1 1

which I think is what I specified now because of the answer, I tried to fix in the original comment, IIUC now.

Apologies for the miss-specification,

@calclavia yea i agree. i don’t think Tri needs to support some of the newer works like the causal cross attention from flamingo or retro, they are very specialized cases

in which case just passing in the difference between the lengths of the queries and keys is sufficient. this would fix the inference causal attention + cache, transformer-xl like memories, as well as the new perceiver AR attention + some other brain works

re: @lucidrains I’m looking at the flash_attention_interface and it doesn’t seem like there’s a way to pass in a custom mask right now. As @tridao hinted, this is probably a future work.

I have a simple proposal re: @tridao

As to how causal should behave when num keys != num queries, my thought is that the current behavior probably has no use case:

Right now, if there are 2 queries and 10 keys, then query 1 will attend to key 1 and query 2 will attend to key 1 & 2.

This is equivalent to just slicing out the last 8 keys and doing non-prefix causal attention. If the user’s intention is to do this, they could pre-slice the qkv inputs to achieve this effect. I think the right behavior should be:

If there are 2 queries and 10 keys, then query 1 will attend to key 1 to 9 and query 2 will attend to key 1 to 10 (all keys).

This is causal attention with prefix and is shown in the image by @lucidrains. @lucidrains 's proposed idea with passing time indices seems more general but I personally don’t see a use case where you need this. You can always achieve the equivalent via slicing the key tensor prior to passing it to flash attention since, effectively, some of the tail keys will never be attended to anyway. <- Perhaps I’m missing some use cases where passing time indices would be helpful and cannot be achieved via preparing the input.

@lucidrains Yep, my training curves seem to match that of regular attention on preliminary tests on enwiki8.

Thanks for putting up that picture as it helps illustrate my causal + prefix case. I think having to supply a full attention-mask would be more general than supporting causal with the prefix, but causal with prefix should probably be the default scenario when num keys != num queries. Otherwise, it would be equivalent to computing attention with the final few keys trimmed off.