diffusers: enable_sequential_cpu_offload followed by load_textual_inversion does not handle placement of loaded weights

Describe the bug

Using pipe.enable_sequential_cpu_offload() and pipe.load_textual_inversion() together crashes diffusers.

Reproduction

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",  
    torch_dtype=torch.float16,
)

pipe.enable_sequential_cpu_offload()
pipe.load_textual_inversion("./happy512.pt", token="happy512")
prompt = "girl, happy512"

image = pipe(prompt).images[0]

Logs

(D:\2.35\dev\stable-diffusion\env) D:\2.35\dev\models\embeddings>python x.py
safety_checker\model.safetensors not found
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ D:\2.35\dev\models\embeddings\x.py:13 in <module>                                                │
│                                                                                                  │
│   10 pipe.load_textual_inversion("./happy512.pt", token="happy512")                              │
│   11 prompt = "girl, happy512"                                                                   │
│   12                                                                                             │
│ ❱ 13 image = pipe(prompt).images[0]                                                              │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\torch\utils\_contextlib.py:115 in             │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\diffusers\pipelines\stable_diffusion\pipeline │
│ _stable_diffusion.py:688 in __call__                                                             │
│                                                                                                  │
│   685 │   │   text_encoder_lora_scale = (                                                        │
│   686 │   │   │   cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not N   │
│   687 │   │   )                                                                                  │
│ ❱ 688 │   │   prompt_embeds = self._encode_prompt(                                               │
│   689 │   │   │   prompt,                                                                        │
│   690 │   │   │   device,                                                                        │
│   691 │   │   │   num_images_per_prompt,                                                         │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\diffusers\pipelines\stable_diffusion\pipeline │
│ _stable_diffusion.py:382 in _encode_prompt                                                       │
│                                                                                                  │
│   379 │   │   │   else:                                                                          │
│   380 │   │   │   │   attention_mask = None                                                      │
│   381 │   │   │                                                                                  │
│ ❱ 382 │   │   │   prompt_embeds = self.text_encoder(                                             │
│   383 │   │   │   │   text_input_ids.to(device),                                                 │
│   384 │   │   │   │   attention_mask=attention_mask,                                             │
│   385 │   │   │   )                                                                              │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\torch\nn\modules\module.py:1501 in _call_impl │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\accelerate\hooks.py:165 in new_forward        │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\transformers\models\clip\modeling_clip.py:816 │
│ in forward                                                                                       │
│                                                                                                  │
│    813 │   │   """                                                                            │
│    814 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return  │
│    815 │   │                                                                                     │
│ ❱  816 │   │   return self.text_model(                                                           │
│    817 │   │   │   input_ids=input_ids,                                                          │
│    818 │   │   │   attention_mask=attention_mask,                                                │
│    819 │   │   │   position_ids=position_ids,                                                    │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\torch\nn\modules\module.py:1501 in _call_impl │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\transformers\models\clip\modeling_clip.py:712 │
│ in forward                                                                                       │
│                                                                                                  │
│    709 │   │   input_shape = input_ids.size()                                                    │
│    710 │   │   input_ids = input_ids.view(-1, input_shape[-1])                                   │
│    711 │   │                                                                                     │
│ ❱  712 │   │   hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)   │
│    713 │   │                                                                                     │
│    714 │   │   bsz, seq_len = input_shape                                                        │
│    715 │   │   # CLIP's text model uses causal mask, prepare it here.                            │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\torch\nn\modules\module.py:1501 in _call_impl │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\accelerate\hooks.py:165 in new_forward        │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\transformers\models\clip\modeling_clip.py:230 │
│ in forward                                                                                       │
│                                                                                                  │
│    227 │   │   │   inputs_embeds = self.token_embedding(input_ids)                               │
│    228 │   │                                                                                     │
│    229 │   │   position_embeddings = self.position_embedding(position_ids)                       │
│ ❱  230 │   │   embeddings = inputs_embeds + position_embeddings                                  │
│    231 │   │                                                                                     │
│    232 │   │   return embeddings                                                                 │
│    233                                                                                           │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\torch\_prims_common\wrappers.py:220 in _fn    │
│                                                                                                  │
│   217 │   │   │   │   │   if k not in kwargs:                                                    │
│   218 │   │   │   │   │   │   kwargs[k] = out_attr                                               │
│   219 │   │   │                                                                                  │
│ ❱ 220 │   │   │   result = fn(*args, **kwargs)                                                   │
│   221 │   │   │   assert (                                                                       │
│   222 │   │   │   │   isinstance(result, TensorLike)                                             │
│   223 │   │   │   │   and is_tensor                                                              │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\torch\_prims_common\wrappers.py:130 in _fn    │
│                                                                                                  │
│   127 │   │   │   }                                                                              │
│   128 │   │   │   bound.arguments.update(promoted_args)                                          │
│   129 │   │   │                                                                                  │
│ ❱ 130 │   │   │   result = fn(**bound.arguments)                                                 │
│   131 │   │   │                                                                                  │
│   132 │   │   │   if isinstance(result, TensorLike):                                             │
│   133 │   │   │   │   return _maybe_convert_to_dtype(result, result_dtype)                       │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\torch\_refs\__init__.py:975 in add            │
│                                                                                                  │
│    972 │   │   │   raise ValueError(msg)                                                         │
│    973 │   │   b = prims.mul(b, alpha)                                                           │
│    974 │                                                                                         │
│ ❱  975 │   return prims.add(a, b)                                                                │
│    976                                                                                           │
│    977                                                                                           │
│    978 # TODO: add docstring                                                                     │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\torch\_ops.py:287 in __call__                 │
│                                                                                                  │
│   284 │   │   )                                                                                  │
│   285 │                                                                                          │
│   286 │   def __call__(self, *args, **kwargs):                                                   │
│ ❱ 287 │   │   return self._op(*args, **kwargs or {})                                             │
│   288 │                                                                                          │
│   289 │   def __hash__(self):                                                                    │
│   290 │   │   return hash(self._op)                                                              │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\torch\_prims\__init__.py:346 in               │
│ _elementwise_meta                                                                                │
│                                                                                                  │
│    343 │   if args_with_fixed_dtypes is not None:                                                │
│    344 │   │   args_ = list(args_with_fixed_dtypes) + args_                                      │
│    345 │                                                                                         │
│ ❱  346 │   utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)                        │
│    347 │   utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)                         │
│    348 │                                                                                         │
│    349 │   strides = utils.compute_elementwise_output_strides(*args_)                            │
│                                                                                                  │
│ D:\2.35\dev\stable-diffusion\env\lib\site-packages\torch\_prims_common\__init__.py:596 in        │
│ check_same_device                                                                                │
│                                                                                                  │
│    593 │   │   │   │   │   + str(device)                                                         │
│    594 │   │   │   │   │   + "!"                                                                 │
│    595 │   │   │   │   )                                                                         │
│ ❱  596 │   │   │   │   raise RuntimeError(msg)                                                   │
│    597 │   │   else:                                                                             │
│    598 │   │   │   msg = (                                                                       │
│    599 │   │   │   │   "Unexpected type when checking for same device, " + str(type(arg)) + "!"  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Tensor on device cuda:0 is not on the expected device meta!

System Info

Diffusers 0.17.1 Python 3.8.5 Windows 10

Who can help?

@williamberman, @patrickvonplaten @sayakpaul

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 21 (19 by maintainers)

Most upvoted comments

Hi @JeLuF, you are right, the textual inversion weights are not moved to the appropriate device when loaded. Thanks for spotting it!

One workaround while this is fixed is to reverse the call order, the following should work:

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",  
    torch_dtype=torch.float16,
)

pipe.load_textual_inversion("./happy512.pt", token="happy512")
pipe.enable_sequential_cpu_offload()
prompt = "girl, happy512"

image = pipe(prompt).images[0]

In addition, I’d recommend using enable_model_cpu_offload() if you can. Just in case you didn’t know about it, it’s much faster but takes more memory than enable_sequential_cpu_offload(). It also is more robust to this particular scenario, there’s no need to reverse the calls when using it.

Great then I think we should try to offload all hooks and later load them again when loading textual inversion and/or loading LoRA

@williamberman I think Patrick was asking for your inputs. I will give this a try.

The safest way to deal with accelerate hooks and adding new weights is going to be removing all hooks, adding the new weights, and then re-applying the hooks. This should be cheap so-long as the model stays on the cpu the whole time and we don’t have to move any weights to or from the gpu (which I think is the way it works today).

This will work with the current LoRA weights as is because the patched just swaps out the linear nn.Module for the patched linear nn.Module, i.e. from the view of accelerate, there’s nothing abnormal about the model.

@patrickvonplaten are you asking for just my opinion or should I implement this? If the latter, it should only be implemented for load_lora_weights and load_textual_inversion, yes?

I think we can fix both this and #3958 by make sure the newly added weights (whether it’s textual inversion or LoRA) are correctly placed on the devices. It means that we have to add hooks to the newly added weights just like there were before. @sayakpaul do you want to give this a try?