flash-attention: Illegal Memory Access when number of keys != queries
I’m experimenting with a scenario where I have more keys than queries (e.g., a common scenario during inference where the prior keys/values are cached). When trying to use different query and keys, I get:
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
This only happens during the backward pass. I can run the forward pass and it won’t crash. Seems like it’s related to FlashAttention indexing something that is out of bounds?
File "/workspace/models/attention.py", line 176, in flash_attention
causal=True
File "/opt/conda/lib/python3.7/site-packages/flash_attn-0.1-py3.7-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 243, in flash_attn_unpadded_func
dropout_p, softmax_scale, causal, return_attn_probs)
(Triggered internally at /opt/conda/conda-bld/pytorch_1656352464346/work/torch/csrc/autograd/python_anomaly_mode.cpp:102.)
allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
Error executing job with overrides: ['name=test-tx-persist', '+experiment=tx']
Note: In my scenario, all batches have the same sequence length. I set cu seqlens q and k as the cumsum of [0]+[q_seq_len] * b and cumsum of [0]+[k_sq] * b respectively.
If I set both keys and queries to have the same sequence length, then everything works.
About this issue
- Original URL
- State: open
- Created 2 years ago
- Comments: 24 (17 by maintainers)
Ya that makes sense. I’ll try to get to it this weekend (been busy with preparing for ICML talks etc.)
@jesus-andres-ferrer haha, masking is definitely one of the most confusing topics in attention and transformers
if you did it this way, you would simply remove the
k_len - q_lenkeys that are omitted from attention altogether. there wouldn’t be any point to passing it inthe most standard use-case for
k_len > q_lenwhen autoregressive setting is turned on (not cross attention), is what Henry describedRight now
causalmeans queryiwill attend to key1, 2, ..., i. I’m thinking mainly of training auto-regressive LM, whereseqlen_q == seqlen_k. I’m not too sure whatcausalshould mean whenseqlen_q != seqlen_k. Right now, if there are 2 queries and 10 keys, then query 1 will attend to key 1 and query 2 will attend to key 1 & 2.I’m open to suggestions here (what causal should mean if
seqlen_q != seqlen_k). It’s all just simple index calculation in CUDA so it’ll be pretty easy to change.Eventually I want to support full N x N attention bias / mask, but that will probably be slightly less efficient since we need to read in an N x N matrix.
@jesus-andres-ferrer ah no problem, yes, I think both Henry and I are referring to the first configuration!
we are all on the same page then
Hi, I think we intended to say the same, maybe I failed to specify the idea on the indexes above, now not sure I understood the indexing you were using because answer above when I answered or I got confused myself. I think it is more clear in the code for me:
begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin;this IIUC, should be:begin = begin;and the end should be modified so that the causal mask should be applied to the rightmost square (as given by q_len) elements of K while we can attend to the leftmost rectangle prefix (that shouldn’t be zeroed out). That only makes sense when q_len < k_len or when k_len > q_len. I agree specification above is incorrect. In both cases, it won’t make sense to pass the delta columns from k and v, independently of whether we focus on the leftmost square to apply self-attention or on the rightmost. I essentially intended to refer to below for (k_len=11 and q_len = 5) where 1 means attention is possible and 0 means no attention:1 1 1 1 1 1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1
instead of IIUC
1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0
Obviously below is likewise not useful IMO
0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 1 1 1 1 1
which I think is what I specified now because of the answer, I tried to fix in the original comment, IIUC now.
Apologies for the miss-specification,
@calclavia yea i agree. i don’t think Tri needs to support some of the newer works like the causal cross attention from flamingo or retro, they are very specialized cases
in which case just passing in the difference between the lengths of the queries and keys is sufficient. this would fix the inference causal attention + cache, transformer-xl like memories, as well as the new perceiver AR attention + some other brain works
re: @lucidrains I’m looking at the flash_attention_interface and it doesn’t seem like there’s a way to pass in a custom mask right now. As @tridao hinted, this is probably a future work.
I have a simple proposal re: @tridao
As to how causal should behave when num keys != num queries, my thought is that the current behavior probably has no use case:
This is equivalent to just slicing out the last 8 keys and doing non-prefix causal attention. If the user’s intention is to do this, they could pre-slice the qkv inputs to achieve this effect. I think the right behavior should be:
If there are 2 queries and 10 keys, then query 1 will attend to key 1 to 9 and query 2 will attend to key 1 to 10 (all keys).
This is causal attention with prefix and is shown in the image by @lucidrains. @lucidrains 's proposed idea with passing time indices seems more general but I personally don’t see a use case where you need this. You can always achieve the equivalent via slicing the key tensor prior to passing it to flash attention since, effectively, some of the tail keys will never be attended to anyway. <- Perhaps I’m missing some use cases where passing time indices would be helpful and cannot be achieved via preparing the input.
@lucidrains Yep, my training curves seem to match that of regular attention on preliminary tests on enwiki8.
Thanks for putting up that picture as it helps illustrate my causal + prefix case. I think having to supply a full attention-mask would be more general than supporting causal with the prefix, but causal with prefix should probably be the default scenario when num keys != num queries. Otherwise, it would be equivalent to computing attention with the final few keys trimmed off.