pytorch-lightning: Training stuck at 0% after few epochs while training with DDP

🐛 Bug

I recently updated to pytorch_lightning 1.1.7 and noticed that after a few epochs of training, the training % is stuck at 0% and never progresses. When I switched back to 1.1.4, this strange behavior does not occur. I do not know the root cause of this issue.

  • PyTorch Version (e.g., 1.0): 1.1.7
  • OS (e.g., Linux): Linux, U18
  • How you installed PyTorch (conda, pip, source): pip install
  • Build command you used (if compiling from source):
  • Python version: 3.6
  • CUDA/cuDNN version: 10.0
  • GPU models and configuration: RTX 2080 x3
  • Any other relevant information:

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 11
  • Comments: 45 (16 by maintainers)

Most upvoted comments

In my case, it stuck at 0% at epoch 18 with 2 gpus ddp before. Then I try to use only 1 gpu, currently trained for 100+ epochs without any problem.

Wanted to add some details. We ran some different code with copy paste and the problem reoccured. We had the checkpoint set up to monitor train_loss and once we took that out everything worked fine. In other words, I suspect this has something to do with monitor on ddp .

同样的问题。 pytorch-闪电版本1.3.7post0

You can downgrade to 1.2.4

yea, training completed without any issues.

same issue with 1.8.3.post1

Edit: response to a deleted question ¯_(ツ)_/¯

self.all_gather will return a tensor of shape [num_gpus, x] where x is the result from a single GPU. If you’re only syncing a scalar, then you can just take a mean (or min/max/std) like in my example.

If for example, x is not a scalar (e.g. you want to calculate IoU or something), you can combine them using something like:

def gather_and_squash(self, x):
    # The reshape goes from [n_gpus, bs, N, N] to [n_gpus * bs, N, N]
    return torch.reshape(self.all_gather(x), [-1] + list(x.shape)[1:])

At the moment I’m working on a joint segmentation & classification problem and using valid_clf as my monitor. Here’s what my validation_epoch_end looks like:

def validation_epoch_end(self, outputs):
    loss_val = torch.stack([x["val_loss"] for x in outputs]).mean()
    loss_seg = torch.stack([x["loss_seg"] for x in outputs]).mean()
    loss_clf = torch.stack([x["loss_clf"] for x in outputs]).mean()
    
    log = {
        "loss/valid_seg": torch.mean(self.all_gather(loss_seg)),
        "loss/valid_clf": torch.mean(self.all_gather(loss_clf)),
    }
           
    # AUC calculation etc...
    y_true = torch.cat([x["y_true"] for x in outputs])
    y_pred = torch.cat([x["y_pred"] for x in outputs])
    # etc ...

    self.log_dict(log)
    self.log_dict(
        {"loss/valid": loss_val, "auc/overall": log["metric"]}, prog_bar=True,
    )

Also having this problem. Always getting stuck on epoch 2, if I have checkpoints enabled when training with ddp on single machine

same problem here, stuck at 0% at epoch 18

I ran the script again as NCCL_ASYNC_ERROR_HANDLING=1 python train.py and got no difference in the output :

https://pastebin.com/jtLyuXTN

This time it got stuck at epoch 11

I think I’ve isolated the issue from the discussion in https://github.com/PyTorchLightning/pytorch-lightning/issues/5604#issuecomment-785783418.

This issue started when I switched the ModelCheckpoint monitor from AUC (using PL metrics and dist_sync_on_step=True) to val_loss. val_loss was not being synced between the 2 GPUs, so adding this line:

val_loss = torch.mean(self.all_gather(val_loss))

In validation_epoch_end fixed the issue for me. Model has been successfully training overnight and still running 😃

Hey everyone,

Could it be related to this issue and solved by this PR: https://github.com/PyTorchLightning/pytorch-lightning/pull/6004

@HareshKarnan,

I have seen val_loss in your logs. I think it might be related. Mind trying the fix ?

Best, T.C

Hi, @HareshKarnan @ndrplz. Folks, may I ask you to run your pipelines with NCCL_ASYNC_ERROR_HANDLING=1? Seems like we have the same problem. I use 1.2.0.dev0 lightning version, so this issue may migrate to the new release… My error trace:

https://pastebin.com/LNxG2JF6