consistency_models: QKVFlashAttention unexpected parameters error, running in Google Colab

I tried to generate samples in Colab and everything works except that I had to change this line of code in /cm/unet.py, clearing out factory_kwargs.

Not sure if this is a bug or I did something wrong. This is how I ran it: https://github.com/JonathanFly/consistency_models_colab_notebook/blob/main/Consistency_Models_Make_Samples.ipynb


class QKVFlashAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        batch_first=True,
        attention_dropout=0.0,
        causal=False,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        from einops import rearrange
        from flash_attn.flash_attention import FlashAttention

        assert batch_first
        #factory_kwargs = {"device": device, "dtype": dtype}
        factory_kwargs = {}
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Reactions: 4
  • Comments: 21

Most upvoted comments

@boxwayne @aarontan-git @asanakoy For the stride issue, I think it’s the rearrange issue because of flashAttn version.

        qkv = self.rearrange(
            qkv, "b (three h d) s -> b s three h d", three=3, h=self.num_heads
        )
        # print(qkv.shape, qkv.stride())
        qkv, _ = self.inner_attn(
            qkv.contiguous(),
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            causal=self.causal,
        )

The print result is torch.Size([1, 256, 3, 6, 64]) (256, 1, 98304, 16384, 256), which means the tensor after rearranging is no longer contiguous. (The old version might not require this while the new version requires it to be contiguous.) So I simply add a contiguous operations before calling inner_attn:

qkv=qkv.contiguous()

Let me know if that solves the issue. I tested on my side and it works.

i also meet the same error, i guess this code base using a previously version of flash_attn.

Logging to /tmp/openai-2023-04-13-14-33-55-278549
creating model and diffusion...
Traceback (most recent call last):
  File "/content/consistency_models/scripts/image_sample.py", line 143, in <module>
    main()
  File "/content/consistency_models/scripts/image_sample.py", line 37, in main
    model, diffusion = create_model_and_diffusion(
  File "/content/consistency_models/cm/script_util.py", line 76, in create_model_and_diffusion
    model = create_model(
  File "/content/consistency_models/cm/script_util.py", line 140, in create_model
    return UNetModel(
  File "/content/consistency_models/cm/unet.py", line 612, in __init__
    AttentionBlock(
  File "/content/consistency_models/cm/unet.py", line 293, in __init__
    self.attention = QKVFlashAttention(channels, self.num_heads)
  File "/content/consistency_models/cm/unet.py", line 359, in __init__
    self.inner_attn = FlashAttention(
TypeError: __init__() got an unexpected keyword argument 'device'

I use pip install flash-attn==0.2.8 solved it.

Solution: Do the following changes in File “/content/consistency_models/cm/unet.py”, line 359, in init

-        self.inner_attn = FlashAttention(
-            attention_dropout=attention_dropout, **factory_kwargs
-        )
+        self.inner_attn = FlashAttention(attention_dropout=attention_dropout)