transformers: Pix2Struct: unable to overfit on a single training sample

System Info

  • transformers version: 4.28.0
  • Platform: Linux-5.4.0-1037-aws-x86_64-with-glibc2.27
  • Python version: 3.9.16
  • Huggingface_hub version: 0.13.4
  • Safetensors version: 0.3.0
  • PyTorch version (GPU?): 1.13.0+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

Here’s the minimal training loop:

import requests
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, AutoProcessor
from torch.optim import AdamW
import torch

torch.manual_seed(42)

model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")
processor = AutoProcessor.from_pretrained("google/pix2struct-base")

dummy_target = "The model should overfit this sentence"
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)

encoded_image = processor(images=image, return_tensors="pt")
encoded_text = processor(text=dummy_target, return_tensors='pt', max_length=20)
optimizer = AdamW(model.parameters(), lr=1e-4)

model.train()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
flattened_patches=encoded_image.flattened_patches.to(device)
attention_mask=encoded_image.attention_mask.to(device)
labels=encoded_text.input_ids.to(device)

for i in range(1000):
    outputs = model(
        flattened_patches=flattened_patches,
        attention_mask=attention_mask,
        labels=labels
                   )
    loss = outputs.loss
    
    loss.backward()

    optimizer.step()
    optimizer.zero_grad()
    if i % 50 == 0:
        model.eval()
        prediction = model.generate(
            flattened_patches=flattened_patches,
            attention_mask=attention_mask)
        print(f'step: {i} train_loss: {loss.item()} prediction: {processor.batch_decode(prediction)}')
        model.train()

Here’s the output I got:

step: 0 train_loss: 8.259493827819824 prediction: ['<pad> <img_src=cropped-img-20180924']
step: 50 train_loss: 1.9695181846618652 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 100 train_loss: 2.071323871612549 prediction: ['<pad> <The model should overfit this sentence should overfit this sentence should overfit this sentence should']
step: 150 train_loss: 2.0366554260253906 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 200 train_loss: 1.8225889205932617 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 250 train_loss: 1.6568734645843506 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 300 train_loss: 1.6770282983779907 prediction: ['<pad> The model should overfit this sentence sentence should overfit this sentence sentence should overfit this sentence']
step: 350 train_loss: 1.688515067100525 prediction: ['<pad> The model should overfit this sentence sentence overfit this sentence sentence overfit this sentence sentence over']
step: 400 train_loss: 1.6118296384811401 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 450 train_loss: 1.6204414367675781 prediction: ['<pad> The model should overfit this sentence sentence should overfit this sentence should overfit this sentence should']
step: 500 train_loss: 1.59645676612854 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 550 train_loss: 1.5818239450454712 prediction: ['<pad> The model should overfit this sentence sentence sentence sentence sentence sentence sentence sentence sentence sentence sentence sentence sentence']
step: 600 train_loss: 1.5775129795074463 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 650 train_loss: 1.561257243156433 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 700 train_loss: 1.5319150686264038 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 750 train_loss: 1.646193504333496 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 800 train_loss: 1.533736228942871 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 850 train_loss: 1.6203268766403198 prediction: ['<pad> The model should overfit this sentence should overfit this sentence should overfit this sentence should over']
step: 900 train_loss: 1.5132172107696533 prediction: ['<pad> The model should overfit this sentence sentence should overfit this sentence sentence should overfit this sentence']
step: 950 train_loss: 1.491452693939209 prediction: ['<pad> The model should overfit this sentence The model should overfit this sentence The model should overfit']

Expected behavior

I’ve been trying to fine-tune Pix2Struct starting from the base pretrained model, and have been unable to do so. The model collapses consistently and fails to overfit on that single training sample. I noticed a comment about this on the fine-tuning notebook: https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb

Let’s train the model! Run the simply the cell below for training the model. We have observed that finding the best hyper-parameters was quite challenging and required a lot of trials and errors, as the model can easily enter in “collapse-model” (always predicting the same output, no matter the input) if the HP are not chosen correctly. In this example, we found out that using AdamW optimizer with lr=1e-5 seemed to be the best approach.

To dig a little deeper, I’ve been trying to train on a single training sample with a minimal training loop, and see whether the model was able to correctly learn that single training sample. It seems that it’s not able to overfit on a single training sample after 1000 training steps. Unless I missed something in my training loop, that seems like a weird behavior and might be a symptom of a bug somewhere?

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 26 (24 by maintainers)

Most upvoted comments

I would love to be an official contributor, even if its just a one-line code change 😅 I will put together a PR shortly.

@younesbelkada I shared a notebook on how to train Matcha/Pix2Struct model for Kaggle’s Benetech competition, in case anyone is interested. This model achieved silver zone and includes the updates with the fix.

Yeah, the model seems to be learning well on >3k images dataset with the change on the decoder config. This seems to be the root cause. Really good catch @gbarello-uipath 😃

I have also been trying to finetune pix2struct. I find that the losses go to zero very quickly which made me suspect that the attention masks are not being set properly.

What I see is that in the Pix2StructText module, self.config.is_decoder is set to False, causing this line to output a non-causal attention mask.

If I add the line self.config.is_decoder = True to the line above that to force it to be a decoder things look more normal.

Indeed, the loss should go down to 0. I notice 2 things here:

Hi thanks for the detailed report, indeed this seems weird. I will have a look at it once I am back on Tuesday. cc also @NielsRogge and @nbroad1881 for visibility as they have been also working on fine-tuning Pix2struct

Thanks very much for sharing! It is really cool to see Matcha/Pix2Struct being using for winning notebooks in major kaggle competitions 🔥

Let’s close this issue as we merged #23051 ! @NielsRogge has also made a nice tutorial in https://github.com/NielsRogge/Transformers-Tutorials/tree/master/Pix2Struct Thanks everyone

I think the pretrained model configs should be fixed directly.

Yeah I had a hard time fine-tuning Pix2Struct myself. However looking at your code snippet, when you encode the target sequence:

from transformers import Pix2StructProcessor

processor = Pix2StructProcessor.from_pretrained("google/pix2struct-base")

dummy_target = "The model should overfit this sentence"
encoded_text = processor(text=dummy_target, return_tensors='pt', max_length=20)

then when decoding back to text:

processor.decode(encoded_text.input_ids.squeeze())

prints:

'The model should overfit this sentence'

So this target sequence doesn’t contain an EOS (end-of-sequence) token nor a BOS (beginning-of-sequence) token. Hence, when generating text using the generate() method, it will just continue predicting tokens, at this method only stops generating text when the model predicts the EOS token. As the model is trained to not produce the EOS token, it simply will keep on generating text (hence you’re getting ‘<pad> The model should overfit this sentence should overfit this sentence’ etc.). Also it looks like the first token is <pad> since the model’s BOS token is equal to the pad token, so you’ll need to add skip_special_tokens=True to the batch_decode method.

So cc @younesbelkada we’ll need to check that, in case the user sets the max length to 20, then the tokenizer should set the EOS token as last token appropriately. It looks like the processor’s tokenizer has this set:

>>> processor.tokenizer.eos_token
'</s>'