flash-attention: Training with flash-attention causing loss to differ

Hi, since flash attention can reduce time cost and device memory usage significantly, we’d like to use it in our model training. However, we’ve encountered a loss difference with baseline when training:

  1. we use it in inference, it works just fine
  2. the same code in training infects our model’s performance.

Note: our model is a diffusion model with slight difference and trained with fp32. To fit flash-attention, we change the related part in CrossAttention to fp16 and then change the result back to fp32. We train with 8 gpus and ddp, inference only use 1 gpu.

Environments Info:

  • python 3.7.11
  • pytorch 1.11+cu113
  • lash-attn 0.2.2
  • A100 80G

partial code

        qb, qn, qd = q.shape
        kb, kn, kd = k.shape
        vb, vn, vd = v.shape
        using_flash = False
        if((qd // h) <= 128 and (kd // h) <= 128 and (vd // h) <= 128): using_flash = True
        if(opt_helper.using_flash and using_flash):
            # print('using flash')
            opt_helper.flash_cnt += 1
            q = q.reshape(qb * qn, h, qd//h).to(dtype=torch.float16)
            k = k.reshape(kb * kn, h, kd//h).to(dtype=torch.float16)
            v = v.reshape(vb * vn, h, vd//h).to(dtype=torch.float16)
            max_seqlen_q = qn
            max_seqlen_k = kn
            cu_seqlens_q = torch.arange(
                0,
                (qb + 1) * qn,
                step=qn,
                dtype=torch.int32,
                device=q.device,
            )
            cu_seqlens_k = torch.arange(
                0,
                (kb + 1) * kn,
                step=kn,
                dtype=torch.int32,
                device=k.device,
            )
            output = flash_attn_unpadded_func(
                q, k, v,
                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                dropout_p = 0.,
                softmax_scale = self.scale,
            )
            out = output.reshape(qb, qn, qd).float()

In pytorch AMP, softmax using fp32, we kind of wonder would this be an issue? our query num(qn, kn) might be large(e.g. 4096)

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 20 (10 by maintainers)

Most upvoted comments

FlashAttention should give the same output as standard attention, up to some numerical tolerance (we have tests for this). Internally softmax is done in fp32 as well (matrix multiply Q @ K^T and attn @ V are done in fp16/bf16).

The comparison we usually make is: (FlashAttention in fp16/bf16 - standard attention in fp32) vs (standard attention in fp16/bf16 - standard attention in fp32).

In all of our use cases the errors between the two are around the same.

For your case:

  1. Can you print out these two numerical differences to compare?
  2. Can you try running standard attention (softmax(Q @ K^T) @ V) in fp16 to see if that affects your model quality. If that’s the case, that means that the model is simply sensitive to numerical differences and you’d want to run attention in fp32 anyway (so just use standard attention in that case). For the use cases we’ve seen (training language models & vision transformers), swapping out standard attention in fp16/bf16 with FlashAttention in fp16/bf16 does not change the model quality.