transformers: Possible Bug with KV Caching in Llama (original) model
System Info
transformers==4.31.0
- huggingface_hub version: 0.15.1
- Platform: Linux-5.15.0-78-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Running in iPython ?: No
- Running in notebook ?: No
- Running in Google Colab ?: No
- Token path ?: /u/k/h/khanov/.cache/huggingface/token
- Has saved token ?: False
- Configured git credential helpers:
- FastAI: N/A
- Tensorflow: N/A
- Torch: 2.0.0
- Jinja2: 3.0.3
- Graphviz: N/A
- Pydot: N/A
- Pillow: 9.0.1
- hf_transfer: N/A
- gradio: N/A
- numpy: 1.24.2
- ENDPOINT: https://huggingface.co
- HUGGINGFACE_HUB_CACHE: /u/k/h/khanov/.cache/huggingface/hub
- HUGGINGFACE_ASSETS_CACHE: /u/k/h/khanov/.cache/huggingface/assets
- HF_TOKEN_PATH: /u/k/h/khanov/.cache/huggingface/token
- HF_HUB_OFFLINE: False
- HF_HUB_DISABLE_TELEMETRY: False
- HF_HUB_DISABLE_PROGRESS_BARS: None
- HF_HUB_DISABLE_SYMLINKS_WARNING: False
- HF_HUB_DISABLE_EXPERIMENTAL_WARNING: False
- HF_HUB_DISABLE_IMPLICIT_TOKEN: False
- HF_HUB_ENABLE_HF_TRANSFER: False
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
I was working on a custom decoding method, however, I found a deviation from greedy search when using KV caching.
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
MODEL_PATH = "/nobackup-fast/khanov/llama-7b" # "huggyllama/llama-7b"
GEN_DEV = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(GEN_DEV)
def get_input_ids(prompt: str) -> torch.Tensor:
global model, tokenizer
tokens = tokenizer(prompt, return_tensors="pt").input_ids.to(GEN_DEV)
return tokens
def tokens_to_text(tokens: torch.Tensor):
return tokenizer.batch_decode(tokens, skip_special_tokens=True)
PROMPT = "This is a " # this is just a test prompt
# greedy decoding without caching
tokens = get_input_ids(PROMPT)
for _ in tqdm(range(40)):
with torch.no_grad():
mout = model(tokens)
tokens = torch.hstack((tokens, torch.argmax(mout.logits[0, -1]).unsqueeze(0).unsqueeze(0)))
without_cache = tokens_to_text(tokens)[0]
print(f"{without_cache=}")
# greedy decoding WITH caching
tokens = get_input_ids(PROMPT)
cached = None
for _ in tqdm(range(40)):
with torch.no_grad():
if cached is None:
mout = model(tokens, output_hidden_states=True, use_cache=True)
cached = mout.past_key_values
else:
mout = model(tokens, past_key_values=cached, use_cache=True, output_hidden_states=True)
cached = mout.past_key_values
tokens = torch.hstack((tokens, torch.argmax(mout.logits[0, -1]).unsqueeze(0).unsqueeze(0)))
with_cache = tokens_to_text(tokens)[0]
print(f"{with_cache=}")
# normal greedy search with HF Generate implementation
tokens = get_input_ids(PROMPT)
tokens = model.generate(tokens, num_return_sequences=1, max_new_tokens=40)
generate_output = tokens_to_text(tokens)[0]
print(f"{generate_output=}")
# this matches exactly
assert without_cache == generate_output
# this does not!
assert without_cache == with_cache
Expected behavior
I was expecting the results to not change when using the past_key_values kwarg, however, when passing past_key_values, the model assigned different logits to the tokens. This deviates from the model.generate behavior too. This is possibly related to #18809, and #21080.
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 22 (6 by maintainers)
Hey folks 👋 I’ve done a deep dive on this issue, and I will link related issues to this comment that attempts to summarize findings.
cc:
TL;DR
Using KV caches (and, in some models, left-padding) do change the
logits. This happens in most, if not all models at all precisions, but it is almost imperceptible in FP32. With 16 bits, the difference becomes non-negligible. The model was not trained with KV caches or left-padding, so this is introducing a distribution shift – it’s part of the cost of using a lower precision and other related optimizations. The effect is more visible whendo_sample=True, as greedy decoding (do_sample=False) often selects the same token despite the differences.Why does this happen?
A key operation in neural networks is matrix multiplication, where values are multiplied and accumulated. Unless you have infinite precision, different implementations or different shapes (e.g. crop a few rows of the first matrix) may produce different outputs, as the intermediary calculations must remain in the specified precision and are subject to rounding. For instance, our models with TF and JAX implementations never have the exact output as the PyTorch implementation, they tend to differ by a maximum
1e-5at FP32 for the same exact input, due to minor differences in the frameworks’ inner implementations.When using KV caches (and, in some models, left-padding), we are changing the input shape to some matrix multiplication operations. For instance, in Llama, when you apply the linear projection to obtain the QKV for the attention layer, the input shape will be different depending on whether you’re using left-padding and/or KV caches. Therefore, the output of these operations may be different, and these tiny differences build up across layers and across generated tokens, especially at lower resolutions.
If you place a breakpoint inside the model, and see what happens with and without KV caches, you’ll see:
How big is this difference?
Let’s do a simple experiment: for the same set of inputs, let’s measure the hidden states’ and the logits’ maximum difference for the first generated token, with and without KV caching. I created the following test script from an example given in a related issue (https://github.com/huggingface/transformers/issues/26344). TL;DR it averages the maximum value for the variables described above over 1000 runs:
Test script
Here are the results I got for
CodeLlama(which uses the same code as Llama and Llama2), withGPT2in FP16 for comparison:Llama, FP32
Llama, FP16 (the expected 16-bit format to use)
Llama, BF16 (the wrong 16-bit format to use with Llama)
GPT2, FP16
As we can see:
GPT2What can we do about it?
First of all: the benefits of using variables with lower precision and KV caching is obvious. Are the downsides worth it? My advice is to measure the model on metrics relevant to your task (e.g. perplexity), and compare the cost-benefits on your use case. I suspect using KV caching will remain cost-effective 😃
Secondly: there may be ways to reduce this mismatch, but so far I haven’t found any. A common trick is to upcast some sensible operations to FP32 (like the on the attention layers’ softmax). For completeness, on Llama, I tried:
Linearlayers in the attention layerapply_rotary_pos_embin FP32 (while keepingsinandcosin FP32 as well)self.input_layernorm(hidden_states)self.post_attention_layernorm(hidden_states)Most had no impact, some reduced the mismatch at a high throughput cost.
Finally, regarding left-padding: We might be able to mitigate problems here when we migrate batched generation to nested tensors, which don’t need padding.
I hope this comprehensive analysis helps you understand what’s going on 🤗 And, who knows, be the spark that ignites a solution to this issue 🪄
Hey @jmzeng 👋
It’s impossible to convert between
fp16andbf16without rounding, which means that your model will lose performance once you switch. Switching before fine-tuning might be okay, depending on the model and how long your fine-tuning is – you give the model a chance to recover from the rounding errors. However, switching before inference will be a source of distribution drift, which almost surely will negatively impact your downstream performance.That being said, note that
bf16is indeed better for fine-tuning due to its dynamic precision range, andfp16tends to excel at inference time due to its better accumulation precision. So it’s not an easy answer here 😄Finally, if you’re using techniques like LORA (see our peft library), you can get away with doing fine-tuning in
fp32. Then, you can downcast tofp16with fewer problems.Ok I think I found the culprit! It seems that when using past_key_values, and bfloat16 the errors are huge.
float32 (default): max abs diff between logits (with vs without past_key_values) = 1.0490e-05
With bfloat16: max abs diff between logits (with vs without past_key_values) = 0.1250
With float16: max abs diff between logits (with vs without past_key_values) = 0.0195
Since the unit tests only check for f32, they aren’t catching this.
Here’s the script to measure this:
Script
Any ideas of how to fix this discrepancy?
Thanks so so much! Turns out the
prepare_inputs_for_generationfunction prepared the positional ID information as you said and after adding that in, the results exactly match! I’ll go ahead and close this!@varung “wrong” is perhaps too strong of a word – suboptimal would be more precise. We have collaborated with the authors of Llama 2, and they have suggested the use of
fp16. You can see it in our examples, when we released the model (e.g. here).In practice, it depends on how the model is saved – we should load the model in the format in which it was stored. If it was stored in
fp32and you want to operate it in a 16-bit precision,fp16is superior.Thanks for the detailed explanation @gante ! makes a lot of sense!
Hey @maximkha I don’t have an update on this right now no 😅 will let @gante have a look I will not have time to dive into this.
Hi @maximkha 👋
Thank you for raising this issue! Sadly, our bandwidth is limited, so our capacity to dive into custom code for which a solution already exists is limited 😃
As @ArthurZucker wrote, you are missing the position IDs, which may have a significant impact on the output. The same is true for the attention mask. Our modeling code makes its best effort to infer these two inputs when they are missing, but it fails in some cases.
My suggestion would be to introduce a
breakpoint()ingenerate, before the model forward pass, and compare the inputs that go into the model 😃