TTS: [Bug] BCELoss should not be masked

I have trained Tacotron2 but during eval / inference, it often doesn’t know when to stop decoding. This is a known issue in seq2seq models and i was trying to solve it in TensorFlowTTS when i gave up due to Tensorflow problems.

Training with enable_bos_eos=True helps a bit but the output is still 3x the ground truth mel length for shorter audio: see length_data_eos.csv vs length_data_no_eos.csv

One reason is the BCELossMasked criterion – in its current form, it encourages the model never to stop decoding once it has passed mel_length. Some of the loss results don’t quite make sense, as seen below:

import torch
def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
    # B x T_max
    mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
    return mask

from torch.nn import functional
length = torch.tensor([95])
mask = sequence_mask(length, 100)
pos_weight = torch.tensor([5.0])
target = 1. - sequence_mask(length - 1, 100).float()  # [0, 0, .... 1, 1] where the first 1 is the last mel frame
true_x = target * 200 - 100  # creates logits of [-100, -100, ... 100, 100] corresponding to target
zero_x = torch.zeros(target.shape) - 100.  # simulate logits if it never stops decoding
early_x = -200. * sequence_mask(length - 3, 100).float() + 100.  # simulate logits on early stopping
late_x = -200. * sequence_mask(length + 1, 100).float() + 100.  # simulate logits on late stopping

# if we mask
>>> functional.binary_cross_entropy_with_logits(mask * true_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(3.4657)  # Should be zero! It's not zero because of trailing zeros in the mask
>>> functional.binary_cross_entropy_with_logits(mask * zero_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(503.4657)
>>> functional.binary_cross_entropy_with_logits(mask * late_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(503.4657)  # Stopping late should be better than not stopping at all. Again due to trailing zeros in the mask
>>> functional.binary_cross_entropy_with_logits(mask * early_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(203.4657)  # Early stopping should be worse than late stopping because the audio will be cut

# if we don't mask
>>> functional.binary_cross_entropy_with_logits(true_x, target, pos_weight=pos_weight, reduction='sum')
tensor(0.)  # correct
>>> functional.binary_cross_entropy_with_logits(zero_x, target, pos_weight=pos_weight, reduction='sum')
tensor(3000.)  # correct
>>> functional.binary_cross_entropy_with_logits(late_x, target, pos_weight=pos_weight, reduction='sum')
tensor(1000.)
>>> functional.binary_cross_entropy_with_logits(early_x, target, pos_weight=pos_weight, reduction='sum')
tensor(200.)  # still wrong

# pos_weight should be < 1 to penalize early stopping
>>> functional.binary_cross_entropy_with_logits(zero_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(120.0000)
>>> functional.binary_cross_entropy_with_logits(late_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(40.0000)
>>> functional.binary_cross_entropy_with_logits(early_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(200.)  # correct

For now i am passing length=None to avoid the mask and setting pos_weight=0.2 to experiment. Will update the training results.

Additional context

I would also propose renaming stop_tokens to either stop_probs or stop_logits depending on context. Currently, inference() produces stop_tokens that represent stop probabilities, while forward() produces the logits before sigmoid. Confusingly, both are called stop_tokens.

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 19 (10 by maintainers)

Commits related to this issue

Most upvoted comments

Follow some steps:

1st Fork the 🐸 TTS repository (use the button “fork” at the top of the page)

2st Clone from your Fork (dev branch). The command will be some like: git clone https://github.com/iamanigeeit/TTS.git -b dev

3st Change the files that you need.

4st Commit the changes with the commands (obs: Change the commit message 😃):

git add .
git commit -m "Commit message"

5st Push the commits to your fork with the command: git push

6st Go to your fork (https://github.com/iamanigeeit/TTS). Github will identify that you’ve made changes and suggest the pull request and it will show a pull request button below “Go to file”, “Add file” and “code” buttons. Now you can click on the pull request button and send a pull request from your dev branch to Coqui’s dev branch 😃.

I used the same config as recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py except batch_size=32 (due to GPU memory limit) and r=1 (i think r=1 is the correct one for Tacotron2). Training was for 100k steps each.