attention_sinks: Trying a minimal example with LlamaForCasualLM, sadly it fails

My minimal example:

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList
repo = "meta-llama/Llama-2-13b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(repo)
from attention_sinks import LlamaForCausalLM 
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto", load_in_4bit=True)
# Set the text you want to generate text based on

#text = "<s> you are hepful assistant. </s> <u> Tell me the pros and cons of coffee. Two points. </u>"
text = "<s> you are hepful assistant. </s> <u> Write me a long essay on the reasons for fall of roman empire/u>"

# Encode the text
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)

# Generate text
generated_tokens = model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=4096)
# Decode the generated text
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

print(generated_text)

Fails here:

File [~/mambaforge/envs/data_science/lib/python3.10/site-packages/attention_sinks/models/llama/pos_shift.py:103](https://file+.vscode-resource.vscode-cdn.net/home/alexbalandi/betterwithai/personalized_assistant/notebooks/~/mambaforge/envs/data_science/lib/python3.10/site-packages/attention_sinks/models/llama/pos_shift.py:103), in llama_pos_shift_attention_forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
    101 if attention_mask is not None:
    102     if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
--> 103         raise ValueError(
    104             f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
    105         )
    106     attn_weights = attn_weights + attention_mask
    108 # upcast attention to fp32

ValueError: Attention mask should be of size (1, 1, 1, 1025), but is torch.Size([1, 1, 1, 1026])

The root of issue is clear, but trying dumb fixes (like slicing the attention mask to make it “fit”) doesn’t work. Is it at least reproducable in your env? 👀 I’d really appreciate any pointers on ways to fix this 🙏

About this issue

  • Original URL
  • State: closed
  • Created 9 months ago
  • Comments: 16 (10 by maintainers)

Most upvoted comments

Resolved in 48bb293d4fb15d08bdeb3a0425cee0ea78f8ba52, thanks again for reporting

Just chiming in here to say thank you for all your hard work that makes it easier to experiment with the results of the paper, you rock 🤗

Please try the following snippet with the model of your choice and a corresponding prompt. The tokenizer here is set up to endlessly generate, so it may still eventually lose track of what it was doing, but it shouldn’t forget English like what would happen with pure transformers or windowed attention.

import torch
from transformers import AutoTokenizer, TextStreamer, GenerationConfig
from attention_sinks import AutoModelForCausalLM


# model_id = "meta-llama/Llama-2-7b-hf"
# model_id = "mistralai/Mistral-7B-v0.1"
model_id = "mosaicml/mpt-7b"
# model_id = "tiiuae/falcon-7b"
# model_id = "EleutherAI/pythia-6.9b-deduped"

# Load the chosen model and corresponding tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # for efficiency:
    device_map="auto",
    torch_dtype=torch.float16,
    # `attention_sinks`-specific arguments:
    attention_sink_size=4,
    attention_sink_window_size=252,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Our input text
text = "Vaswani et al. (2017) introduced the Transformers"

# Encode the text
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)

# Print tokens as they're being generated
streamer = TextStreamer(tokenizer)
generated_tokens = model.generate(
    input_ids,
    generation_config=GenerationConfig(
        # use_cache=True is required, the rest can be changed up.
        use_cache=True,
        min_new_tokens=20000,
        max_new_tokens=50000,
        penalty_alpha=0.6,
        top_k=5,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    ),
    streamer=streamer,
)
# Decode the final generated text
output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
  • Tom Aarsen

Feel free to experiment with #6 to get model.generate working.

I also get failures when calling generate, although the model does work if I do the generation manually like so:

with torch.no_grad():
    past_key_values = None
    for i in range(4096):
        input_ids.to(model.device)
        outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
        logits = outputs.logits.view(-1, model.config.vocab_size)
        past_key_values = outputs.past_key_values
        token = logits[-1,:].argmax()
        print(i, tokenizer.decode(token, clean_up_tokenization_spaces=False))
        input_ids = token.unsqueeze(0).unsqueeze(0)

It ends up writing <u> Write me a long essay on the reasons for fall of roman empire/u> over and over for thousands of tokens (because this is not a instruct-tuned model, this is how the pure transformers model reacts too).

I’ll also check what happens if I use the windowed attention approach, i.e. the green line here. Edit: See these outputs. The left is the index and the right is the output token. It completely loses the plot.

996 Write
997 O
998 O
999 u
1000 /
1001 u
1002 /
1003 u
1004 /
1005 u
1006 /
1007 /
1008 /
1009 /
1010 u
1011 /
1012 /
1013 /
1014 /
1015 /
...
1704 O
1705 O
1706 O
1707 O
1708 in
1709 .
1710 Ћ
1711 nobody
1712 nobody
1713 nobody
1714 nobody
1715 nobody
1716 nobody

So, attention_sinks does work, but not with model.generate at the moment. I’ll have to debug the generate method to figure out where the issue originates.

  • Tom Aarsen

Let me look into this! I haven’t tried to generate myself: i’ve only tried to directly call forward on the LlamaModel/FalconModel in my benchmarks.