examples: RuntimeError: Triton Error [CUDA]: invalid argument

Getting the following issue when running mosaic-bert recipe. Only with bf16, works with fp32.

Traceback (most recent call last):
  File "<string>", line 21, in _bwd_kernel
KeyError: ('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-7929002797455b30efce6e41eddc6b57-3aa563e00c5c695dd945e23b09a86848-d962222789c30252d492a16cca3bf467-ff946bd4b3b4a4cbdf8cedc6e1c658e0-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.bfloat16, torch.float32, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('matrix', False, 64, False, True, True, True, 128, 128), (True, True, True, True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False)))
During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/coc/scratch/sreddy65/examples/examples/bert/main.py", line 141, in <module>
    main(cfg)
  File "/coc/scratch/sreddy65/examples/examples/bert/main.py", line 128, in main
    trainer.fit()
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 1787, in fit
    self._train_loop()
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 1950, in _train_loop
    total_loss_dict = self._train_batch(use_grad_scaling)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2126, in _train_batch
    optimizer.step(closure=lambda **kwargs: self._train_microbatches(
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 68, in wrapper
    return wrapped(*args, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/optim/optimizer.py", line 140, in wrapper
    out = func(*args, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/optim/decoupled_weight_decay.py", line 289, in step
    loss = closure()
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2126, in <lambda>
    optimizer.step(closure=lambda **kwargs: self._train_microbatches(
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2209, in _train_microbatches
    microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2305, in _train_microbatch
    microbatch_loss.backward(create_graph=self._backwards_create_graph)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/autograd/function.py", line 267, in apply
    return user_fn(self, *args)
  File "/nethome/sreddy65/examples/examples/bert/src/flash_attn_triton.py", line 1041, in backward
    _flash_attn_backward(do,
  File "/nethome/sreddy65/examples/examples/bert/src/flash_attn_triton.py", line 949, in _flash_attn_backward
    _bwd_kernel[grid](  # type: ignore
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 73, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 73, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 63, in _bench
    return do_bench(kernel_call)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/testing.py", line 136, in do_bench
    fn()
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 62, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 200, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 43, in _bwd_kernel
RuntimeError: Triton Error [CUDA]: invalid argument

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Reactions: 3
  • Comments: 17 (2 by maintainers)

Most upvoted comments

thanks, really appreciate it! i’ll mess around with versions (probably later this week) and see if that fixes things

hey @eldarkurtic here is the main diff. You can drop the if condition (I put the or True and forgot to remove it on my fork)

diff --git a/examples/bert/src/bert_layers.py b/examples/bert/src/bert_layers.py
index 4f8403c..db68ba8 100644
--- a/examples/bert/src/bert_layers.py
+++ b/examples/bert/src/bert_layers.py
@@ -209,18 +209,22 @@ class BertUnpadSelfAttention(nn.Module):
                         'b s (t h d) -> b s t h d',
                         t=3,
                         h=self.num_attention_heads)
-        if self.p_dropout or flash_attn_qkvpacked_func is None:
+        # NOTE: FLASH ATTENTION
+        if self.p_dropout or flash_attn_qkvpacked_func is None or True:
             # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
             q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3)  # b h s d
-            k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1)  # b h d s
+            # k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1)  # b h d s
+            k = qkv[:, :, 1, :, :].permute(0, 2, 1, 3)  # b h s d
             v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3)  # b h s d
-            attention_scores = torch.matmul(q, k) / math.sqrt(
-                self.attention_head_size)
-            attention_scores = attention_scores + bias
-            attention_probs = nn.functional.softmax(attention_scores, dim=-1)
-            attention_probs = self.dropout(attention_probs)
-            attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
-                                                                 3)  # b s h d
+            # attention_scores = torch.matmul(q, k) / math.sqrt(
+            #     self.attention_head_size)
+            # attention_scores = attention_scores + bias
+            # attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+            # attention_probs = self.dropout(attention_probs)
+            # attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3)
+            # ALWAYS RUNS
+            attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.p_dropout, is_causal=False)
+            attention = attention.permute(0, 2, 1, 3)
         else:
             # Triton implementation only supports 0 attention dropout
             convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]

and a file to just test if flash attention is avaliable in your env

diff --git a/examples/bert/test_pytorch2.0_attention.py b/examples/bert/test_pytorch2.0_attention.py
new file mode 100644
index 0000000..b30140d
--- /dev/null
+++ b/examples/bert/test_pytorch2.0_attention.py
@@ -0,0 +1,51 @@
+# Lets define a helpful benchmarking function:
+import torch.utils.benchmark as benchmark
+import torch.nn.functional as F
+import torch
+def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
+    t0 = benchmark.Timer(
+        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
+    )
+    return t0.blocked_autorange().mean * 1e6
+
+# Lets define the hyper-parameters of our input
+device = 'cuda'
+batch_size = 32
+max_sequence_len = 1024
+num_heads = 32
+embed_dimension = 32
+
+dtype = torch.float16
+
+query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
+key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
+value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
+
+print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
+
+# Lets explore the speed of each of the 3 implementations
+from torch.backends.cuda import sdp_kernel, SDPBackend
+
+# Helpful arg mapper
+backend_map = {
+    SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
+    SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
+    SDPBackend.EFFICIENT_ATTENTION: {
+        "enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
+}
+
+with sdp_kernel(**backend_map[SDPBackend.MATH]):
+    print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
+
+
+with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
+    try:
+        print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
+    except RuntimeError:
+        print("FlashAttention is not supported. See warnings for reasons.")
+
+with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
+    try:
+        print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
+    except RuntimeError:
+        print("EfficientAttention is not supported. See warnings for reasons.")

Still happens with python 3.9.16