transformers: MMS: target_lang=fra in pipeline() leads to "Size mismatch for lm_head.weight/bias when loading state_dict for Wav2Vec2ForCTC"

System Info

  • transformers version: 4.31.0.dev0
  • Platform: Linux-5.15.107±x86_64-with-glibc2.31
  • Python version: 3.10.12
  • Huggingface_hub version: 0.15.1
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): 2.12.0 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.6.9 (gpu)
  • Jax version: 0.4.10
  • JaxLib version: 0.4.10
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@patrickvonplaten @sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

Colab link: https://colab.research.google.com/drive/1YmABsYKCk39Z7GF390G316ZsEWH7GkqT?usp=sharing

Expected behavior

pipe = pipeline(model="facebook/mms-1b-l1107", model_kwargs={"target_lang":"fra"})

I expected this to create the pipeline with the fra adapter loaded, as seems to be intended here.

It fails with a size mismatch issue. Ignoring it seems to load the english adapter instead, as the result is poor and doesn’t match the demo on the official space (https://huggingface.co/spaces/facebook/MMS).

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Reactions: 1
  • Comments: 15 (11 by maintainers)

Most upvoted comments

For now this is a hack we have to do, but this PR: https://github.com/huggingface/transformers/pull/24335 should solve it nicely.

There was indeed a bug! What is happening here is the following.

1.) By specifying target_lang in the constructor method of the pipeline, it is passed to the constructor method of from_pretrained of the model, which means inside the pipeline(...) function this is called:

model_id = "facebook/mms-1b-fl102"

model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="fra")

2.) Now by passing target_lang="fra" however we load the french adapter weights here: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1880 in the init method of the model.

3.) However the init method is run before from_prertained(...) loads the state dict into the model. This means the correctly loaded French adapter layers are later again overwritten by the original English adapter layers (the default in the state dict). This was sadly not noticed in the tests because English adapter weights work surprisingly well for French 😅

=> A quick’n’dirty fix for users to get the exact same results as posted by @vineelpratap here, is running the following code:

from transformers import pipeline

model_id = "facebook/mms-1b-all"
pipe = pipeline(model=model_id, model_kwargs={"target_lang":"fra", "ignore_mismatched_sizes":True})
pipe.model.load_adapter("fra")  # THIS CORRECTS THE INCORRECTLY OVERWRITTEN WEIHGTS!

print(pipe("http://french.voiceoversamples.com/jeanNL.mp3"))

gives:

la première fois que vous allez ouvrir une interaction client vous seraet dirigée vers la page d'identification il s’agit du mode par des fauts utilisé pour toutes les interactions clients. veuillez vérifier le numéro de sécurité sociale de l'appelan avant de poursuivre. une fois après avoir confirmé clique sur le bouton suivant comme ceci très bien passans maintenant à l’étape 2

Also this is not a problem for the original demo because the original demo only make use of load_adapter after having called from_pretrained which solves this problem.

Hi, for the above sample - I get this result with FL102 models and using greedy decoding. I converted .mp3 to .wav using this command ffmpeg -y -i audio.mp3 -ar 16000 -ac 1 audio.wav

la première fois que vous allez ouvrir une interaction client vous seraet dirigée vers la page d'identification il s’agit du mode par des fauts utilisé pour toutes les interactions clients. veuillez vérifier le numéro de sécurité sociale de l'appelan avant de poursuivre. une fois après avoir confirmé clique sur le bouton suivant comme ceci très bien passans maintenant à l’étape 2

Is it possible that we used incorrect dictionary ?

it’s a bit surprising that the fl102 model perfoms worse than the “all” model here no?

Note that MMS-FL102 model is trained only on FLEURS data, which consists of about 10 hours of data per language while while MMS-ALL model is trained on combining MLS, Common Voice, FLEURS etc. So, it is expected that the performance of MMS-ALL model is better than MMS-FL102.

MMS-FL102, MMS-FL1107 were open sourced so that one can reproduce some of the results in the paper. If you care about the best performing ASR model, using MMS-ALL model would be the best choice. Running LM decoding will further boost performance as we discuss in the MMS paper, and we are working on open sourcing the LMs soon.

Hey @erickedji,

We were indeed missing docs here, I’ve added them in #24292 .

However, the way you used the pipeline is 100% correct! For some reason “facebook/mms-1b-l1107” doesn’t perform very well. However, “facebook/mms-1b-all” works well.

from transformers import pipeline

model_id = "facebook/mms-1b-all"
pipe = pipeline(model=model_id, model_kwargs={"target_lang":"fra", "ignore_mismatched_sizes":True})

print(pipe("http://french.voiceoversamples.com/jeanNL.mp3"))

gives

{'text': "la première fois que vous allez ouvrir une interaction client vous serait dirigée vers la page d'identification il s'agit du mode par défaut utilisé pour toutes les interactions clients veuillez vérifier le numéro de sécurité sociale de l'appelant avant de poursuivre une fois après avoir confirmé cliqué sur le bouton suivant comme ceci très bien passons maintenant à l'étape dex"}

and “facebook/mms-1b-fl102” also seems to have problems.

from transformers import pipeline

model_id = "facebook/mms-1b-fl102"
pipe = pipeline(model=model_id, model_kwargs={"target_lang":"fra", "ignore_mismatched_sizes":True})

print(pipe("http://french.voiceoversamples.com/jeanNL.mp3"))

gives

{'text': "la première fois que vous alez ouvrir une interaction client vous seraen dirigée vers la page d'identification il s’agit du mode par des fauts utilisé pour toutes les interactions clients veuillez vérifier le numéro de sécurité sociale de l'appelan avant de poursuivre une fois après avoir confirmé clicque sur le bouton suivant comme ceci très bien passons maintenant à l’étape d"}

cc @vineelpratap it’s a bit surprising that the fl102 model perfoms worse than the "all" model here no? Maybe I’ve made an error with the weight conversion? Could you maybe check what the original checkpoint & code gives for pure CTC for "http://french.voiceoversamples.com/jeanNL.mp3" ?

@Vaibhavs10 we also should run some evals on the whole FLEURS dataset to be sure.

See this guide for details on installing from main (or source) @erickedji: https://huggingface.co/docs/transformers/installation#install-from-source

That’s right 👍 you are installing from source (which includes the latest fix).

@patrickvonplaten @xenova The only difference seems to be the pip install, and I don’t see why it leads to a different behavior.

I just tried again by running the first and last cells here : https://colab.research.google.com/drive/1YmABsYKCk39Z7GF390G316ZsEWH7GkqT?usp=sharing

It worked. The above notebook basically does !pip install git+https://github.com/huggingface/transformers datasets[torch], then:

from transformers import pipeline
model_id = "facebook/mms-1b-all"
pipe = pipeline(model=model_id, model_kwargs={"target_lang":"fra", "ignore_mismatched_sizes":True})
output = pipe("http://french.voiceoversamples.com/jeanNL.mp3")
output

I’m not familiar enough with pip to comment. @xenova Can you try with the same pip call as my notebook?

I don’t know currently, I’ll try to look into it over the weekend

@patrickvonplaten - do you know why there is a discrepancy in the output of FL102 models from fairseq and transformer models for the above audio sample in French. It would be good to figure out the underlying issue.