flash-attention: Training with flash-attention causing loss to differ
Hi, since flash attention can reduce time cost and device memory usage significantly, we’d like to use it in our model training. However, we’ve encountered a loss difference with baseline when training:
- we use it in inference, it works just fine
- the same code in training infects our model’s performance.
Note: our model is a diffusion model with slight difference and trained with fp32. To fit flash-attention, we change the related part in CrossAttention to fp16 and then change the result back to fp32. We train with 8 gpus and ddp, inference only use 1 gpu.
Environments Info:
- python 3.7.11
- pytorch 1.11+cu113
- lash-attn 0.2.2
- A100 80G
partial code
qb, qn, qd = q.shape
kb, kn, kd = k.shape
vb, vn, vd = v.shape
using_flash = False
if((qd // h) <= 128 and (kd // h) <= 128 and (vd // h) <= 128): using_flash = True
if(opt_helper.using_flash and using_flash):
# print('using flash')
opt_helper.flash_cnt += 1
q = q.reshape(qb * qn, h, qd//h).to(dtype=torch.float16)
k = k.reshape(kb * kn, h, kd//h).to(dtype=torch.float16)
v = v.reshape(vb * vn, h, vd//h).to(dtype=torch.float16)
max_seqlen_q = qn
max_seqlen_k = kn
cu_seqlens_q = torch.arange(
0,
(qb + 1) * qn,
step=qn,
dtype=torch.int32,
device=q.device,
)
cu_seqlens_k = torch.arange(
0,
(kb + 1) * kn,
step=kn,
dtype=torch.int32,
device=k.device,
)
output = flash_attn_unpadded_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p = 0.,
softmax_scale = self.scale,
)
out = output.reshape(qb, qn, qd).float()
In pytorch AMP, softmax using fp32, we kind of wonder would this be an issue? our query num(qn, kn) might be large(e.g. 4096)
About this issue
- Original URL
- State: open
- Created a year ago
- Comments: 20 (10 by maintainers)
FlashAttention should give the same output as standard attention, up to some numerical tolerance (we have tests for this). Internally softmax is done in fp32 as well (matrix multiply Q @ K^T and attn @ V are done in fp16/bf16).
The comparison we usually make is: (FlashAttention in fp16/bf16 - standard attention in fp32) vs (standard attention in fp16/bf16 - standard attention in fp32).
In all of our use cases the errors between the two are around the same.
For your case: