TransformerLens: [Bug Report] Pythia models / Rotary Embeddings don't match Huggingface.
Describe the bug Pythia model outputs don’t exactly match the Huggingface Transformers implementation.
Code example
def check_similarity_with_hf_model(tl_model, hf_model, atol, prompt="Hello, world!"):
tokens = tl_model.tokenizer.encode(prompt, return_tensors="pt")
logits = tl_model(tokens, prepend_bos=False)
hf_logits = hf_model(tokens).logits
assert torch.allclose(t.softmax(logits, dim=-1), t.softmax(hf_logits, dim=-1), atol=atol)
model_name = "EleutherAI/pythia-70m"
tl_model = HookedTransformer.from_pretrained(model_name)
hf_model = AutoModelForCausalLM.from_pretrained(model_name)
check_similarity_with_hf_model(tl_model, hf_model, atol=1e-5)
This fails with model_name = "EleutherAI/pythia-70m"
, but passes with every other model I tried.
It passes with pythia-70m
if I set atol=0.1
. Arthur says it works for him with atol=1e-3
.
System Info Describe the characteristic of your environment:
transformer_lens
main branch.- MacOS (CPU)
- Python 3.10.12
- Pytorch 1.13.1
Additional context See discussion in Open Source Mechanistic Interpretability Slack here: https://opensourcemechanistic.slack.com/archives/C04SRRE96UV/p1695593544494209
Checklist
- I have checked that there is no similar issue in the repo (required)
About this issue
- Original URL
- State: open
- Created 9 months ago
- Reactions: 2
- Comments: 15 (3 by maintainers)
@ed1d1a8d ^the egregious Llama-2 errors are fixed, we think! Now 1e-4 errors, only.
We are working on the dull task of porting on TL functions to match HF exactly which should resolve all further issues, but we haven’t finished this yet.
@ed1d1a8d I’ve reached the same problem, have you solved it already?
This issue also exists with Llama-7b-hf-chat and it is really bad. The following code fails even with atol=1 (it passes with atol=2). Code:
I’ve assigned myself to this as I’ve started trying to debug this. It seems like the most reasonable culprit is deviation in
calculate_sin_cos_rotary
but fixing that doesn’t fix compounding deviation as layers increase. Will put updates here when I know more.