DNABERT_2: CompilationError: at 114:24:
Epoch [1/3]
KeyError Traceback (most recent call last) File <string>:21, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)
KeyError: (‘2-.-0-.-0–d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-14de7de5c4da5794c8ca14e7e41a122d-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c’, (torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float32, torch.float32, ‘fp32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’), (‘matrix’, False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (False, False), (True, False), (True, False), (True, False), (False, False), (False, False), (False, False), (True, False), (True, False), (False, False), (False, False)))
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last) File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:937, in build_triton_ir(fn, signature, specialization, constants) 936 try: –> 937 generator.visit(fn.parse()) 938 except Exception as e:
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:183, in CodeGenerator.visit_Module(self, node) 182 def visit_Module(self, node): –> 183 ast.NodeVisitor.generic_visit(self, node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:379, in NodeVisitor.generic_visit(self, node) 378 if isinstance(item, AST): –> 379 self.visit(item) 380 elif isinstance(value, AST):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:252, in CodeGenerator.visit_FunctionDef(self, node) 251 # visit function body –> 252 has_ret = self.visit_compound_statement(node.body) 253 # finalize function
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:177, in CodeGenerator.visit_compound_statement(self, stmts) 176 for stmt in stmts: –> 177 self.last_ret_type = self.visit(stmt) 178 if isinstance(stmt, ast.Return):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:678, in CodeGenerator.visit_For(self, node) 677 self.scf_stack.append(node) –> 678 self.visit_compound_statement(node.body) 679 self.scf_stack.pop()
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:177, in CodeGenerator.visit_compound_statement(self, stmts) 176 for stmt in stmts: –> 177 self.last_ret_type = self.visit(stmt) 178 if isinstance(stmt, ast.Return):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:319, in CodeGenerator.visit_AugAssign(self, node) 318 assign = ast.Assign(targets=[node.target], value=rhs) –> 319 self.visit(assign) 320 return self.get_value(name)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:301, in CodeGenerator.visit_Assign(self, node) 300 names = _names[0] –> 301 values = self.visit(node.value) 302 if not isinstance(names, tuple):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:339, in CodeGenerator.visit_BinOp(self, node) 338 lhs = self.visit(node.left) –> 339 rhs = self.visit(node.right) 340 fn = { 341 ast.Add: ‘add’, 342 ast.Sub: ‘sub’, (…) 352 ast.BitXor: ‘xor’, 353 }[type(node.op)]
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:797, in CodeGenerator.visit_Call(self, node)
795 if (hasattr(fn, ‘self’) and self.is_triton_tensor(fn.self))
796 or impl.is_builtin(fn):
–> 797 return fn(*args, _builder=self.builder, **kws)
798 if fn in self.builtins.values():
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/impl/base.py:22, in builtin.<locals>.wrapper(*args, **kwargs)
18 raise ValueError(
19 "Did you forget to add @triton.jit ? "
20 “(_builder
argument must be provided outside of JIT functions.)”
21 )
—> 22 return fn(*args, **kwargs)
TypeError: dot() got an unexpected keyword argument ‘trans_b’
The above exception was the direct cause of the following exception:
CompilationError Traceback (most recent call last) Cell In[15], line 1 ----> 1 teacher_train(T_model, cfg, train_loader, test_loader)
Cell In[14], line 39, in teacher_train(model, config, train_loader, test_loader) 37 mask = mask.to(config.device) 38 labels = labels.to(config.device) —> 39 outputs = model(ids, mask) 40 model.zero_grad() 41 loss = F.cross_entropy(outputs, labels)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []
Cell In[12], line 12, in BERT_Model.forward(self, context, mask) 11 def forward(self, context, mask): —> 12 outputs = self.bert(context, attention_mask=mask) 13 pooled = outputs[1] 14 out = self.fc(pooled)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:608, in BertModel.forward(self, input_ids, token_type_ids, attention_mask, position_ids, output_all_encoded_layers, masked_tokens_mask, **kwargs) 605 first_col_mask[:, 0] = True 606 subset_mask = masked_tokens_mask | first_col_mask –> 608 encoder_outputs = self.encoder( 609 embedding_output, 610 attention_mask, 611 output_all_encoded_layers=output_all_encoded_layers, 612 subset_mask=subset_mask) 614 if masked_tokens_mask is None: 615 sequence_output = encoder_outputs[-1]
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:446, in BertEncoder.forward(self, hidden_states, attention_mask, output_all_encoded_layers, subset_mask) 444 if subset_mask is None: 445 for layer_module in self.layer: –> 446 hidden_states = layer_module(hidden_states, 447 cu_seqlens, 448 seqlen, 449 None, 450 indices, 451 attn_mask=attention_mask, 452 bias=alibi_attn_mask) 453 if output_all_encoded_layers: 454 all_encoder_layers.append(hidden_states)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:327, in BertLayer.forward(self, hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias) 305 def forward( 306 self, 307 hidden_states: torch.Tensor, (…) 313 bias: Optional[torch.Tensor] = None, 314 ) -> torch.Tensor: 315 “”“Forward pass for a BERT layer, including both attention and MLP. 316 317 Args: (…) 325 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) 326 “”” –> 327 attention_output = self.attention(hidden_states, cu_seqlens, seqlen, 328 subset_idx, indices, attn_mask, bias) 329 layer_output = self.mlp(attention_output) 330 return layer_output
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:240, in BertUnpadAttention.forward(self, input_tensor, cu_seqlens, max_s, subset_idx, indices, attn_mask, bias) 218 def forward( 219 self, 220 input_tensor: torch.Tensor, (…) 226 bias: Optional[torch.Tensor] = None, 227 ) -> torch.Tensor: 228 “”“Forward pass for scaled self-attention without padding. 229 230 Arguments: (…) 238 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) 239 “”” –> 240 self_output = self.self(input_tensor, cu_seqlens, max_s, indices, 241 attn_mask, bias) 242 if subset_idx is not None: 243 return self.output(index_first_axis(self_output, subset_idx), 244 index_first_axis(input_tensor, subset_idx))
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:181, in BertUnpadSelfAttention.forward(self, hidden_states, cu_seqlens, max_seqlen_in_batch, indices, attn_mask, bias) 179 bias_dtype = bias.dtype 180 bias = bias.to(torch.float16) –> 181 attention = flash_attn_qkvpacked_func(qkv, bias) 182 attention = attention.to(orig_dtype) 183 bias = bias.to(bias_dtype)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs) 503 if not torch._C._are_functorch_transforms_active(): 504 # See NOTE: [functorch vjp and autograd interaction] 505 args = _functorch.utils.unwrap_dead_wrappers(args) –> 506 return super().apply(*args, **kwargs) # type: ignore[misc] 508 if cls.setup_context == _SingleLevelFunction.setup_context: 509 raise RuntimeError( 510 'In order to use an autograd.Function with functorch transforms ’ 511 '(vmap, grad, jvp, jacrev, …), it must override the setup_context ’ 512 'staticmethod. For more details, please see ’ 513 ‘https://pytorch.org/docs/master/notes/extending.func.html’)
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/flash_attn_triton.py:1021, in _FlashAttnQKVPackedFunc.forward(ctx, qkv, bias, causal, softmax_scale) 1019 if qkv.stride(-1) != 1: 1020 qkv = qkv.contiguous() -> 1021 o, lse, ctx.softmax_scale = _flash_attn_forward( 1022 qkv[:, :, 0], 1023 qkv[:, :, 1], 1024 qkv[:, :, 2], 1025 bias=bias, 1026 causal=causal, 1027 softmax_scale=softmax_scale) 1028 ctx.save_for_backward(qkv, o, lse, bias) 1029 ctx.causal = causal
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/flash_attn_triton.py:826, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale) 823 # BLOCK = 128 824 # num_warps = 4 if d <= 64 else 8 825 grid = lambda META: (triton.cdiv(seqlen_q, META[‘BLOCK_M’]), batch * nheads) –> 826 _fwd_kernel[grid]( # type: ignore 827 q, 828 k, 829 v, 830 bias, 831 o, 832 lse, 833 tmp, 834 softmax_scale, 835 q.stride(0), 836 q.stride(2), 837 q.stride(1), 838 k.stride(0), 839 k.stride(2), 840 k.stride(1), 841 v.stride(0), 842 v.stride(2), 843 v.stride(1), 844 *bias_strides, 845 o.stride(0), 846 o.stride(2), 847 o.stride(1), 848 nheads, 849 seqlen_q, 850 seqlen_k, 851 seqlen_q_rounded, 852 d, 853 seqlen_q // 32, 854 seqlen_k // 32, # key for triton cache (limit number of compilations) 855 # Can’t use kwargs here because triton autotune expects key to be args, not kwargs 856 # IS_CAUSAL=causal, BLOCK_HEADDIM=d, 857 bias_type, 858 causal, 859 BLOCK_HEADDIM, 860 # BLOCK_M=BLOCK, BLOCK_N=BLOCK, 861 # num_warps=num_warps, 862 # num_stages=1, 863 ) 864 return o, lse, softmax_scale
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/runtime/autotuner.py:90, in Autotuner.run(self, *args, **kwargs) 88 if config.pre_hook is not None: 89 config.pre_hook(self.nargs) —> 90 return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/runtime/autotuner.py:199, in Heuristics.run(self, *args, **kwargs) 197 for v, heur in self.values.items(): 198 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) –> 199 return self.fn.run(*args, **kwargs)
File <string>:41, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:1621, in compile(fn, **kwargs) 1619 next_module = parse(path) 1620 else: -> 1621 next_module = compile(module) 1622 fn_cache_manager.put(next_module, f"{name}.{ir}") 1623 if os.path.exists(path):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:1550, in compile.<locals>.<lambda>(src) 1545 extern_libs = kwargs.get(“extern_libs”, dict()) 1546 # build compilation stages 1547 stages = { 1548 “ast”: (lambda path: fn, None), 1549 “ttir”: (lambda path: parse_mlir_module(path, context), -> 1550 lambda src: ast_to_ttir(src, signature, configs[0], constants)), 1551 “ttgir”: (lambda path: parse_mlir_module(path, context), 1552 lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)), 1553 “llir”: (lambda path: Path(path).read_text(), 1554 lambda src: ttgir_to_llir(src, extern_libs, capability)), 1555 “ptx”: (lambda path: Path(path).read_text(), 1556 lambda src: llir_to_ptx(src, capability)), 1557 “cubin”: (lambda path: Path(path).read_bytes(), 1558 lambda src: ptx_to_cubin(src, capability)) 1559 } 1560 # find out the signature of the function 1561 if isinstance(fn, triton.runtime.JITFunction):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:962, in ast_to_ttir(fn, signature, specialization, constants) 961 def ast_to_ttir(fn, signature, specialization, constants): –> 962 mod, _ = build_triton_ir(fn, signature, specialization, constants) 963 return optimize_triton_ir(mod)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:942, in build_triton_ir(fn, signature, specialization, constants) 940 if node is None or isinstance(e, (NotImplementedError, CompilationError)): 941 raise e –> 942 raise CompilationError(fn.src, node) from e 943 ret = generator.module 944 # module takes ownership of the context
About this issue
- Original URL
- State: open
- Created a year ago
- Comments: 22
Hi guys, I found the solution. I spent 5h on @it…
So the problem is that in the model “zhihan1996/DNABERT-2-117M” that we load using : AutoModel.from_pretrained(“zhihan1996/DNABERT-2-117M”, trust_remote_code=True) there is a script called flash_attn_triton.py (see https://huggingface.co/zhihan1996/DNABERT-2-117M/blob/main/flash_attn_triton.py) where the line : qk += tl.dot(q, k, trans_b=True) is no longer compatible with triton 2.0.1 (the one you get from pip install triton) or triton 2.1.0 (the one you get from git clone https://github.com/openai/triton.git). The function tl.dot() does not accept the parameter trans_b. See also: https://github.com/microsoft/DeepSpeed/issues/3491 The version of triton that is compatible is 2.0.0.dev20221202 (it has the parameter trans_b).
So you have to do: pip install triton==2.0.0.dev20221202 Unfortunately, this will install torch 1.13.1 instead of 2.
But that works fine after that 😃.
The problem here is that we need to have the right versions for all python modules to install, but they were not given. This is because for instance for triton module as it is updated it becomes not compatible with DNABERT2.
@pjsample I could not run the scripts as there were other bugs coming. I had to make my own script and it worked after a lot of modifications.
I think I solved it on my system. I have a NVIDIA A100,
nvidia-smi
reportsDriver Version: 535.104.05 CUDA Version: 12.2
. Same error about triton wanting CUDA 11+.Made a new environment:
Then I forced the torch CUDA version:
Then I installed the required packages via this requirements.txt (not pulling/installing triton from github):
Finally, I had to install a CUDA 11 nvcc in the conda environment, I believe triton gets confused by the system-wide CUDA 12 nvcc binary.
At least the example data works now 😃
Command:
nvidia-smi
reports Python using the GPU.Click for full log