FastChat: RuntimeError: CUDA error: device-side assert triggered when running Llama on multiple gpus

I’m getting the following error when using more than one gpu

python3 -m fastchat.serve.cli --model-name /tmp/cache/vicuna-13b/ --num-gpus 2

I am unsure if this is a problem on my end or if it’s something that can be fixed. Can you please confirm if using multiple GPUs is supported by FastChat and if there are any specific requirements that must be met? Thank you.

I’m using 4xV100 - 32GB , and yes I’ve already tried with 2 and 4 combinations. Screenshot 2023-04-04 at 11 08 39 PM

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 17 (1 by maintainers)

Most upvoted comments

Seems like @sethbruder 's solution would solve this problem. Closing. Please re-open if the issue persists.

@zhisbug Hi, I have updated the fschat and transformers package to the latest version , and reconverted the model format to huggingface format, but the error before mentioned still exists when running the client on two RTX4090 gpus. I think this issue hasn’t been solved, may you reopen the issue?

In case it helps anybody else: The problem posted by @starphantom666 may be a different problem from the original post (OP) above. A problem consistent with the @starphantom666 report can occur because of the recently changed handling of BOS/EOS tokens in the Hugging Face (“HF”) Llama implementation.

  • This problem may occur if the conversion of weights to HF format was done with “old” HF code, whereas you are now using the latest HF code.
  • The symptom of the problem is that, if you print out the prompt tokenization, the BOS token ("<s>") is wrongly represented by ID# 32000, whereas the embeddings now expect 1.
  • In that case, if one does export CUDA_LAUNCH_BLOCKING=1, then one will see an assertion message of the form ... Indexing.cu: ... indexSelectLargeIndex: ... Assertion srcIndex < srcSelectDimSize failed. from a CUDA kernel launched by the torch embedding function – presumably because the 32000 is past the end of the embedding table.

A solution is to update both HF transformers and FastChat repos to the latest and re-convert weights from the original weights to HF weights, and then from HF weights to vicuna weights.

I also have the same problem. I can use the vicuna 13b model properly with --load-8bit option on in single 4090 GPU, but when I use multiple gpus like (–num-gpus 2), this problem occured. I’m still seeing the same traceback message and I can’t figure out why.

I got the same problem on a dual 4090 machine. I tried the same command with two 3090s and it worked well. I guessed it was the problem of driver/cuda version, but then I did some searching and found the following post: https://discuss.pytorch.org/t/ddp-training-on-rtx-4090-ada-cu118/168366

It seems 4090 does not support communication between multiple cards at all. I am not 100% sure if this is the root cause since I am not an expert in this domain. Can someone double-check it? Thanks.

=== Updated below ===

I tried setting NCCL_P2P_DISABLE=1 and ran another code for training LoRA with two 4090s. Now it works (it used to be stuck).

But when I try running Vicuna with P2P disabled, it quit and reported another error:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/runpy.py:196 in _run_module_as_main  │                                                                       
│                                                                                                  │                                                                       
│   193 │   main_globals = sys.modules["__main__"].__dict__                                        │                                                                       
│   194 │   if alter_argv:                                                                         │                                                                       
│   195 │   │   sys.argv[0] = mod_spec.origin                                                      │                                                                       
│ ❱ 196 │   return _run_code(code, main_globals, None,                                             │                                                                       
│   197 │   │   │   │   │    "__main__", mod_spec)                                                 │                                                                       
│   198                                                                                            │                                                                       
│   199 def run_module(mod_name, init_globals=None,                                                │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/runpy.py:86 in _run_code             │                                                                       
│                                                                                                  │                                                                       
│    83 │   │   │   │   │      __loader__ = loader,                                                │                                                                       
│    84 │   │   │   │   │      __package__ = pkg_name,                                             │                                                                       
│    85 │   │   │   │   │      __spec__ = mod_spec)                                                │                                                                       
│ ❱  86 │   exec(code, run_globals)                                                                │                                                                       
│    87 │   return run_globals                                                                     │                                                                       
│    88                                                                                            │                                                                       
│    89 def _run_module_code(code, init_globals=None,                                              │                                                                       
│                                                                                                  │                                                                       
│ /XX/FastChat/fastchat/serve/cli.py:132 in <module>                                    │                                                                       
│                                                                                                  │                                                                       
│   129 │   │   │   │   │   │   choices=["simple", "rich"], help="Display style.")                 │                                                                       
│   130 │   parser.add_argument("--debug", action="store_true")                                    │                                                                       
│   131 │   args = parser.parse_args()                                                             │                                                                       
│ ❱ 132 │   main(args)                                                                             │                                                                       
│   133                                                                                            │                                                                       
│                                                                                                  │                                                                       
│ /XX/FastChat/fastchat/serve/cli.py:108 in main                                        │                                                                       
│                                                                                                  │                                                                       
│   105 │   else:                                                                                  │                                                                       
│   106 │   │   raise ValueError(f"Invalid style for console: {args.style}")                       │                                                                       
│   107 │   try:                                                                                   │                                                                       
│ ❱ 108 │   │   chat_loop(args.model_path, args.device, args.num_gpus, args.max_gpu_memory,        │                                                                       
│   109 │   │   │   args.load_8bit, args.conv_template, args.temperature, args.max_new_tokens,     │                                                                       
│   110 │   │   │   chatio, args.debug)                                                            │                                                                       
│   111 │   except KeyboardInterrupt:                                                              │                                                                       
│                                                                                                  │                                                                       
│ /XX/FastChat/fastchat/serve/inference.py:223 in chat_loop                             │                                                                       
│                                                                                                  │                                                                       
│   220 │   │                                                                                      │                                                                       
│   221 │   │   chatio.prompt_for_output(conv.roles[1])                                            │                                                                       
│   222 │   │   output_stream = generate_stream_func(model, tokenizer, params, device)             │                                                                       
│ ❱ 223 │   │   outputs = chatio.stream_output(output_stream, skip_echo_len)                       │                                                                       
│   224 │   │   # NOTE: strip is important to align with the training data.                        │                                                                       
│   225 │   │   conv.messages[-1][-1] = outputs.strip()                                            │                                                                       
│   226                                                                                            │                                                                       
│                                                                                                  │                                                                       
│ /XX/FastChat/fastchat/serve/cli.py:69 in stream_output                                │                                                                       
│                                                                                                  │                                                                       
│    66 │   │   # Create a Live context for updating the console output                            │                                                                       
│    67 │   │   with Live(console=self._console, refresh_per_second=4) as live:                    │                                                                       
│    68 │   │   │   # Read lines from the stream                                                   │                                                                       
│ ❱  69 │   │   │   for outputs in output_stream:                                                  │                                                                       
│    70 │   │   │   │   accumulated_text = outputs[skip_echo_len:]                                 │                                                                       
│    71 │   │   │   │   if not accumulated_text:                                                   │                                                                       
│    72 │   │   │   │   │   continue                                                               │                                                                       
│                                                                                                  │             
│                                                                                                  │                                                             [104/1845]
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/torch/utils/_contextli │                                                                       
│ b.py:56 in generator_context                                                                     │                                                                       
│                                                                                                  │                                                                       
│    53 │   │   │   │   else:                                                                      │                                                                       
│    54 │   │   │   │   │   # Pass the last request to the generator and get its response          │                                                                       
│    55 │   │   │   │   │   with ctx_factory():                                                    │                                                                       
│ ❱  56 │   │   │   │   │   │   response = gen.send(request)                                       │                                                                       
│    57 │   │                                                                                      │                                                                       
│    58 │   │   # We let the exceptions raised above by the generator's `.throw` or                │                                                                       
│    59 │   │   # `.send` methods bubble up to our caller, except for StopIteration                │                                                                       
│                                                                                                  │                                                                       
│ /XX/FastChat/fastchat/serve/inference.py:122 in generate_stream                       │                                                                       
│                                                                                                  │                                                                       
│   119 │   │   │   logits = out.logits                                                            │                                                                       
│   120 │   │   │   past_key_values = out.past_key_values                                          │                                                                       
│   121 │   │   else:                                                                              │                                                                       
│ ❱ 122 │   │   │   out = model(input_ids=torch.as_tensor([[token]], device=device),               │                                                                       
│   123 │   │   │   │   │   │   use_cache=True,                                                    │                                                                       
│   124 │   │   │   │   │   │   past_key_values=past_key_values)                                   │                                                                       
│   125 │   │   │   logits = out.logits                                                            │                                                                       
│                                                                                                  │                                                                       
│ /XX/miniconda3/envs/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/modul │                                                                       
│ e.py:1501 in _call_impl                                                                          │                                                                       
│                                                                                                  │                                                                       
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │                                                                       
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                       
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                       
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                       
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                       
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                       
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/accelerate/hooks.py:16 │                                                                       
│ 5 in new_forward                                                                                 │                                                                       
│                                                                                                  │                                                                       
│   162 │   │   │   with torch.no_grad():                                                          │                                                                       
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │                                                                       
│   164 │   │   else:                                                                              │                                                                       
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │                                                                       
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │                                                                       
│   167 │                                                                                          │                                                                       
│   168 │   module.forward = new_forward                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/transformers/models/ll │                                                                       
│ ama/modeling_llama.py:687 in forward                                                             │                                                                       
│                                                                                                  │                                                                       
│   684 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return   │                                                                       
│   685 │   │                                                                                      │                                                                       
│   686 │   │   # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)    │                                                                       
│ ❱ 687 │   │   outputs = self.model(                                                              │                                                                       
│   688 │   │   │   input_ids=input_ids,                                                           │                                                                       
│   689 │   │   │   attention_mask=attention_mask,                                                 │                                                                       
│   690 │   │   │   position_ids=position_ids,                                                     │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/modul │                                                                       
│ e.py:1501 in _call_impl                                                                          │                                                                       
│                                                                                                  │                                                                       
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │                                                                       
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                       
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                       
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                       
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                       
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                       
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                       
│                                                                                                  │    
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/transformers/models/ll │                                                                       
│ ama/modeling_llama.py:577 in forward                                                             │                                                                       
│                                                                                                  │                                                                       
│   574 │   │   │   │   │   None,                                                                  │                                                                       
│   575 │   │   │   │   )                                                                          │                                                                       
│   576 │   │   │   else:                                                                          │                                                                       
│ ❱ 577 │   │   │   │   layer_outputs = decoder_layer(                                             │                                                                       
│   578 │   │   │   │   │   hidden_states,                                                         │                                                                       
│   579 │   │   │   │   │   attention_mask=attention_mask,                                         │                                                                       
│   580 │   │   │   │   │   position_ids=position_ids,                                             │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/modul │                                                                       
│ e.py:1501 in _call_impl                                                                          │                                                                       
│                                                                                                  │                                                                       
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │                                                                       
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                       
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                       
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                       
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                       
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                       
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/accelerate/hooks.py:16 │                                                                       
│ 5 in new_forward                                                                                 │                                                                       
│                                                                                                  │                                                                       
│   162 │   │   │   with torch.no_grad():                                                          │                                                                       
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │                                                                       
│   164 │   │   else:                                                                              │                                                                       
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │                                                                       
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │                                                                       
│   167 │                                                                                          │                                                                       
│   168 │   module.forward = new_forward                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/transformers/models/ll │                                                                       
│ ama/modeling_llama.py:305 in forward                                                             │                                                                       
│                                                                                                  │                                                                       
│   302 │   │   # Fully Connected                                                                  │                                                                       
│   303 │   │   residual = hidden_states                                                           │       
│   304 │   │   hidden_states = self.post_attention_layernorm(hidden_states)                       │                                                               [0/1845]
│ ❱ 305 │   │   hidden_states = self.mlp(hidden_states)                                            │                                                                       
│   306 │   │   hidden_states = residual + hidden_states                                           │                                                                       
│   307 │   │                                                                                      │                                                                       
│   308 │   │   outputs = (hidden_states,)                                                         │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/modul │                                                                       
│ e.py:1501 in _call_impl                                                                          │                                                                       
│                                                                                                  │                                                                       
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │                                                                       
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                       
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                       
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                       
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                       
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                       
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/accelerate/hooks.py:16 │                                                                       
│ 5 in new_forward                                                                                 │                                                                       
│                                                                                                  │                                                                       
│   162 │   │   │   with torch.no_grad():                                                          │                                                                       
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │                                                                       
│   164 │   │   else:                                                                              │                                                                       
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │                                                                       
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │                                                                       
│   167 │                                                                                          │                                                                       
│   168 │   module.forward = new_forward                                                           │                                                                       
│                                                                                                  │                                                                       
| /XX/miniconda3/envs/vicuna-matata/lib/python3.10/site-packages/transformers/models/ll │                                                                       
│ ama/modeling_llama.py:157 in forward                                                             │                                                                       
│                                                                                                  │                                                                       
│   154 │   │   self.act_fn = ACT2FN[hidden_act]                                                   │                                                                       
│   155 │                                                                                          │                                                                       
│   156 │   def forward(self, x):                                                                  │                                                                       
│ ❱ 157 │   │   return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))            │                                                                       
│   158                                                                                            │                                                                       
│   159                                                                                            │                                                                       
│   160 class LlamaAttention(nn.Module):                                                           │                                                                       
│                                                                                                  │                                                                       
| /XX/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/modul │                                                                       
│ e.py:1501 in _call_impl                                                                          │                                                                       
│                                                                                                  │                                                                       
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │                                                                       
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                       
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                       
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                       
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                       
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                       
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                       
│                                                                                                  │                                                                       
| /XX/envs/vicuna-matata/lib/python3.10/site-packages/accelerate/hooks.py:16 │                                                                       
│ 5 in new_forward                                                                                 │                                                                       
│                                                                                                  │                                                                       
│   162 │   │   │   with torch.no_grad():                                                          │                                                                       
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │                                                                       
│   164 │   │   else:                                                                              │                                                                       
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │                                                                       
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │                                                                       
│   167 │                                                                                          │                                                                       
│   168 │   module.forward = new_forward                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/linea │                                                                       
│ r.py:114 in forward                                                                              │                                                                       
│                                                                                                  │                                                                       
│   111 │   │   │   init.uniform_(self.bias, -bound, bound)                                        │                                                                       
│   112 │                                                                                          │                                                                       
│   113 │   def forward(self, input: Tensor) -> Tensor:                                            │                                                                       
│ ❱ 114 │   │   return F.linear(input, self.weight, self.bias)                                     │                                                                       
│   115 │                                                                                          │                                                                       
│   116 │   def extra_repr(self) -> str:                                                           │                                                                       
│   117 │   │   return 'in_features={}, out_features={}, bias={}'.format(                          │                                                                       
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ 
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, 
&fbeta, c, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)`

I got the same problem on a dual 4090 machine. I tried the same command with two 3090s and it worked well. I guessed it was the problem of driver/cuda version, but then I did some searching and found the following post: https://discuss.pytorch.org/t/ddp-training-on-rtx-4090-ada-cu118/168366 It seems 4090 does not support communication between multiple cards at all. I am not 100% sure if this is the root cause since I am not an expert in this domain. Can someone double-check it? Thanks. === Updated below === I tried setting NCCL_P2P_DISABLE=1 and ran another code for training LoRA with two 4090s. Now it works (it used to be stuck). But when I try running Vicuna with P2P disabled, it quit and reported another error:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/runpy.py:196 in _run_module_as_main  │                                                                       
│                                                                                                  │                                                                       
│   193 │   main_globals = sys.modules["__main__"].__dict__                                        │                                                                       
│   194 │   if alter_argv:                                                                         │                                                                       
│   195 │   │   sys.argv[0] = mod_spec.origin                                                      │                                                                       
│ ❱ 196 │   return _run_code(code, main_globals, None,                                             │                                                                       
│   197 │   │   │   │   │    "__main__", mod_spec)                                                 │                                                                       
│   198                                                                                            │                                                                       
│   199 def run_module(mod_name, init_globals=None,                                                │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/runpy.py:86 in _run_code             │                                                                       
│                                                                                                  │                                                                       
│    83 │   │   │   │   │      __loader__ = loader,                                                │                                                                       
│    84 │   │   │   │   │      __package__ = pkg_name,                                             │                                                                       
│    85 │   │   │   │   │      __spec__ = mod_spec)                                                │                                                                       
│ ❱  86 │   exec(code, run_globals)                                                                │                                                                       
│    87 │   return run_globals                                                                     │                                                                       
│    88                                                                                            │                                                                       
│    89 def _run_module_code(code, init_globals=None,                                              │                                                                       
│                                                                                                  │                                                                       
│ /XX/FastChat/fastchat/serve/cli.py:132 in <module>                                    │                                                                       
│                                                                                                  │                                                                       
│   129 │   │   │   │   │   │   choices=["simple", "rich"], help="Display style.")                 │                                                                       
│   130 │   parser.add_argument("--debug", action="store_true")                                    │                                                                       
│   131 │   args = parser.parse_args()                                                             │                                                                       
│ ❱ 132 │   main(args)                                                                             │                                                                       
│   133                                                                                            │                                                                       
│                                                                                                  │                                                                       
│ /XX/FastChat/fastchat/serve/cli.py:108 in main                                        │                                                                       
│                                                                                                  │                                                                       
│   105 │   else:                                                                                  │                                                                       
│   106 │   │   raise ValueError(f"Invalid style for console: {args.style}")                       │                                                                       
│   107 │   try:                                                                                   │                                                                       
│ ❱ 108 │   │   chat_loop(args.model_path, args.device, args.num_gpus, args.max_gpu_memory,        │                                                                       
│   109 │   │   │   args.load_8bit, args.conv_template, args.temperature, args.max_new_tokens,     │                                                                       
│   110 │   │   │   chatio, args.debug)                                                            │                                                                       
│   111 │   except KeyboardInterrupt:                                                              │                                                                       
│                                                                                                  │                                                                       
│ /XX/FastChat/fastchat/serve/inference.py:223 in chat_loop                             │                                                                       
│                                                                                                  │                                                                       
│   220 │   │                                                                                      │                                                                       
│   221 │   │   chatio.prompt_for_output(conv.roles[1])                                            │                                                                       
│   222 │   │   output_stream = generate_stream_func(model, tokenizer, params, device)             │                                                                       
│ ❱ 223 │   │   outputs = chatio.stream_output(output_stream, skip_echo_len)                       │                                                                       
│   224 │   │   # NOTE: strip is important to align with the training data.                        │                                                                       
│   225 │   │   conv.messages[-1][-1] = outputs.strip()                                            │                                                                       
│   226                                                                                            │                                                                       
│                                                                                                  │                                                                       
│ /XX/FastChat/fastchat/serve/cli.py:69 in stream_output                                │                                                                       
│                                                                                                  │                                                                       
│    66 │   │   # Create a Live context for updating the console output                            │                                                                       
│    67 │   │   with Live(console=self._console, refresh_per_second=4) as live:                    │                                                                       
│    68 │   │   │   # Read lines from the stream                                                   │                                                                       
│ ❱  69 │   │   │   for outputs in output_stream:                                                  │                                                                       
│    70 │   │   │   │   accumulated_text = outputs[skip_echo_len:]                                 │                                                                       
│    71 │   │   │   │   if not accumulated_text:                                                   │                                                                       
│    72 │   │   │   │   │   continue                                                               │                                                                       
│                                                                                                  │             
│                                                                                                  │                                                             [104/1845]
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/torch/utils/_contextli │                                                                       
│ b.py:56 in generator_context                                                                     │                                                                       
│                                                                                                  │                                                                       
│    53 │   │   │   │   else:                                                                      │                                                                       
│    54 │   │   │   │   │   # Pass the last request to the generator and get its response          │                                                                       
│    55 │   │   │   │   │   with ctx_factory():                                                    │                                                                       
│ ❱  56 │   │   │   │   │   │   response = gen.send(request)                                       │                                                                       
│    57 │   │                                                                                      │                                                                       
│    58 │   │   # We let the exceptions raised above by the generator's `.throw` or                │                                                                       
│    59 │   │   # `.send` methods bubble up to our caller, except for StopIteration                │                                                                       
│                                                                                                  │                                                                       
│ /XX/FastChat/fastchat/serve/inference.py:122 in generate_stream                       │                                                                       
│                                                                                                  │                                                                       
│   119 │   │   │   logits = out.logits                                                            │                                                                       
│   120 │   │   │   past_key_values = out.past_key_values                                          │                                                                       
│   121 │   │   else:                                                                              │                                                                       
│ ❱ 122 │   │   │   out = model(input_ids=torch.as_tensor([[token]], device=device),               │                                                                       
│   123 │   │   │   │   │   │   use_cache=True,                                                    │                                                                       
│   124 │   │   │   │   │   │   past_key_values=past_key_values)                                   │                                                                       
│   125 │   │   │   logits = out.logits                                                            │                                                                       
│                                                                                                  │                                                                       
│ /XX/miniconda3/envs/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/modul │                                                                       
│ e.py:1501 in _call_impl                                                                          │                                                                       
│                                                                                                  │                                                                       
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │                                                                       
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                       
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                       
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                       
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                       
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                       
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/accelerate/hooks.py:16 │                                                                       
│ 5 in new_forward                                                                                 │                                                                       
│                                                                                                  │                                                                       
│   162 │   │   │   with torch.no_grad():                                                          │                                                                       
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │                                                                       
│   164 │   │   else:                                                                              │                                                                       
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │                                                                       
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │                                                                       
│   167 │                                                                                          │                                                                       
│   168 │   module.forward = new_forward                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/transformers/models/ll │                                                                       
│ ama/modeling_llama.py:687 in forward                                                             │                                                                       
│                                                                                                  │                                                                       
│   684 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return   │                                                                       
│   685 │   │                                                                                      │                                                                       
│   686 │   │   # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)    │                                                                       
│ ❱ 687 │   │   outputs = self.model(                                                              │                                                                       
│   688 │   │   │   input_ids=input_ids,                                                           │                                                                       
│   689 │   │   │   attention_mask=attention_mask,                                                 │                                                                       
│   690 │   │   │   position_ids=position_ids,                                                     │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/modul │                                                                       
│ e.py:1501 in _call_impl                                                                          │                                                                       
│                                                                                                  │                                                                       
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │                                                                       
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                       
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                       
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                       
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                       
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                       
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                       
│                                                                                                  │    
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/transformers/models/ll │                                                                       
│ ama/modeling_llama.py:577 in forward                                                             │                                                                       
│                                                                                                  │                                                                       
│   574 │   │   │   │   │   None,                                                                  │                                                                       
│   575 │   │   │   │   )                                                                          │                                                                       
│   576 │   │   │   else:                                                                          │                                                                       
│ ❱ 577 │   │   │   │   layer_outputs = decoder_layer(                                             │                                                                       
│   578 │   │   │   │   │   hidden_states,                                                         │                                                                       
│   579 │   │   │   │   │   attention_mask=attention_mask,                                         │                                                                       
│   580 │   │   │   │   │   position_ids=position_ids,                                             │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/modul │                                                                       
│ e.py:1501 in _call_impl                                                                          │                                                                       
│                                                                                                  │                                                                       
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │                                                                       
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                       
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                       
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                       
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                       
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                       
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/accelerate/hooks.py:16 │                                                                       
│ 5 in new_forward                                                                                 │                                                                       
│                                                                                                  │                                                                       
│   162 │   │   │   with torch.no_grad():                                                          │                                                                       
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │                                                                       
│   164 │   │   else:                                                                              │                                                                       
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │                                                                       
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │                                                                       
│   167 │                                                                                          │                                                                       
│   168 │   module.forward = new_forward                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/transformers/models/ll │                                                                       
│ ama/modeling_llama.py:305 in forward                                                             │                                                                       
│                                                                                                  │                                                                       
│   302 │   │   # Fully Connected                                                                  │                                                                       
│   303 │   │   residual = hidden_states                                                           │       
│   304 │   │   hidden_states = self.post_attention_layernorm(hidden_states)                       │                                                               [0/1845]
│ ❱ 305 │   │   hidden_states = self.mlp(hidden_states)                                            │                                                                       
│   306 │   │   hidden_states = residual + hidden_states                                           │                                                                       
│   307 │   │                                                                                      │                                                                       
│   308 │   │   outputs = (hidden_states,)                                                         │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/modul │                                                                       
│ e.py:1501 in _call_impl                                                                          │                                                                       
│                                                                                                  │                                                                       
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │                                                                       
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                       
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                       
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                       
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                       
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                       
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/accelerate/hooks.py:16 │                                                                       
│ 5 in new_forward                                                                                 │                                                                       
│                                                                                                  │                                                                       
│   162 │   │   │   with torch.no_grad():                                                          │                                                                       
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │                                                                       
│   164 │   │   else:                                                                              │                                                                       
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │                                                                       
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │                                                                       
│   167 │                                                                                          │                                                                       
│   168 │   module.forward = new_forward                                                           │                                                                       
│                                                                                                  │                                                                       
| /XX/miniconda3/envs/vicuna-matata/lib/python3.10/site-packages/transformers/models/ll │                                                                       
│ ama/modeling_llama.py:157 in forward                                                             │                                                                       
│                                                                                                  │                                                                       
│   154 │   │   self.act_fn = ACT2FN[hidden_act]                                                   │                                                                       
│   155 │                                                                                          │                                                                       
│   156 │   def forward(self, x):                                                                  │                                                                       
│ ❱ 157 │   │   return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))            │                                                                       
│   158                                                                                            │                                                                       
│   159                                                                                            │                                                                       
│   160 class LlamaAttention(nn.Module):                                                           │                                                                       
│                                                                                                  │                                                                       
| /XX/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/modul │                                                                       
│ e.py:1501 in _call_impl                                                                          │                                                                       
│                                                                                                  │                                                                       
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │                                                                       
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                       
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                       
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                       
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                       
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                       
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                       
│                                                                                                  │                                                                       
| /XX/envs/vicuna-matata/lib/python3.10/site-packages/accelerate/hooks.py:16 │                                                                       
│ 5 in new_forward                                                                                 │                                                                       
│                                                                                                  │                                                                       
│   162 │   │   │   with torch.no_grad():                                                          │                                                                       
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │                                                                       
│   164 │   │   else:                                                                              │                                                                       
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │                                                                       
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │                                                                       
│   167 │                                                                                          │                                                                       
│   168 │   module.forward = new_forward                                                           │                                                                       
│                                                                                                  │                                                                       
│ /XX/envs/vicuna-matata/lib/python3.10/site-packages/torch/nn/modules/linea │                                                                       
│ r.py:114 in forward                                                                              │                                                                       
│                                                                                                  │                                                                       
│   111 │   │   │   init.uniform_(self.bias, -bound, bound)                                        │                                                                       
│   112 │                                                                                          │                                                                       
│   113 │   def forward(self, input: Tensor) -> Tensor:                                            │                                                                       
│ ❱ 114 │   │   return F.linear(input, self.weight, self.bias)                                     │                                                                       
│   115 │                                                                                          │                                                                       
│   116 │   def extra_repr(self) -> str:                                                           │                                                                       
│   117 │   │   return 'in_features={}, out_features={}, bias={}'.format(                          │                                                                       
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ 
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, 
&fbeta, c, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)`

Hi, I’m also seeing this error. Any update on this? Thank you!

No, loading the model with dual 4090s still doesn’t work. I currently use a single card with the --load-8bit quantization option as a workaround. It should have very little performance degradation. Hope this helps.