pytorch-metric-learning: Eerror of Centroid Triplet Loss (version1.2.0)
In version 1.2.0, Centroid Triplet Loss is unstable. When the batch size is small, the below error always occurs. If larger batch size, it is relatively be relieved.
File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/base_metric_loss_function.py", line 38, in forward embeddings, labels, indices_tuple, ref_emb, ref_labels File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in compute_loss indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in <listcomp> indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] RuntimeError: shape '[6, -1]' is invalid for input of size 70330
About this issue
- Original URL
- State: open
- Created 2 years ago
- Reactions: 1
- Comments: 20 (12 by maintainers)
Commits related to this issue
- fixes #451 with valueerror, more documentation, more code comments — committed to cwkeam/pytorch-metric-learning by cwkeam 2 years ago
We temporarily add an assertion in the compute_loss function for losses.CentroidTripletLoss, as follows:
This is a temporary solution and may negatively contribute to the results due to truncated triplet samples. We look forward to an effective solution.