audiolm-pytorch: Heads up that ComplexFloat doesn't appear to be supported by DistributedDataParallel (NCCL backend)

SoundStream training seems to be working well on a single GPU but when you attempt to use more than one GPU, this error is thrown up: RuntimeError: Input tensor data type is not supported for NCCL process group: ComplexFloat.

It would appear that the ComplexLeakyReLU class is probably the cause, but the workaround mentioned here doesn’t neccessarily seem to be a great answer.

Not really a bug for AudioLM, just a heads up for anyone attempting multi-GPU training right now.

(Pytorch 1.13)

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Reactions: 1
  • Comments: 32 (29 by maintainers)

Commits related to this issue

Most upvoted comments

@lucidrains well they’re both GAN papers to be fair. Seems like they are here to stay. I foresee upcoming commentators complaining about GAN collapse issues and it being a bit fiddly to fix, but that’s what the community is for.

@lucidrains I just wanted to share because I saw this method pop up again.

“the high loss is actually coming from the new “multi spectral recon loss” which wasn’t present in previous version, added in 0.11.2 . Just made it so you can turn it off by setting the multi spectral recon loss weight to 0”

@lucidrains another reason for a loss balancer ? #60

Schaffer et al 2022 train a GAN and include a poor-mans (non-EMA) loss balancer too: “To find suitable loss weights for all 3 types of losses on the generator (LSGAN loss, deep matching loss, and multi-scale spectral loss; see Section 3), we solved a least squares equation to weigh all loss terms equally for the first 1k training iterations. After that, we froze those weights applied to the losses for the remainder of training”

also, after some back and forth on the pytorch thread, i was able to get the proposed workaround running for complex conv2d

going to close this issue… inb4 @djqualia tells me it stopped converging again haha

just started training on 0.11.6 – FYI it started with the following warning 4x:

warnings.warn( /home/qualia/anaconda3/envs/audiolm/lib/python3.8/site-packages/torchaudio/functional/functional.py:571: UserWarning: At least one mel filterbank has all zero values. The value for n_mels (64) may be set too high. Or, the value for n_freqs (65) may be set too low.

@zhvng @djqualia ohh crap, i broke it sorry 😓 fixed! 🤞

which version are you using? i’ve trained with 0.11.1 the past ~24h and its working well for me. i haven’t been able to try the newer versions yet (but will test them later…)

heads up @lucidrains I get an error when running the newest version of the code:

Traceback (most recent call last):                                                                                                                                                                                               
  File ".../train_soundstream.py", line 24, in <module>                                                                                                                                              
    trainer.train()                                                                                                                                                                                                              
  File ".../audiolm-pytorch/audiolm_pytorch/trainer.py", line 418, in train                                                                                                                                        
    logs = self.train_step()                                                                                                                                                                                                     
  File ".../audiolm-pytorch/audiolm_pytorch/trainer.py", line 337, in train_step                                                                                                                                   
    discr_losses = self.soundstream(                                                                                                                                                                                             
  File ".../env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                   
    return forward_call(*input, **kwargs)                                                                                                                                                                                        
  File ".../audiolm-pytorch/audiolm_pytorch/soundstream.py", line 651, in forward                                                                                                                                  
    stft_discr_loss = hinge_discr_loss(stft_fake_logits, stft_real_logits)        
File "...audiolm_pytorch/soundstream.py", line 42, in hinge_discr_loss
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()
  File ".../env/lib/python3.10/site-packages/torch/nn/functional.py", line 1457, in relu
    result = torch.relu(input)
RuntimeError: clamp is not supported for complex types

tried to fix this by passing in abs(stft_fake_logits) and abs(stft_real_logits) into hinge_discr_loss here but i get a very high loss (~10^7)

20: soundstream total loss: 13529093.125, soundstream recon loss: 0.076 | discr (scale 1) loss: 3.104 | discr (scale 0.5) loss: 2.482 | discr (scale 0.25) loss: 2.575
21: soundstream total loss: 11155972.625, soundstream recon loss: 0.072 | discr (scale 1) loss: 3.264 | discr (scale 0.5) loss: 2.181 | discr (scale 0.25) loss: 2.203
22: soundstream total loss: 14122796.625, soundstream recon loss: 0.081 | discr (scale 1) loss: 3.616 | discr (scale 0.5) loss: 1.932 | discr (scale 0.25) loss: 2.035

Unfortunately I ran into a different issue when trying this on Monday so wasn’t able to verify. However, this note from one of the Encodec researchers (in the context of multi-GPU training) is interesting: “we actually did not use DDP for the training but custom distributed routines.”

Unsure if it’s related, just flagging in case.

Awesome, thanks, will do.

@eonglints thanks for reporting this! i’ll try to debug this later this week

for now, i addressed something that was bothering me, and switched over to modrelu that seems to be popular in the complex valued neural network field