fairseq: Text-to-Speech problem
Hi, I am trying to use this model with fairseq: https://huggingface.co/facebook/tts_transformer-zh-cv7_css10 I am using the following code snippet for the model download and initialization:
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
"facebook/tts_transformer-zh-cv7_css10",
arg_overrides={"vocoder": "hifigan", "fp16": False}
)
model = models[0]
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
generator = task.build_generator(model, cfg)
But running this code, I just get an error: “TTSTransformerModel object is not subscriptable”. If I replace model with models in the build_generator, like this:
generator = task.build_generator(models, cfg)
then the initialization will go through, but then I will get an error during the Text-to-Speech process. The code snippet for this part is the following:
text = "您好,这是试运行。"
sample = TTSHubInterface.get_model_input(task, text)
wav, rate = TTSHubInterface.get_prediction(task, model, generator, sample)
And I will get the following error within the get_prediction: “Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same”.
Any suggestion or idea how I can get this work? Or where can I find an example? Thanks!
The environment is the following:
- fairseq Version: main
- PyTorch Version: 1.8.1+cu111
- OS (e.g., Linux): Ubuntu 18.04
- How you installed fairseq (
pip
, source): from source - Build command you used (if compiling from source): pip install --editable ./
- Python version: 3.7.12
- CUDA/cuDNN version: 11.1.1
- GPU models and configuration: NVIDIA A100
About this issue
- Original URL
- State: open
- Created 2 years ago
- Reactions: 3
- Comments: 15 (2 by maintainers)
Finally this worked (the problem was related to
IPython.display.Audio
not liking thewav
variable was on devicecuda:0
):我是通过下面这样解决了: ` models, cfg, task = load_model_ensemble_and_task_from_hf_hub( “facebook/tts_transformer-zh-cv7_css10”, arg_overrides={“vocoder”: “hifigan”, “fp16”: False} ) model = models # 第一处改动,删除了[0] TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg) generator = task.build_generator(model, cfg)
text = “您好,这是试运行。”
sample = TTSHubInterface.get_model_input(task, text) wav, rate = TTSHubInterface.get_prediction(task, models[0], generator, sample) # 第二处改动,model -> models[0] `