pytorch-lightning: Metrics API when using DDP and multi-GPU freezes on compute() at end of validation phase
π Bug
Implemented AUC metric class to calculate train/valid AUC per epoch, but my progress bar freezes at end of first epoch with GPUs at 100%. It works with 1 GPU, but not more. I basically copied the source code from metric ExplainedVariance but it doesnβt work in DDP with multi-gpus for me. The bug happens after the return in compute() because print statements in compute() successfully print the preds and targets variables.
Iβm training ResNet101 on 2700 3D images stored as .npy files.
import torch
from pytorch_lightning.metrics import Metric
from pytorch_lightning.metrics.functional.classification import multiclass_auroc
class AUC(Metric):
def __init__(self, dist_sync_on_step=False):
super().__init__(compute_on_step=False, dist_sync_on_step=dist_sync_on_step)
self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("targets", default=[], dist_reduce_fx=None)
def update(self, preds: torch.Tensor, targets: torch.Tensor):
self.preds.append(preds)
self.targets.append(targets)
def compute(self):
preds = torch.cat(self.preds)
targets = torch.cat(self.targets)
return multiclass_auroc(preds, targets)
- PyTorch Version (e.g., 1.0): 1.7.1+cu101
- OS (e.g., Linux): Linux
- How you installed PyTorch (conda, pip, source): pip
- Build command you used (if compiling from source):
- Python version: 3.7.6
- CUDA/cuDNN version: 10.1
- GPU models and configuration: 4 V100 on google cloud VM
- Any other relevant information: 32 cores, 128GB mem
- Pytorch Lightning Version: 1.1.8
Additional context
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Reactions: 3
- Comments: 29 (9 by maintainers)
Hi, I am experiencing the same problem. My model breaks at epoch 5. It freezes and I cannot interrupt. I had to kill the terminal. This only happens when DDP with multiple GPUs; single GPU works fine.
I made it work with
pytorch-lightning1.4+ andtorchmetrics0.5+@angadkalra from your notebook it seems you are using 16-precision for training right? I know there was a problem with it not working as it should, but should have been fixed by PR https://github.com/PyTorchLightning/pytorch-lightning/pull/6080 which is included in v1.2.1 (released some hours ago). Could you try upgrading one more time to see if this fix your last problem.
I fixed it by
pip install pytorch-lightning==1.1.1. I think it must be sth to do with the current 1.1.18 version.