torchmetrics: RuntimeError when using MAP-metric
π Bug
Hi! I am training a detection model and use MAP-metric during validation. I got the following error at the validation_step:
RuntimeError: expected scalar type Float but found Bool
.
To Reproduce
Pick a faster rcnn model, I used fasterrcnn_resnet50_fpn_v2
() from torchvision
. Implement validation_step where
self.metrics.update(...)
is called for the model results and targets and validation_epoch_end where the self.metrics.compute()
is called for the previously gathered results.
Code sample
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchmetrics.detection.mean_ap import MeanAveragePrecision
class FasterRCNNModel(pl.LightningModule):
def __init__(self, num_classes):
super().__init__()
model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn_v2()
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
self.model = model
self.metric = MeanAveragePrecision(box_format='xyxy', iou_type='bbox')
def validation_step(self, batch, batch_idx):
images, targets = batch
preds = self.model(images)
self.metric.update(preds, targets)
def validation_epoch_end(self, outs):
mAP = self.metric.compute()
self.log("val/mAP", mAP)
self.metric.reset()
targets (List[Dict]), containing:
- boxes (
torch.float32
) - labels (
torch.int64
)
preds (List[Dict]), containing:
- boxes (
torch.float32
) - scores (
torch.float32
) - labels (
torch.int64
)
Error message
File "/homes/vsoboleva/scripts/pascal_voc/train.py", line 65, in validation_epoch_end
mAP = self.metric.compute()
File "/homes/vsoboleva/miniconda3/lib/python3.9/site-packages/torchmetrics/metric.py", line 523, in wrapped_func
value = compute(*args, **kwargs)
File "/homes/vsoboleva/miniconda3/lib/python3.9/site-packages/torchmetrics/detection/mean_ap.py", line 908, in compute
precisions, recalls = self._calculate(classes)
File "/homes/vsoboleva/miniconda3/lib/python3.9/site-packages/torchmetrics/detection/mean_ap.py", line 758, in _calculate
recall, precision, scores = MeanAveragePrecision.__calculate_recall_precision_scores(
File "/homes/vsoboleva/miniconda3/lib/python3.9/site-packages/torchmetrics/detection/mean_ap.py", line 831, in __calculate_recall_precision_scores
det_scores = torch.cat([e["dtScores"][:max_det] for e in img_eval_cls_bbox])
RuntimeError: expected scalar type Float but found Bool
Expected behavior
The sel.metric.compute(...)
compute values correctly and does not fail with RuntimeError: expected scalar type Float but found Bool
.
Environment
- TorchMetrics 0.9.2 build with pip
- Python 3.9.12, torch 1.12.0, torchvision 0.13.0
- OS (e.g., Linux): Ubuntu 20.04.3
Additional context
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Reactions: 2
- Comments: 19 (12 by maintainers)
This issue happens when using PyTorch 1.12.0 on GPU device, the minimal code to reproduce this issue is shown below.
Although, the CPU implementation and other versions of PyTorch can cast 0-dim bool Tensor to float Tensor. The
dtScores
should be initialized as a float type because score is usually a real number.@SkafteNicki
The current test case in
tests/unittests/detection/test_map.py
do not cover thisbool
float
concatenate case. There is only one pair of pred and target in_inputs3
. Adding another pair of pred and target can fail the tests. e.g.okay I think I found it:
https://github.com/Lightning-AI/metrics/blob/31c384411bc9a28f4ad2085cf123f68f382b6f82/torchmetrics/detection/mean_ap.py#L505
I believe this type needs to be changed to
torch.float32
fromtorch.bool
Can you check that by changing the above, youβre no longer experiencing the issue please?
@austinmw should be fixed now by PR #1150. Please try installing from master:
which should solve the issue. If not, please report back and we can reopen the issue and try to fix it.
@dreaquil, changing the type helped me as well π Thank you very much!
Hi, donβt mind me, just sliding into this thread. I had the same issue. @dreaquil changing the type fixexd the issue for me. Best Simon