consistency_models: QKVFlashAttention unexpected parameters error, running in Google Colab
I tried to generate samples in Colab and everything works except that I had to change this line of code in /cm/unet.py, clearing out factory_kwargs.
Not sure if this is a bug or I did something wrong. This is how I ran it: https://github.com/JonathanFly/consistency_models_colab_notebook/blob/main/Consistency_Models_Make_Samples.ipynb
class QKVFlashAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
batch_first=True,
attention_dropout=0.0,
causal=False,
device=None,
dtype=None,
**kwargs,
) -> None:
from einops import rearrange
from flash_attn.flash_attention import FlashAttention
assert batch_first
#factory_kwargs = {"device": device, "dtype": dtype}
factory_kwargs = {}
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.causal = causal
About this issue
- Original URL
- State: closed
- Created a year ago
- Reactions: 4
- Comments: 21
@boxwayne @aarontan-git @asanakoy For the
stride
issue, I think it’s the rearrange issue because of flashAttn version.The print result is
torch.Size([1, 256, 3, 6, 64]) (256, 1, 98304, 16384, 256)
, which means the tensor after rearranging is no longer contiguous. (The old version might not require this while the new version requires it to be contiguous.) So I simply add a contiguous operations before callinginner_attn
:Let me know if that solves the issue. I tested on my side and it works.
i also meet the same error, i guess this code base using a previously version of flash_attn.
I use pip install flash-attn==0.2.8 solved it.
Solution: Do the following changes in File “/content/consistency_models/cm/unet.py”, line 359, in init
The version of v1.0.2 has no device parameter. https://github.com/HazyResearch/flash-attention/blob/v1.0.2/flash_attn/flash_attention.py#L21
But v0.2.8 has device parameter. https://github.com/HazyResearch/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py#L21