triton: TypeError: dot() got an unexpected keyword argument 'trans_b'
Hi all,
I am working on integrating the Triton version of Flash Attention in a GPT-like model.
For some reason I am receiving this error: TypeError: dot() got an unexpected keyword argument 'trans_b'
Here is a snippet of the code where the error is occurring:
import math
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
import flash_attn.flash_attn_triton as flash_attn_triton
from einops import rearrange, repeat
@dataclass
class GPT2Config:
num_heads = 8
head_dim = 64
hidden_dim = 512
attn_pdrop = 0.1
resid_pdrop = 0.1
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
# noinspection PyMethodMayBeStatic
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (
0.5 * x * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi)
* (x + 0.044715 * torch.pow(x, 3.0))
))
)
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
class GPT2Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
inner_dim = config.num_heads * config.head_dim
self.c_attn = Conv1D(3 * inner_dim, config.hidden_dim)
self.c_proj = Conv1D(config.hidden_dim, inner_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x):
batch_size, seq_len, hidden_dim = x.shape
num_heads, head_dim = self.config.num_heads, self.config.head_dim
# x.shape -> torch.Size([1, 512, 512])
qkv = self.c_attn(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h = num_heads), qkv)
# batch_size, seq_len, num_heads, head_dim
# q.shape, k.shape, v.shape -> torch.Size([1, 512, 8, 64])
flash_attn_out = flash_attn_triton.flash_attn_func(
q,
k,
v,
None,
True,
1.0
)
out = flash_attn_out.contiguous().view(
batch_size, seq_len, hidden_dim
)
attn_out = self.c_proj(out)
attn_out = self.resid_dropout(attn_out)
return attn_out
# Test GPT2Attention
config = GPT2Config()
attention = GPT2Attention(config).to(torch.float16).cuda()
print(attention(torch.randn(1, 512, 512).to(torch.float16).cuda()))
Any help would be appreciated.
Thank you,
Enrico
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 19 (3 by maintainers)
Commits related to this issue
- in reference to issue https://github.com/openai/triton/issues/1098 , updated as per suggested solution — committed to bmedishe/DeepSpeed by bmedishe a year ago
Yeah, Ubuntu 18.04 includes GCC 7.5.0 by default, which doesn’t support C++17. You can still use Ubuntu 18.04 but you have to install a more recent version of gcc/clang. As for the other errors, we’ll look into it when we have some time, we’re still firefighting some issues with the new backend.
@ptillet Works now with gcc/g++ 9. Was able to install from the source and run the Triton version of Flash Attention.
And FYI wheel built on Ubuntu 22 + CUDA 12 should work everywhere. Triton isn’t tied to any CUDA version.