TTS: [Bug] BCELoss should not be masked
I have trained Tacotron2 but during eval / inference, it often doesn’t know when to stop decoding. This is a known issue in seq2seq models and i was trying to solve it in TensorFlowTTS when i gave up due to Tensorflow problems.
Training with enable_bos_eos=True
helps a bit but the output is still 3x the ground truth mel length for shorter audio: see length_data_eos.csv vs length_data_no_eos.csv
One reason is the BCELossMasked criterion – in its current form, it encourages the model never to stop decoding once it has passed mel_length
. Some of the loss results don’t quite make sense, as seen below:
import torch
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
# B x T_max
mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
return mask
from torch.nn import functional
length = torch.tensor([95])
mask = sequence_mask(length, 100)
pos_weight = torch.tensor([5.0])
target = 1. - sequence_mask(length - 1, 100).float() # [0, 0, .... 1, 1] where the first 1 is the last mel frame
true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target
zero_x = torch.zeros(target.shape) - 100. # simulate logits if it never stops decoding
early_x = -200. * sequence_mask(length - 3, 100).float() + 100. # simulate logits on early stopping
late_x = -200. * sequence_mask(length + 1, 100).float() + 100. # simulate logits on late stopping
# if we mask
>>> functional.binary_cross_entropy_with_logits(mask * true_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(3.4657) # Should be zero! It's not zero because of trailing zeros in the mask
>>> functional.binary_cross_entropy_with_logits(mask * zero_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(503.4657)
>>> functional.binary_cross_entropy_with_logits(mask * late_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(503.4657) # Stopping late should be better than not stopping at all. Again due to trailing zeros in the mask
>>> functional.binary_cross_entropy_with_logits(mask * early_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(203.4657) # Early stopping should be worse than late stopping because the audio will be cut
# if we don't mask
>>> functional.binary_cross_entropy_with_logits(true_x, target, pos_weight=pos_weight, reduction='sum')
tensor(0.) # correct
>>> functional.binary_cross_entropy_with_logits(zero_x, target, pos_weight=pos_weight, reduction='sum')
tensor(3000.) # correct
>>> functional.binary_cross_entropy_with_logits(late_x, target, pos_weight=pos_weight, reduction='sum')
tensor(1000.)
>>> functional.binary_cross_entropy_with_logits(early_x, target, pos_weight=pos_weight, reduction='sum')
tensor(200.) # still wrong
# pos_weight should be < 1 to penalize early stopping
>>> functional.binary_cross_entropy_with_logits(zero_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(120.0000)
>>> functional.binary_cross_entropy_with_logits(late_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(40.0000)
>>> functional.binary_cross_entropy_with_logits(early_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(200.) # correct
For now i am passing length=None
to avoid the mask and setting pos_weight=0.2
to experiment. Will update the training results.
Additional context
I would also propose renaming stop_tokens
to either stop_probs
or stop_logits
depending on context. Currently, inference()
produces stop_tokens
that represent stop probabilities, while forward()
produces the logits before sigmoid. Confusingly, both are called stop_tokens
.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 19 (10 by maintainers)
Commits related to this issue
- Fix BCELoss adressing #1192 — committed to coqui-ai/TTS by erogol 2 years ago
- v0.8.0 (#1810) * Fix checkpointing GAN models (#1641) * checkpoint sae step crash fix * checkpoint save step crash fix * Update gan.py updated requested changes * crash fix * Fix th... — committed to coqui-ai/TTS by erogol 2 years ago
Follow some steps:
1st Fork the 🐸 TTS repository (use the button “fork” at the top of the page)
2st Clone from your Fork (dev branch). The command will be some like:
git clone https://github.com/iamanigeeit/TTS.git -b dev
3st Change the files that you need.
4st Commit the changes with the commands (obs: Change the commit message 😃):
5st Push the commits to your fork with the command:
git push
6st Go to your fork (https://github.com/iamanigeeit/TTS). Github will identify that you’ve made changes and suggest the pull request and it will show a pull request button below “Go to file”, “Add file” and “code” buttons. Now you can click on the pull request button and send a pull request from your dev branch to Coqui’s dev branch 😃.
I used the same config as
recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py
exceptbatch_size=32
(due to GPU memory limit) andr=1
(i thinkr=1
is the correct one for Tacotron2). Training was for 100k steps each.