1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
|
"""
Pytorch adaptation of https://omoindrot.github.io/triplet-loss
https://github.com/omoindrot/tensorflow-triplet-loss
"""
import torch
import torch.nn as nn
class TripletMarginLoss(nn.Module):
def __init__(self, margin=1.0, p=2.0, mining="batch_all"):
super().__init__()
self.margin = margin
self.p = p
self.mining = mining
if mining == "batch_all":
self.loss_fn = batch_all_triplet_loss
if mining == "batch_hard":
self.loss_fn = batch_hard_triplet_loss
def forward(self, embeddings, labels):
return self.loss_fn(labels, embeddings, self.margin, self.p)
def batch_hard_triplet_loss(labels, embeddings, margin, p):
pairwise_dist = torch.cdist(embeddings, embeddings, p=p)
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
anchor_positive_dist = mask_anchor_positive * pairwise_dist
# hardest positive for every anchor
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
# Add max value in each row to invalid negatives
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
# hardest negative for every anchor
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
triplet_loss = hardest_positive_dist - hardest_negative_dist + margin
triplet_loss[triplet_loss < 0] = 0
triplet_loss = triplet_loss.mean()
return triplet_loss, -1
def batch_all_triplet_loss(labels, embeddings, margin, p):
pairwise_dist = torch.cdist(embeddings, embeddings, p=p)
anchor_positive_dist = pairwise_dist.unsqueeze(2)
anchor_negative_dist = pairwise_dist.unsqueeze(1)
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
mask = _get_triplet_mask(labels)
triplet_loss = mask.float() * triplet_loss
# Remove negative losses (easy triplets)
triplet_loss[triplet_loss < 0] = 0
# Count number of positive triplets (where triplet_loss > 0)
valid_triplets = triplet_loss[triplet_loss > 1e-16]
num_positive_triplets = valid_triplets.size(0)
num_valid_triplets = mask.sum()
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
# Get final mean triplet loss over the positive valid triplets
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
return triplet_loss, fraction_positive_triplets
def _get_triplet_mask(labels):
# Check that i, j and k are distinct
indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
indices_not_equal = ~indices_equal
i_not_equal_j = indices_not_equal.unsqueeze(2)
i_not_equal_k = indices_not_equal.unsqueeze(1)
j_not_equal_k = indices_not_equal.unsqueeze(0)
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
i_equal_j = label_equal.unsqueeze(2)
i_equal_k = label_equal.unsqueeze(1)
valid_labels = ~i_equal_k & i_equal_j
return valid_labels & distinct_indices
def _get_anchor_positive_triplet_mask(labels):
# Check that i and j are distinct
indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
indices_not_equal = ~indices_equal
# Check if labels[i] == labels[j]
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
return labels_equal & indices_not_equal
def _get_anchor_negative_triplet_mask(labels):
return labels.unsqueeze(0) != labels.unsqueeze(1)
|