transformers: Unable to reproduce the simple code snippets in official doc for google/gemma-7b

System Info

  • transformers version: 4.39.0.dev0
  • Platform: Linux-5.15.0-60-generic-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.21.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.0 (True)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer, GemmaForCausalLM

model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")

prompt = "What is your favorite condiment?"
inputs = tokenizer(prompt, return_tensors="pt")

generate_ids = model.generate(inputs.input_ids, max_length=30)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

Then it will report a bug like this:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/net/nfs/mosaic/day/envs/opensource/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/day/transformers/src/transformers/generation/utils.py", line 1549, in generate
    result = self.greedy_search(
  File "/home/day/transformers/src/transformers/generation/utils.py", line 2420, in greedy_search
    outputs = self(
  File "/net/nfs/mosaic/day/envs/opensource/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/day/transformers/src/transformers/models/gemma/modeling_gemma.py", line 1062, in forward
    outputs = self.model(
  File "/net/nfs/mosaic/day/envs/opensource/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/day/transformers/src/transformers/models/gemma/modeling_gemma.py", line 904, in forward
    layer_outputs = decoder_layer(
  File "/net/nfs/mosaic/day/envs/opensource/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/day/transformers/src/transformers/models/gemma/modeling_gemma.py", line 624, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/net/nfs/mosaic/day/envs/opensource/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/day/transformers/src/transformers/models/gemma/modeling_gemma.py", line 279, in forward
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
RuntimeError: shape '[1, 8, 3072]' is invalid for input of size 32768

Expected behavior

Hi Hf! I encountered a problem about shape misalignment when I use the official code example to test the google/gemma-7b model. Would appreciate a lot if you could offer some help to fix this issue! I’m using the library built from the latest source code.

About this issue

  • Original URL
  • State: closed
  • Created 4 months ago
  • Reactions: 4
  • Comments: 20 (4 by maintainers)

Most upvoted comments

Reopening.

This issue is still not solved: The perplexity values of gemma-2b and gemma-7b (much worse, near random) are very different. Wikitext-v2 token perplexity for gemma-2b ~= 21. For gemma-7b it is a very large value ~= 1e13.

Not sure of the reason, but it does have to be a problem with the implementation, it might be because of the weights, or some embedding/tokenizer mismatch.

It is not related to the torch version. If the traceback is:

  File "/home/ray/anaconda3/lib/python3.8/site-packages/transformers/models/gemma/modeling_gemma.py", line 279, in forward
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
RuntimeError: shape '[1, 49, 3072]' is invalid for input of size 200704

then you have the wrong version of transformers, Because we shipped:

- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = attn_output.reshape(bsz, q_len, -1)

Hi everyone! I got the response from @alisafaya before (but not sure why the reply was recalled, if there’re other better alternatives, please feel free to update here).

Replacing attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) with attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) in the [transformers/models/gemma/modeling_gemma.py](https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L279) file solves this issue.

Note that it solves the bug, but the outputs may not be correct.

cc @arnavgarg1


In [2]: model
Out[2]: 
GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 3072, padding_idx=0)
    (layers): ModuleList(
      (0-27): 28 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (v_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3072, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (up_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (down_proj): Linear(in_features=24576, out_features=3072, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRMSNorm()
  )
  (lm_head): Linear(in_features=3072, out_features=256000, bias=False)
)

double checked that I am using eager

@ArthurZucker What’s interesting here is that google/gemma-2b-it works fine the script above using master/release, and the outputs are logical and make sense.

So the issue seems to be specific to google/gemma-7b and google/gemma-7b-it.