accelerate: Device mismatch when using load_checkpoint_and_dispatch with LLaMA

System Info

- `Accelerate` version: 0.20.3
- Platform: Linux-5.4.0-84-generic-x86_64-with-glibc2.31
- Python version: 3.9.16
- Numpy version: 1.24.3
- PyTorch version (GPU?): 2.0.1+cu118 (True)
- PyTorch XPU available: False
- System RAM: 116.56 GB
- GPU type: GeForce RTX 2080 Ti
- `Accelerate` default config:
        Not found
- `transformers` version: 4.32.0.dev0

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

The model I’m using is https://huggingface.co/huggyllama/llama-7b

Code I’m running (used this as reference):

import torch
from transformers import LlamaForCausalLM, LlamaConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

MODEL_NAME_OR_PATH = "models/llama-7b"

with init_empty_weights():
    model = LlamaForCausalLM(LlamaConfig.from_pretrained(MODEL_NAME_OR_PATH))

model = load_checkpoint_and_dispatch(
    model, checkpoint=MODEL_NAME_OR_PATH, device_map="auto"
)
model.tie_weights()

model.generate(input_ids=torch.as_tensor([[1, 2, 3, 4]]))

I get the following error (truncated):

File ~/miniconda3/envs/llama_prl/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:184, in apply_rotary_pos_emb(q, k, cos, sin, position_ids)
    182 cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    183 sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
--> 184 cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    185 sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    186 q_embed = (q * cos) + (rotate_half(q) * sin)

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

Expected behavior

Should infer normally, I think

About this issue

  • Original URL
  • State: closed
  • Created 10 months ago
  • Comments: 24 (10 by maintainers)

Most upvoted comments

Hi @CrazyBrick , you need to check for layers that have residual connection. Most of the time, you need to put the encoder/decoder layers in the no_split_module_classes.

@SunMarc that did the trick! I quickly looked through the llama code for mentions of residual connections and came up with no_split_module_classes=['LlamaDecoderLayer']. Now the model successfully parallelizes and quantizes, and I don’t get an OOM as with device_map='sequential'. Thank you for help! 😊