llama.cpp: parallel/server crashes with: ggml.c:16521: i != GGML_HASHTABLE_FULL when defragmentation is enabled
Context
Using latest 17e98d4c96a583d420f12046bc92102381dbd28e llama.cpp server.
Server started with a llama70b-F16 like model:
server \
--model model-f16.gguf \
--ctx-size 32768 \
--n-predict 4096 \
--parallel 32 \
--n-gpu-layers 81 \
--batch-size 4096 \
--ubatch-size 256 \
--metrics \
--mg 1 \
--log-format text \
--defrag-thold 0.1
When sending 32 concurrent requests, the server crashes with:
GGML_ASSERT: /llama.cpp/ggml.c:16521: i != GGML_HASHTABLE_FULL
Backend is CUDA, on 2 A100, compute capability 80.
EDIT: The issue is related with defragmentation, quick fix: disable defragmentation
About this issue
- Original URL
- State: open
- Created 3 months ago
- Comments: 28 (3 by maintainers)
Commits related to this issue
- KV Cache defrag hash overflow - TMP Fix by @slaren #6685 — committed to ggerganov/llama.cpp by phymbert 2 months ago
A proper fix will take some time, but this should fix it for now.
This command triggers the error quite easily:
Let me know if does not work and can send you patch to trigger always
Edit: this one also triggers it and uses 13B model:
I think that the issue is that
ggml_backend_schedallocates a hash table based on the actual size of the graph used during measure, butllama_kv_cache_defrag_internaluses up toLLAMA_MAX_NODES. Limiting the size of the hash table is important for performance, as clearing a large table has a significant cost.32 sequences, each with 4096 tokens requires a KV cache of size
32*4096 = 131072in order to handle the worst case, so this setup should theoretically run out of KV cache slots. Not sure about the error that you get though, but it could be relatedAnother thing to look into is the hypothesis that we are evaluating batches partially:
https://github.com/ggerganov/llama.cpp/issues/6617#issuecomment-2051618514
Do you get the error with equal
batchandubatch?@slaren The temporary fix is behaving well, 32 users on F16 70b with 1024 max context each for 30min. No shifting, no KV cache full. And no crash 😃 I have pushed eedd42e3767efb49cd497cdef3943397b42ee935 in order to retrieve it securely, but I will delete this temp branch once you feel ready to submit the target patch. FYI: 2 A100 average PP=200tk/s per sequence, TG=4,5tk/s per sequence, KV Cache usage ratio=0,78
I will also test flash attention, but this is quite out of the scope of this issue.
Maybe @ggerganov has a suggestion for how to trigger a worst-case defragment in a simple way.
I am sorry but what I meant by “simply command” was something like curl, this is too complicated and it is going to take too much time to reproduce.
I am not sure which model size will cause the issue, but these steps can be done for any model, I am trying to reproduce with a llama 70B F16 in an isolated environment.
Assuming you are on a debian like OS.
--defrag-thold0.1,)I will update the steps if I miss something here, but this is the general idea.
EDIT: I am not german, but this is the latest model I found on HF 😄
Is there a simple command that I can run to reproduce this issue? It seems that the issue is the size of the KV defrag graph.
It should give you a full call stack automatically if gdb is installed. Otherwise just run
gdb --args ./server ...and typebtto get the call stack when it crashes. It will be more accurate with a debug build, but probably even in release will be enough to figure the issue.The size of a hash table is probably being underestimated. However without the call stack it is hard to figure which one.