transformers: `KeyError: 'Cache only has 0 layers, attempted to access layer with index 0'`
System Info
transformers
version: 4.36.0- Platform: Linux-5.15.0-70-generic-x86_64-with-glibc2.35
- Python version: 3.11.4
- Huggingface_hub version: 0.19.4
- Safetensors version: 0.3.3
- Accelerate version: 0.25.0.dev0
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: Yes
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, β¦) - My own task or dataset (give details below)
Reproduction
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
long_text = # ...
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, use_flash_attention_2=True, torch_dtype=torch.float16, device_map="auto")
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
messages = [
{"role": "user", "content": f"Summarize the following:\n{long_text}"}
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_new_tokens=8192, do_sample=True, streamer=streamer)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Expected behavior
Expected it to work or at least give me a cuda error.
About this issue
- Original URL
- State: closed
- Created 7 months ago
- Reactions: 4
- Comments: 21 (7 by maintainers)
Hi! Iβm Having the same Error: only happens when the token length is greater than the sliding window size (I do not have the error with transformers version 4.34.0, but when I upgrade to 4.36.0 I get the error)
Thanks! π
Traceback (most recent call last): File β/opt/conda/lib/python3.10/site-packages/uvicorn/protocols/http/httptools_impl.pyβ, line 426, in run_asgi result = await app( # type: ignore[func-returns-value] File β/opt/conda/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.pyβ, line 84, in call return await self.app(scope, receive, send) File β/opt/conda/lib/python3.10/site-packages/fastapi/applications.pyβ, line 1106, in call await super().call(scope, receive, send) File β/opt/conda/lib/python3.10/site-packages/starlette/applications.pyβ, line 122, in call await self.middleware_stack(scope, receive, send) File β/opt/conda/lib/python3.10/site-packages/starlette/middleware/errors.pyβ, line 184, in call raise exc File β/opt/conda/lib/python3.10/site-packages/starlette/middleware/errors.pyβ, line 162, in call await self.app(scope, receive, _send) File β/opt/conda/lib/python3.10/site-packages/starlette/middleware/cors.pyβ, line 83, in call await self.app(scope, receive, send) File β/opt/conda/lib/python3.10/site-packages/starlette/middleware/exceptions.pyβ, line 79, in call raise exc File β/opt/conda/lib/python3.10/site-packages/starlette/middleware/exceptions.pyβ, line 68, in call await self.app(scope, receive, sender) File β/opt/conda/lib/python3.10/site-packages/fastapi/middleware/asyncexitstack.pyβ, line 20, in call raise e File β/opt/conda/lib/python3.10/site-packages/fastapi/middleware/asyncexitstack.pyβ, line 17, in call await self.app(scope, receive, send) File β/opt/conda/lib/python3.10/site-packages/starlette/routing.pyβ, line 718, in call await route.handle(scope, receive, send) File β/opt/conda/lib/python3.10/site-packages/starlette/routing.pyβ, line 276, in handle await self.app(scope, receive, send) File β/opt/conda/lib/python3.10/site-packages/starlette/routing.pyβ, line 66, in app response = await func(request) File β/opt/conda/lib/python3.10/site-packages/fastapi/routing.pyβ, line 274, in app raw_response = await run_endpoint_function( File β/opt/conda/lib/python3.10/site-packages/fastapi/routing.pyβ, line 191, in run_endpoint_function return await dependant.call(**values) File β/home/mmockus/dev/chatbot/rGPT/host.pyβ, line 301, in benchmark_model response, time, previous_prompt = rgpt( File β/home/mmockus/dev/chatbot/rGPT/RGPT.pyβ, line 384, in call output = self.__generate_text(final_prompt) File β/opt/conda/lib/python3.10/site-packages/transformers/pipelines/base.pyβ, line 1140, in call return self.run_single(inputs, preprocess_params, forward_params, postprocess_params) File β/opt/conda/lib/python3.10/site-packages/transformers/pipelines/base.pyβ, line 1147, in run_single model_outputs = self.forward(model_inputs, **forward_params) File β/opt/conda/lib/python3.10/site-packages/transformers/pipelines/base.pyβ, line 1046, in forward model_outputs = self._forward(model_inputs, **forward_params) File β/home/mmockus/dev/chatbot/rGPT/instruct_pipeline.pyβ, line 60, in _forward generated_sequence = self.model.generate( File β/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.pyβ, line 115, in decorate_context return func(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.pyβ, line 1764, in generate return self.sample( File β/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.pyβ, line 2861, in sample outputs = self( File β/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.pyβ, line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.pyβ, line 1527, in _call_impl return forward_call(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/accelerate/hooks.pyβ, line 165, in new_forward output = module._old_forward(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.pyβ, line 1212, in forward outputs = self.model( File β/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.pyβ, line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.pyβ, line 1527, in _call_impl return forward_call(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.pyβ, line 1080, in forward layer_outputs = decoder_layer( File β/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.pyβ, line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.pyβ, line 1527, in _call_impl return forward_call(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/accelerate/hooks.pyβ, line 165, in new_forward output = module._old_forward(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.pyβ, line 796, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File β/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.pyβ, line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.pyβ, line 1527, in _call_impl return forward_call(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/accelerate/hooks.pyβ, line 165, in new_forward output = module._old_forward(*args, **kwargs) File β/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.pyβ, line 441, in forward past_key = past_key_value[0] File β/opt/conda/lib/python3.10/site-packages/transformers/cache_utils.pyβ, line 78, in getitem raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") KeyError: βCache only has 0 layers, attempted to access layer with index 0β
Opening a PR to fix it π
Thank you so much! I can confirm training works now π
The PR linked above fixes it π
This indeed seems like a caching issue. cc @gante It seems like this snippet was not updated to work with the new Cache class: https://github.com/huggingface/transformers/blob/2788f8d8d5f9cee2fe33a9292b0f3570bd566a6d/src/transformers/models/mistral/modeling_mistral.py#L388-L407
I suspect that not using Flash Attention 2 may solve the issue in the meantime.