You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
267 lines
12 KiB
267 lines
12 KiB
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
|
from sentence_transformers.util import pytorch_cos_sim
|
|
|
|
|
|
class BatchHardTripletLossDistanceFunction:
|
|
"""This class defines distance functions, that can be used with Batch[All/Hard/SemiHard]TripletLoss"""
|
|
|
|
@staticmethod
|
|
def cosine_distance(embeddings: Tensor) -> Tensor:
|
|
"""Compute the 2D matrix of cosine distances (1-cosine_similarity) between all embeddings."""
|
|
return 1 - pytorch_cos_sim(embeddings, embeddings)
|
|
|
|
@staticmethod
|
|
def eucledian_distance(embeddings: Tensor, squared=False) -> Tensor:
|
|
"""
|
|
Compute the 2D matrix of eucledian distances between all the embeddings.
|
|
Args:
|
|
embeddings: tensor of shape (batch_size, embed_dim)
|
|
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
|
If false, output is the pairwise euclidean distance matrix.
|
|
Returns:
|
|
pairwise_distances: tensor of shape (batch_size, batch_size)
|
|
"""
|
|
|
|
dot_product = torch.matmul(embeddings, embeddings.t())
|
|
|
|
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
|
|
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
|
|
# shape (batch_size,)
|
|
square_norm = torch.diag(dot_product)
|
|
|
|
# Compute the pairwise distance matrix as we have:
|
|
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
|
|
# shape (batch_size, batch_size)
|
|
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
|
|
|
|
# Because of computation errors, some distances might be negative so we put everything >= 0.0
|
|
distances[distances < 0] = 0
|
|
|
|
if not squared:
|
|
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
|
|
# we need to add a small epsilon where distances == 0.0
|
|
mask = distances.eq(0).float()
|
|
distances = distances + mask * 1e-16
|
|
|
|
distances = (1.0 - mask) * torch.sqrt(distances)
|
|
|
|
return distances
|
|
|
|
|
|
class BatchHardTripletLoss(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model: SentenceTransformer,
|
|
distance_metric=BatchHardTripletLossDistanceFunction.eucledian_distance,
|
|
margin: float = 5,
|
|
) -> None:
|
|
"""
|
|
BatchHardTripletLoss takes a batch with (sentence, label) pairs and computes the loss for all possible, valid
|
|
triplets, i.e., anchor and positive must have the same label, anchor and negative a different label. It then looks
|
|
for the hardest positive and the hardest negatives.
|
|
The labels must be integers, with same label indicating sentences from the same class. Your train dataset
|
|
must contain at least 2 examples per label class.
|
|
|
|
Args:
|
|
model: SentenceTransformer model
|
|
distance_metric: Function that returns a distance between
|
|
two embeddings. The class SiameseDistanceMetric contains
|
|
pre-defined metrics that can be used
|
|
margin: Negative samples should be at least margin further
|
|
apart from the anchor than the positive.
|
|
|
|
Definitions:
|
|
:Easy triplets: Triplets which have a loss of 0 because
|
|
``distance(anchor, positive) + margin < distance(anchor, negative)``.
|
|
:Hard triplets: Triplets where the negative is closer to the anchor than the positive, i.e.,
|
|
``distance(anchor, negative) < distance(anchor, positive)``.
|
|
:Semi-hard triplets: Triplets where the negative is not closer to the anchor than the positive, but which
|
|
still have a positive loss, i.e., ``distance(anchor, positive) < distance(anchor, negative) + margin``.
|
|
|
|
References:
|
|
* Source: https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
|
|
* Paper: In Defense of the Triplet Loss for Person Re-Identification, https://huggingface.co/papers/1703.07737
|
|
* Blog post: https://omoindrot.github.io/triplet-loss
|
|
|
|
Requirements:
|
|
1. Each sentence must be labeled with a class.
|
|
2. Your dataset must contain at least 2 examples per labels class.
|
|
3. Your dataset should contain hard positives and negatives.
|
|
|
|
Inputs:
|
|
+------------------+--------+
|
|
| Texts | Labels |
|
|
+==================+========+
|
|
| single sentences | class |
|
|
+------------------+--------+
|
|
|
|
Recommendations:
|
|
- Use ``BatchSamplers.GROUP_BY_LABEL`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
|
|
ensure that each batch contains 2+ examples per label class.
|
|
|
|
Relations:
|
|
* :class:`BatchAllTripletLoss` uses all possible, valid triplets, rather than only the hardest positive and negative samples.
|
|
* :class:`BatchSemiHardTripletLoss` uses only semi-hard triplets, valid triplets, rather than only the hardest positive and negative samples.
|
|
* :class:`BatchHardSoftMarginTripletLoss` does not require setting a margin, while this loss does.
|
|
|
|
Example:
|
|
::
|
|
|
|
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
|
|
from datasets import Dataset
|
|
|
|
model = SentenceTransformer("microsoft/mpnet-base")
|
|
# E.g. 0: sports, 1: economy, 2: politics
|
|
train_dataset = Dataset.from_dict({
|
|
"sentence": [
|
|
"He played a great game.",
|
|
"The stock is up 20%",
|
|
"They won 2-1.",
|
|
"The last goal was amazing.",
|
|
"They all voted against the bill.",
|
|
],
|
|
"label": [0, 1, 0, 0, 2],
|
|
})
|
|
loss = losses.BatchHardTripletLoss(model)
|
|
|
|
trainer = SentenceTransformerTrainer(
|
|
model=model,
|
|
train_dataset=train_dataset,
|
|
loss=loss,
|
|
)
|
|
trainer.train()
|
|
"""
|
|
super().__init__()
|
|
self.sentence_embedder = model
|
|
self.triplet_margin = margin
|
|
self.distance_metric = distance_metric
|
|
|
|
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor):
|
|
rep = self.sentence_embedder(sentence_features[0])["sentence_embedding"]
|
|
return self.batch_hard_triplet_loss(labels, rep)
|
|
|
|
# Hard Triplet Loss
|
|
# Source: https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
|
|
# Paper: In Defense of the Triplet Loss for Person Re-Identification, https://huggingface.co/papers/1703.07737
|
|
# Blog post: https://omoindrot.github.io/triplet-loss
|
|
def batch_hard_triplet_loss(self, labels: Tensor, embeddings: Tensor) -> Tensor:
|
|
"""Build the triplet loss over a batch of embeddings.
|
|
For each anchor, we get the hardest positive and hardest negative to form a triplet.
|
|
Args:
|
|
labels: labels of the batch, of size (batch_size,)
|
|
embeddings: tensor of shape (batch_size, embed_dim)
|
|
margin: margin for triplet loss
|
|
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
|
If false, output is the pairwise euclidean distance matrix.
|
|
Returns:
|
|
Label_Sentence_Triplet: scalar tensor containing the triplet loss
|
|
"""
|
|
# Get the pairwise distance matrix
|
|
pairwise_dist = self.distance_metric(embeddings)
|
|
|
|
# For each anchor, get the hardest positive
|
|
# First, we need to get a mask for every valid positive (they should have same label)
|
|
mask_anchor_positive = BatchHardTripletLoss.get_anchor_positive_triplet_mask(labels).float()
|
|
|
|
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
|
|
anchor_positive_dist = mask_anchor_positive * pairwise_dist
|
|
|
|
# shape (batch_size, 1)
|
|
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
|
|
|
|
# For each anchor, get the hardest negative
|
|
# First, we need to get a mask for every valid negative (they should have different labels)
|
|
mask_anchor_negative = BatchHardTripletLoss.get_anchor_negative_triplet_mask(labels).float()
|
|
|
|
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
|
|
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
|
|
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
|
|
|
|
# shape (batch_size,)
|
|
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
|
|
|
|
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
|
|
tl = hardest_positive_dist - hardest_negative_dist + self.triplet_margin
|
|
tl[tl < 0] = 0
|
|
triplet_loss = tl.mean()
|
|
|
|
return triplet_loss
|
|
|
|
@staticmethod
|
|
def get_triplet_mask(labels: Tensor) -> Tensor:
|
|
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
|
|
A triplet (i, j, k) is valid if:
|
|
- i, j, k are distinct
|
|
- labels[i] == labels[j] and labels[i] != labels[k]
|
|
Args:
|
|
labels: tf.int32 `Tensor` with shape [batch_size]
|
|
"""
|
|
# Check that i, j and k are distinct
|
|
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
|
indices_not_equal = ~indices_equal
|
|
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
|
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
|
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
|
|
|
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
|
|
|
|
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
|
i_equal_j = label_equal.unsqueeze(2)
|
|
i_equal_k = label_equal.unsqueeze(1)
|
|
|
|
valid_labels = ~i_equal_k & i_equal_j
|
|
|
|
return valid_labels & distinct_indices
|
|
|
|
@staticmethod
|
|
def get_anchor_positive_triplet_mask(labels: Tensor) -> Tensor:
|
|
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
|
|
Args:
|
|
labels: tf.int32 `Tensor` with shape [batch_size]
|
|
Returns:
|
|
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
|
"""
|
|
# Check that i and j are distinct
|
|
|
|
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
|
indices_not_equal = ~indices_equal
|
|
|
|
# Check if labels[i] == labels[j]
|
|
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
|
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
|
|
|
return labels_equal & indices_not_equal
|
|
|
|
@staticmethod
|
|
def get_anchor_negative_triplet_mask(labels: Tensor) -> Tensor:
|
|
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
|
|
Args:
|
|
labels: tf.int32 `Tensor` with shape [batch_size]
|
|
Returns:
|
|
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
|
"""
|
|
# Check if labels[i] != labels[k]
|
|
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
|
|
|
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
|
|
|
|
@property
|
|
def citation(self) -> str:
|
|
return """
|
|
@misc{hermans2017defense,
|
|
title={In Defense of the Triplet Loss for Person Re-Identification},
|
|
author={Alexander Hermans and Lucas Beyer and Bastian Leibe},
|
|
year={2017},
|
|
eprint={1703.07737},
|
|
archivePrefix={arXiv},
|
|
primaryClass={cs.CV}
|
|
}
|
|
"""
|