triton: TypeError: dot() got an unexpected keyword argument 'trans_b'

Hi all,

I am working on integrating the Triton version of Flash Attention in a GPT-like model.

For some reason I am receiving this error: TypeError: dot() got an unexpected keyword argument 'trans_b'

Here is a snippet of the code where the error is occurring:

import math
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
import flash_attn.flash_attn_triton as flash_attn_triton

from einops import rearrange, repeat

@dataclass
class GPT2Config:
    num_heads = 8
    head_dim = 64
    hidden_dim = 512
    attn_pdrop = 0.1
    resid_pdrop = 0.1

class NewGELUActivation(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """

    # noinspection PyMethodMayBeStatic
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return (
            0.5 * x * (1.0 + torch.tanh(
                math.sqrt(2.0 / math.pi)
                * (x + 0.044715 * torch.pow(x, 3.0))
            ))
        )

class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
    Basically works like a linear layer but the weights are transposed.
    Args:
        nf (`int`): The number of output features.
        nx (`int`): The number of input features.
    """

    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):

        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)

        return x

class GPT2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        inner_dim = config.num_heads * config.head_dim
        self.c_attn = Conv1D(3 * inner_dim, config.hidden_dim)
        self.c_proj = Conv1D(config.hidden_dim, inner_dim)

        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.shape
        num_heads, head_dim = self.config.num_heads, self.config.head_dim

        # x.shape -> torch.Size([1, 512, 512])

        qkv = self.c_attn(x).chunk(3, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h = num_heads), qkv)

        # batch_size, seq_len, num_heads, head_dim
        # q.shape, k.shape, v.shape -> torch.Size([1, 512, 8, 64])

        flash_attn_out = flash_attn_triton.flash_attn_func(
            q, 
            k, 
            v,
            None,
            True,
            1.0
        )

        out = flash_attn_out.contiguous().view(
            batch_size, seq_len, hidden_dim
            )

        attn_out = self.c_proj(out)

        attn_out = self.resid_dropout(attn_out)

        return attn_out

# Test GPT2Attention
config = GPT2Config()

attention = GPT2Attention(config).to(torch.float16).cuda()

print(attention(torch.randn(1, 512, 512).to(torch.float16).cuda()))

Any help would be appreciated.

Thank you,

Enrico

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 19 (3 by maintainers)

Commits related to this issue

Most upvoted comments

Yeah, Ubuntu 18.04 includes GCC 7.5.0 by default, which doesn’t support C++17. You can still use Ubuntu 18.04 but you have to install a more recent version of gcc/clang. As for the other errors, we’ll look into it when we have some time, we’re still firefighting some issues with the new backend.

@ptillet Works now with gcc/g++ 9. Was able to install from the source and run the Triton version of Flash Attention.

And FYI wheel built on Ubuntu 22 + CUDA 12 should work everywhere. Triton isn’t tied to any CUDA version.