torchmetrics: Validation with torchmetrics extremely slow

Bug description

Hi all,

I recently tried to implement a DeepLabV3 training pipeline. I wanted to use the build-in torchmetrics.JaccardIndex as my evaluation metric. My LightningModule looks like this:

import torchmetrics 
from pytorch_lightning import LightningModule
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


class DeepLabV3LightningModule(LightningModule):
    def __init__(self):
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.loss = nn.CrossEntropyLoss(ignore_index=255, reduction="mean")
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

    def training_step(self, batch, batch_idx):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        loss = self.loss(preds, masks)       
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        loss = self.loss(preds, masks)
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        
        # measure runtime of metric update
        start = timer()
        self.iou_metric.update(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"GPU {self.local_rank}: {avg_runtime} seconds")

When using this validation procedure, it is extremely slow. On average, the update step of the metric takes 23.4 seconds. However, the first 3 updates are very fast (<1 second), then they become slow.

I tried to reproduce this behavior in a MWE:

from timeit import default_timer as timer
from statistics import mean
import torchmetrics
import torch

num_classes = 38

iou_metric = torchmetrics.JaccardIndex(
    task="multiclass",
    threshold=0.5, 
    num_classes=num_classes,
    average="macro"
).to("cuda")

# dummy labels in shape [b, h, w]
label_mask = torch.randint(low=0, high=num_classes-1, size=(8, 480, 640), device="cuda")

# dummy predicted labels in shape [b, h, w]
pred_mask = torch.randint(low=0, high=num_classes-1, size=(8, 480, 640), device="cuda")


runtime_hist = []
for i in range(100):
    start = timer()
    iou_metric.update(label_mask, pred_mask)
    elapsed = timer() - start
    runtime_hist.append(elapsed)


avg_runtime = round(mean(runtime_hist), 2)
print(avg_runtime)

Here I get an average update duration of 0.03 seconds, so I do not encounter the extremely slow update as in my LightningModule above. To me this looks like there is something wrong. At this point, I am not sure if thi

Here some training information for my pytorch-lightning training pipeline:

  • OS: Ubuntu 20.04.4
  • CUDA 11.3
  • DDP training strategy
  • GPUs: 4x V100
  • batch size: 8
  • image size (width x height): 640 x 480
  • number of workers in dataloader: 8

My package versions:

  • pytorch lightning: 1.8.4.post0 (installed via pip)
  • torch: 1.13.0
  • torchvision: 0.14.0
  • torchmetrics: 0.11.0
  • numpy: 11.23.5

Thanks so much! Lukas

How to reproduce the bug

No response

Error messages and logs

No response

Environment

No response

More info

No response

cc @carmocca @justusschock @awaelchli @borda

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Reactions: 3
  • Comments: 24 (12 by maintainers)

Most upvoted comments

Any progress?

@lukazso. Your repro was very helpful. I am able to reproduce the slowdown locally.

First I tried to simplify your version:

Using trainer.validate, removing distributed
from statistics import mean
from timeit import default_timer as timer

import pytorch_lightning as pl
import torch
import torchmetrics
from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning import LightningModule, LightningDataModule
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


class DummyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
        self.mask = torch.ones((480, 640), dtype=torch.long
        )
    
    def __len__(self):
        return 320
    
    def __getitem__(self, index):
        return self.random_img, self.mask


class DummyDataModule(LightningDataModule):
    def __init__(self, bs: int = 8, num_workers: int = 8) -> None:
        super().__init__()
        self.bs = bs
        self.num_workers = num_workers
    
    def val_dataloader(self):
        dataset = DummyDataset()
        dataloader = DataLoader(dataset, self.bs, shuffle=False, num_workers=self.num_workers)
        return dataloader


class DeepLabV3LightningModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

    def transfer_batch_to_device(self, batch, device, _):
        profiler = self.trainer.profiler
        with profiler.profile("to0"):
            batch[0] = batch[0].to(device)
        with profiler.profile("to1"):
            batch[1] = batch[1].to(device)
        return batch

    def validation_step(self, batch, _):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        # preds [num_classes, h, w] -> [h, w]
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        
        # measure runtime of metric update
        start = timer()
        with self.trainer.profiler.profile("iou_metric.update"):
            self.iou_metric(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"GPU {self.local_rank}: {avg_runtime} seconds")


if __name__ == "__main__":
    seed_everything(1, workers=True)
    data_module = DummyDataModule(bs=8, num_workers=0)
    model = DeepLabV3LightningModule()
    trainer = pl.Trainer(
        accelerator="gpu", devices=1, limit_val_batches=5, profiler="simple"
    )
    trainer.validate(model, data_module)

After seeing no changes, I decided to avoid trainer.validate but still driving the LightningModule

Manually driving the LightningModule
from statistics import mean
from timeit import default_timer as timer

import pytorch_lightning as pl
import torch
import torchmetrics
from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


class DummyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
        self.mask = torch.ones((480, 640), dtype=torch.long
        )
    
    def __len__(self):
        return 8 * 5  # batch size * limit_val_batches
    
    def __getitem__(self, index):
        return self.random_img, self.mask


class DeepLabV3LightningModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

    def transfer_batch_to_device(self, batch, device, _):
        profiler = self.trainer.profiler
        with profiler.profile("to0"):
            batch[0] = batch[0].to(device)
        with profiler.profile("to1"):
            batch[1] = batch[1].to(device)
        return batch

    def validation_step(self, batch, _):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        # preds [num_classes, h, w] -> [h, w]
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        
        # measure runtime of metric update
        start = timer()
        with self.trainer.profiler.profile("iou_metric.update"):
            self.iou_metric(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"GPU {self.local_rank}: {avg_runtime} seconds")


if __name__ == "__main__":
    seed_everything(1, workers=True)

    model = DeepLabV3LightningModule()
    trainer = pl.Trainer(
        accelerator="gpu", devices=1, limit_val_batches=5, profiler="simple"
    )
    model.trainer = trainer
    dataset = DummyDataset()
    dataloader = DataLoader(dataset, 8, shuffle=False, num_workers=0)
    device = torch.device("cuda")

    model.to(device)
    outputs = []
    for batch in dataloader:
        batch = model.transfer_batch_to_device(batch, device, 0)
        with torch.inference_mode():
            elapsed = model.validation_step(batch, 0)
        outputs.append(elapsed)
        model.validation_epoch_end(outputs)
    print(trainer.profiler.summary())

Still, same behaviour

So I completely stripped out PyTorch Lightning, only leaving it’s profiler and using the advanced profiler

Pure PyTorch
from statistics import mean
from timeit import default_timer as timer

import pytorch_lightning as pl
import torch
import torchmetrics
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


class DummyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
        self.mask = torch.ones((480, 640), dtype=torch.long
        )
    
    def __len__(self):
        return 8 * 5  # batch size * limit_val_batches
    
    def __getitem__(self, index):
        return self.random_img, self.mask


class DeepLabV3Module(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

        self.profiler = pl.profilers.AdvancedProfiler()

    def transfer_batch_to_device(self, batch, device, _):
        profiler = self.profiler
        with profiler.profile("to0"):
            batch[0] = batch[0].to(device)
        with profiler.profile("to1"):
            batch[1] = batch[1].to(device)
        return batch

    def validation_step(self, batch, _):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)

        start = timer()
        with self.profiler.profile("iou_metric.update"):
            # COMMENT THIS to see difference
            self.iou_metric(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"{avg_runtime} seconds")


if __name__ == "__main__":
    model = DeepLabV3Module()
    dataset = DummyDataset()
    dataloader = DataLoader(dataset, 8, shuffle=False, num_workers=0)
    device = torch.device("cuda")

    model.to(device)
    outputs = []
    for batch in dataloader:
        batch = model.transfer_batch_to_device(batch, device, 0)
        with torch.inference_mode():
            elapsed = model.validation_step(batch, 0)
        outputs.append(elapsed)
        model.validation_epoch_end(outputs)
    print(model.profiler.summary())

And still, no changes. So it must be an issue with PyTorch or torchmetrics. You’ll have to debug further.

The profiler report shows that the difference is in batch[0].to(device), where updating the iou_metric makes the {method 'to' of 'torch._C._TensorBase' objects} be much slower.

cc @SkafteNicki or @justusschock in case you have any ideas of what could be causing this in torchmetrics. Maybe this issue should be transferred there.

Does one of the maintainers have an opinion on this? 😃

FWIW, I experience similar slowdowns with MulticlassF1Score, MulticlassPrecision and MulticlassRecall, but only when average argument is set to "macro". And this starts to bite me pretty fast (say, hundreds of classes), whereas my classification task has more than 10k of classes, which makes those metrics completely unusable for me. Surprisingly, this doesn’t happen at all with their multi-label counterparts.

(TorchMetrics: 1.0.0; Lightning: 1.9.5; PyTorch: 1.13.1)

Did PyTorch acknowledge the potential issue / is there a link?

The TLDR seems to be that it is some weird issue between lightning and torch.bincount

@SkafteNicki The snippet you are using does not use lightning other than for the profiler. Notice how the model does not subclass LightnignModule and no Trainer is used. That example can be compressed further and be shared in the PyTorch issue tracker for further debugging by their dev team. As @lukazso correctly noticed, the interaction is between Tensor.to(device) and torch.bincount

Hey @lukazso Sorry for the late reply. Unfortunately, I currently have no idea what could be causing this.

To further narrow this down: Do you experience the same with other metrics?

Could you maybe try with ConfusionMatrix (which is the base class of Jaccard) and Accuracy (as another classification metric)?