DNABERT_2: CompilationError: at 114:24:

Epoch [1/3]

KeyError Traceback (most recent call last) File <string>:21, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)

KeyError: (‘2-.-0-.-0–d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-14de7de5c4da5794c8ca14e7e41a122d-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c’, (torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float32, torch.float32, ‘fp32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’, ‘i32’), (‘matrix’, False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (False, False), (True, False), (True, False), (True, False), (False, False), (False, False), (False, False), (True, False), (True, False), (False, False), (False, False)))

During handling of the above exception, another exception occurred:

TypeError Traceback (most recent call last) File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:937, in build_triton_ir(fn, signature, specialization, constants) 936 try: –> 937 generator.visit(fn.parse()) 938 except Exception as e:

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:183, in CodeGenerator.visit_Module(self, node) 182 def visit_Module(self, node): –> 183 ast.NodeVisitor.generic_visit(self, node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:379, in NodeVisitor.generic_visit(self, node) 378 if isinstance(item, AST): –> 379 self.visit(item) 380 elif isinstance(value, AST):

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:252, in CodeGenerator.visit_FunctionDef(self, node) 251 # visit function body –> 252 has_ret = self.visit_compound_statement(node.body) 253 # finalize function

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:177, in CodeGenerator.visit_compound_statement(self, stmts) 176 for stmt in stmts: –> 177 self.last_ret_type = self.visit(stmt) 178 if isinstance(stmt, ast.Return):

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:678, in CodeGenerator.visit_For(self, node) 677 self.scf_stack.append(node) –> 678 self.visit_compound_statement(node.body) 679 self.scf_stack.pop()

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:177, in CodeGenerator.visit_compound_statement(self, stmts) 176 for stmt in stmts: –> 177 self.last_ret_type = self.visit(stmt) 178 if isinstance(stmt, ast.Return):

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:319, in CodeGenerator.visit_AugAssign(self, node) 318 assign = ast.Assign(targets=[node.target], value=rhs) –> 319 self.visit(assign) 320 return self.get_value(name)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:301, in CodeGenerator.visit_Assign(self, node) 300 names = _names[0] –> 301 values = self.visit(node.value) 302 if not isinstance(names, tuple):

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:339, in CodeGenerator.visit_BinOp(self, node) 338 lhs = self.visit(node.left) –> 339 rhs = self.visit(node.right) 340 fn = { 341 ast.Add: ‘add’, 342 ast.Sub: ‘sub’, (…) 352 ast.BitXor: ‘xor’, 353 }[type(node.op)]

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node) 854 warnings.simplefilter(“ignore”, PendingDeprecationWarning) # python 3.8 –> 855 return super().visit(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node) 370 visitor = getattr(self, method, self.generic_visit) –> 371 return visitor(node)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:797, in CodeGenerator.visit_Call(self, node) 795 if (hasattr(fn, ‘self’) and self.is_triton_tensor(fn.self))
796 or impl.is_builtin(fn): –> 797 return fn(*args, _builder=self.builder, **kws) 798 if fn in self.builtins.values():

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/impl/base.py:22, in builtin.<locals>.wrapper(*args, **kwargs) 18 raise ValueError( 19 "Did you forget to add @triton.jit ? " 20 “(_builder argument must be provided outside of JIT functions.)” 21 ) —> 22 return fn(*args, **kwargs)

TypeError: dot() got an unexpected keyword argument ‘trans_b’

The above exception was the direct cause of the following exception:

CompilationError Traceback (most recent call last) Cell In[15], line 1 ----> 1 teacher_train(T_model, cfg, train_loader, test_loader)

Cell In[14], line 39, in teacher_train(model, config, train_loader, test_loader) 37 mask = mask.to(config.device) 38 labels = labels.to(config.device) —> 39 outputs = model(ids, mask) 40 model.zero_grad() 41 loss = F.cross_entropy(outputs, labels)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

Cell In[12], line 12, in BERT_Model.forward(self, context, mask) 11 def forward(self, context, mask): —> 12 outputs = self.bert(context, attention_mask=mask) 13 pooled = outputs[1] 14 out = self.fc(pooled)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:608, in BertModel.forward(self, input_ids, token_type_ids, attention_mask, position_ids, output_all_encoded_layers, masked_tokens_mask, **kwargs) 605 first_col_mask[:, 0] = True 606 subset_mask = masked_tokens_mask | first_col_mask –> 608 encoder_outputs = self.encoder( 609 embedding_output, 610 attention_mask, 611 output_all_encoded_layers=output_all_encoded_layers, 612 subset_mask=subset_mask) 614 if masked_tokens_mask is None: 615 sequence_output = encoder_outputs[-1]

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:446, in BertEncoder.forward(self, hidden_states, attention_mask, output_all_encoded_layers, subset_mask) 444 if subset_mask is None: 445 for layer_module in self.layer: –> 446 hidden_states = layer_module(hidden_states, 447 cu_seqlens, 448 seqlen, 449 None, 450 indices, 451 attn_mask=attention_mask, 452 bias=alibi_attn_mask) 453 if output_all_encoded_layers: 454 all_encoder_layers.append(hidden_states)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:327, in BertLayer.forward(self, hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias) 305 def forward( 306 self, 307 hidden_states: torch.Tensor, (…) 313 bias: Optional[torch.Tensor] = None, 314 ) -> torch.Tensor: 315 “”“Forward pass for a BERT layer, including both attention and MLP. 316 317 Args: (…) 325 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) 326 “”” –> 327 attention_output = self.attention(hidden_states, cu_seqlens, seqlen, 328 subset_idx, indices, attn_mask, bias) 329 layer_output = self.mlp(attention_output) 330 return layer_output

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:240, in BertUnpadAttention.forward(self, input_tensor, cu_seqlens, max_s, subset_idx, indices, attn_mask, bias) 218 def forward( 219 self, 220 input_tensor: torch.Tensor, (…) 226 bias: Optional[torch.Tensor] = None, 227 ) -> torch.Tensor: 228 “”“Forward pass for scaled self-attention without padding. 229 230 Arguments: (…) 238 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) 239 “”” –> 240 self_output = self.self(input_tensor, cu_seqlens, max_s, indices, 241 attn_mask, bias) 242 if subset_idx is not None: 243 return self.output(index_first_axis(self_output, subset_idx), 244 index_first_axis(input_tensor, subset_idx))

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:181, in BertUnpadSelfAttention.forward(self, hidden_states, cu_seqlens, max_seqlen_in_batch, indices, attn_mask, bias) 179 bias_dtype = bias.dtype 180 bias = bias.to(torch.float16) –> 181 attention = flash_attn_qkvpacked_func(qkv, bias) 182 attention = attention.to(orig_dtype) 183 bias = bias.to(bias_dtype)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs) 503 if not torch._C._are_functorch_transforms_active(): 504 # See NOTE: [functorch vjp and autograd interaction] 505 args = _functorch.utils.unwrap_dead_wrappers(args) –> 506 return super().apply(*args, **kwargs) # type: ignore[misc] 508 if cls.setup_context == _SingleLevelFunction.setup_context: 509 raise RuntimeError( 510 'In order to use an autograd.Function with functorch transforms ’ 511 '(vmap, grad, jvp, jacrev, …), it must override the setup_context ’ 512 'staticmethod. For more details, please see ’ 513 ‘https://pytorch.org/docs/master/notes/extending.func.html’)

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/flash_attn_triton.py:1021, in _FlashAttnQKVPackedFunc.forward(ctx, qkv, bias, causal, softmax_scale) 1019 if qkv.stride(-1) != 1: 1020 qkv = qkv.contiguous() -> 1021 o, lse, ctx.softmax_scale = _flash_attn_forward( 1022 qkv[:, :, 0], 1023 qkv[:, :, 1], 1024 qkv[:, :, 2], 1025 bias=bias, 1026 causal=causal, 1027 softmax_scale=softmax_scale) 1028 ctx.save_for_backward(qkv, o, lse, bias) 1029 ctx.causal = causal

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/flash_attn_triton.py:826, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale) 823 # BLOCK = 128 824 # num_warps = 4 if d <= 64 else 8 825 grid = lambda META: (triton.cdiv(seqlen_q, META[‘BLOCK_M’]), batch * nheads) –> 826 _fwd_kernel[grid]( # type: ignore 827 q, 828 k, 829 v, 830 bias, 831 o, 832 lse, 833 tmp, 834 softmax_scale, 835 q.stride(0), 836 q.stride(2), 837 q.stride(1), 838 k.stride(0), 839 k.stride(2), 840 k.stride(1), 841 v.stride(0), 842 v.stride(2), 843 v.stride(1), 844 *bias_strides, 845 o.stride(0), 846 o.stride(2), 847 o.stride(1), 848 nheads, 849 seqlen_q, 850 seqlen_k, 851 seqlen_q_rounded, 852 d, 853 seqlen_q // 32, 854 seqlen_k // 32, # key for triton cache (limit number of compilations) 855 # Can’t use kwargs here because triton autotune expects key to be args, not kwargs 856 # IS_CAUSAL=causal, BLOCK_HEADDIM=d, 857 bias_type, 858 causal, 859 BLOCK_HEADDIM, 860 # BLOCK_M=BLOCK, BLOCK_N=BLOCK, 861 # num_warps=num_warps, 862 # num_stages=1, 863 ) 864 return o, lse, softmax_scale

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/runtime/autotuner.py:90, in Autotuner.run(self, *args, **kwargs) 88 if config.pre_hook is not None: 89 config.pre_hook(self.nargs) —> 90 return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/runtime/autotuner.py:199, in Heuristics.run(self, *args, **kwargs) 197 for v, heur in self.values.items(): 198 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) –> 199 return self.fn.run(*args, **kwargs)

File <string>:41, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:1621, in compile(fn, **kwargs) 1619 next_module = parse(path) 1620 else: -> 1621 next_module = compile(module) 1622 fn_cache_manager.put(next_module, f"{name}.{ir}") 1623 if os.path.exists(path):

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:1550, in compile.<locals>.<lambda>(src) 1545 extern_libs = kwargs.get(“extern_libs”, dict()) 1546 # build compilation stages 1547 stages = { 1548 “ast”: (lambda path: fn, None), 1549 “ttir”: (lambda path: parse_mlir_module(path, context), -> 1550 lambda src: ast_to_ttir(src, signature, configs[0], constants)), 1551 “ttgir”: (lambda path: parse_mlir_module(path, context), 1552 lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)), 1553 “llir”: (lambda path: Path(path).read_text(), 1554 lambda src: ttgir_to_llir(src, extern_libs, capability)), 1555 “ptx”: (lambda path: Path(path).read_text(), 1556 lambda src: llir_to_ptx(src, capability)), 1557 “cubin”: (lambda path: Path(path).read_bytes(), 1558 lambda src: ptx_to_cubin(src, capability)) 1559 } 1560 # find out the signature of the function 1561 if isinstance(fn, triton.runtime.JITFunction):

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:962, in ast_to_ttir(fn, signature, specialization, constants) 961 def ast_to_ttir(fn, signature, specialization, constants): –> 962 mod, _ = build_triton_ir(fn, signature, specialization, constants) 963 return optimize_triton_ir(mod)

File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:942, in build_triton_ir(fn, signature, specialization, constants) 940 if node is None or isinstance(e, (NotImplementedError, CompilationError)): 941 raise e –> 942 raise CompilationError(fn.src, node) from e 943 ret = generator.module 944 # module takes ownership of the context

About this issue

Most upvoted comments

Hi guys, I found the solution. I spent 5h on @it

So the problem is that in the model “zhihan1996/DNABERT-2-117M” that we load using : AutoModel.from_pretrained(“zhihan1996/DNABERT-2-117M”, trust_remote_code=True) there is a script called flash_attn_triton.py (see https://huggingface.co/zhihan1996/DNABERT-2-117M/blob/main/flash_attn_triton.py) where the line : qk += tl.dot(q, k, trans_b=True) is no longer compatible with triton 2.0.1 (the one you get from pip install triton) or triton 2.1.0 (the one you get from git clone https://github.com/openai/triton.git). The function tl.dot() does not accept the parameter trans_b. See also: https://github.com/microsoft/DeepSpeed/issues/3491 The version of triton that is compatible is 2.0.0.dev20221202 (it has the parameter trans_b).

So you have to do: pip install triton==2.0.0.dev20221202 Unfortunately, this will install torch 1.13.1 instead of 2.

But that works fine after that 😃.

The problem here is that we need to have the right versions for all python modules to install, but they were not given. This is because for instance for triton module as it is updated it becomes not compatible with DNABERT2.

@pjsample I could not run the scripts as there were other bugs coming. I had to make my own script and it worked after a lot of modifications.

I think I solved it on my system. I have a NVIDIA A100, nvidia-smi reports Driver Version: 535.104.05 CUDA Version: 12.2. Same error about triton wanting CUDA 11+.

Made a new environment:

mamba create -n dna python=3.8
conda activate dna

Then I forced the torch CUDA version:

pip install torch==1.13.1+cu117  --extra-index-url https://download.pytorch.org/whl/cu117

Then I installed the required packages via this requirements.txt (not pulling/installing triton from github):

triton==2.0.0.dev20221202
transformers==4.29.2
scikit-learn
peft
einops

Finally, I had to install a CUDA 11 nvcc in the conda environment, I believe triton gets confused by the system-wide CUDA 12 nvcc binary.

 mamba install -c "nvidia/label/cuda-11.7.0" cuda-nvcc

At least the example data works now 😃

Command:

export DATA_PATH=`pwd`/DNABERT_2/sample_data
export LR=3e-5
export MAX_LENGTH=100

python DNABERT_2/finetune/train.py \
    --model_name_or_path zhihan1996/DNABERT-2-117M \
    --data_path  ${DATA_PATH} \
    --kmer -1 \
    --run_name DNABERT2_${DATA_PATH} \
    --model_max_length ${MAX_LENGTH} \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --learning_rate ${LR} \
    --num_train_epochs 5 \
    --fp16 \
    --save_steps 200 \
    --output_dir output/dnabert2 \
    --evaluation_strategy steps \
    --eval_steps 200 \
    --warmup_steps 50 \
    --logging_steps 100 \
    --overwrite_output_dir True \
    --log_level info \
    --find_unused_parameters False

nvidia-smi reports Python using the GPU.

Click for full log
WARNING:root:Perform single sequence classification...
WARNING:root:Perform single sequence classification...
WARNING:root:Perform single sequence classification...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['bert.pooler.dense.weight', 'classifier.bias', 'bert.pooler.dense.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using cuda_amp half precision backend
***** Running training *****
  Num examples = 15
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 10
  Number of trainable parameters = 117,070,082
  0%|                                                                                                                                                 | 0/10 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎             | 9/10 [00:05<00:00,  3.53it/s]

Training completed. Do not forget to share your model on huggingface.co/models =)

{‘train_runtime’: 5.5699, ‘train_samples_per_second’: 13.465, ‘train_steps_per_second’: 1.795, ‘train_loss’: 0.6913905620574952, ‘epoch’: 5.0} 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00, 1.80it/s] ***** Running Evaluation ***** Num examples = 15 Batch size = 16 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 107.90it/s]