pytorch-metric-learning: Eerror of Centroid Triplet Loss (version1.2.0)

In version 1.2.0, Centroid Triplet Loss is unstable. When the batch size is small, the below error always occurs. If larger batch size, it is relatively be relieved.

File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/base_metric_loss_function.py", line 38, in forward embeddings, labels, indices_tuple, ref_emb, ref_labels File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in compute_loss indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] File "/root/miniconda3/lib/python3.7/site-packages/pytorch_metric_learning/losses/centroid_triplet_loss.py", line 124, in <listcomp> indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] RuntimeError: shape '[6, -1]' is invalid for input of size 70330

About this issue

  • Original URL
  • State: open
  • Created 2 years ago
  • Reactions: 1
  • Comments: 20 (12 by maintainers)

Commits related to this issue

Most upvoted comments

We temporarily add an assertion in the compute_loss function for losses.CentroidTripletLoss, as follows:

#make only query vectors be anchor vectors indices_tuple = [x[: len(x) // 3] + starting_idx for x in indices_tuple]

#added by Jason.Fang remainder = len(indices_tuple[0])%len(one_labels) #can not be divisible #quotient = len(indices_tuple[0])//len(one_labels) indices_tuple = [x[: len(indices_tuple[0])-remainder] for x in indices_tuple]

#make only pos_centroids be postive examples indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] indices_tuple = [x.chunk(2, dim=1)[0] for x in indices_tuple]

#make only neg_centroids be negative examples indices_tuple = [x.chunk(len(one_labels), dim=1)[-1].flatten() for x in indices_tuple]

This is a temporary solution and may negatively contribute to the results due to truncated triplet samples. We look forward to an effective solution.