torchmetrics: Validation with torchmetrics extremely slow
Bug description
Hi all,
I recently tried to implement a DeepLabV3 training pipeline. I wanted to use the build-in torchmetrics.JaccardIndex
as my evaluation metric. My LightningModule
looks like this:
import torchmetrics
from pytorch_lightning import LightningModule
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50
class DeepLabV3LightningModule(LightningModule):
def __init__(self):
self.model = deeplabv3_resnet50(
num_classes=38,
aux_loss=False
)
self.loss = nn.CrossEntropyLoss(ignore_index=255, reduction="mean")
self.iou_metric = torchmetrics.JaccardIndex(
task="multiclass",
threshold=0.5,
num_classes=38,
average="macro",
)
def training_step(self, batch, batch_idx):
imgs, masks = batch
out = self.model(imgs)
preds = out["out"]
loss = self.loss(preds, masks)
return loss
def validation_step(self, batch, batch_idx):
imgs, masks = batch
out = self.model(imgs)
preds = out["out"]
loss = self.loss(preds, masks)
preds = torch.softmax(preds, dim=1)
pred_labels = torch.argmax(preds, dim=1)
# measure runtime of metric update
start = timer()
self.iou_metric.update(pred_labels, masks)
elapsed = timer() - start
return elapsed
def validation_epoch_end(self, outputs):
avg_runtime = round(mean(outputs), 4)
print(f"GPU {self.local_rank}: {avg_runtime} seconds")
When using this validation procedure, it is extremely slow. On average, the update step of the metric takes 23.4 seconds. However, the first 3 updates are very fast (<1 second), then they become slow.
I tried to reproduce this behavior in a MWE:
from timeit import default_timer as timer
from statistics import mean
import torchmetrics
import torch
num_classes = 38
iou_metric = torchmetrics.JaccardIndex(
task="multiclass",
threshold=0.5,
num_classes=num_classes,
average="macro"
).to("cuda")
# dummy labels in shape [b, h, w]
label_mask = torch.randint(low=0, high=num_classes-1, size=(8, 480, 640), device="cuda")
# dummy predicted labels in shape [b, h, w]
pred_mask = torch.randint(low=0, high=num_classes-1, size=(8, 480, 640), device="cuda")
runtime_hist = []
for i in range(100):
start = timer()
iou_metric.update(label_mask, pred_mask)
elapsed = timer() - start
runtime_hist.append(elapsed)
avg_runtime = round(mean(runtime_hist), 2)
print(avg_runtime)
Here I get an average update duration of 0.03 seconds, so I do not encounter the extremely slow update as in my LightningModule
above. To me this looks like there is something wrong. At this point, I am not sure if thi
Here some training information for my pytorch-lightning training pipeline:
- OS: Ubuntu 20.04.4
- CUDA 11.3
- DDP training strategy
- GPUs: 4x V100
- batch size: 8
- image size (width x height): 640 x 480
- number of workers in dataloader: 8
My package versions:
- pytorch lightning: 1.8.4.post0 (installed via pip)
- torch: 1.13.0
- torchvision: 0.14.0
- torchmetrics: 0.11.0
- numpy: 11.23.5
Thanks so much! Lukas
How to reproduce the bug
No response
Error messages and logs
No response
Environment
No response
More info
No response
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Reactions: 3
- Comments: 24 (12 by maintainers)
Any progress?
@lukazso. Your repro was very helpful. I am able to reproduce the slowdown locally.
First I tried to simplify your version:
Using trainer.validate, removing distributed
After seeing no changes, I decided to avoid
trainer.validate
but still driving theLightningModule
Manually driving the LightningModule
Still, same behaviour
So I completely stripped out PyTorch Lightning, only leaving it’s profiler and using the advanced profiler
Pure PyTorch
And still, no changes. So it must be an issue with PyTorch or torchmetrics. You’ll have to debug further.
The profiler report shows that the difference is in
batch[0].to(device)
, where updating theiou_metric
makes the{method 'to' of 'torch._C._TensorBase' objects}
be much slower.cc @SkafteNicki or @justusschock in case you have any ideas of what could be causing this in torchmetrics. Maybe this issue should be transferred there.
Does one of the maintainers have an opinion on this? 😃
Sent! #2184
FWIW, I experience similar slowdowns with
MulticlassF1Score
,MulticlassPrecision
andMulticlassRecall
, but only whenaverage
argument is set to"macro"
. And this starts to bite me pretty fast (say, hundreds of classes), whereas my classification task has more than 10k of classes, which makes those metrics completely unusable for me. Surprisingly, this doesn’t happen at all with their multi-label counterparts.(TorchMetrics: 1.0.0; Lightning: 1.9.5; PyTorch: 1.13.1)
Did PyTorch acknowledge the potential issue / is there a link?
@SkafteNicki The snippet you are using does not use lightning other than for the profiler. Notice how the model does not subclass
LightnignModule
and noTrainer
is used. That example can be compressed further and be shared in the PyTorch issue tracker for further debugging by their dev team. As @lukazso correctly noticed, the interaction is betweenTensor.to(device)
andtorch.bincount
Hey @lukazso Sorry for the late reply. Unfortunately, I currently have no idea what could be causing this.
To further narrow this down: Do you experience the same with other metrics?
Could you maybe try with ConfusionMatrix (which is the base class of Jaccard) and Accuracy (as another classification metric)?