You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

265 lines
13 KiB

from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Literal
import torch
from torch import Tensor, nn
from transformers import PreTrainedTokenizerBase
from sentence_transformers.models import StaticEmbedding
from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.util import all_gather_with_grad
class GISTEmbedLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
guide: SentenceTransformer,
temperature: float = 0.01,
margin_strategy: Literal["absolute", "relative"] = "absolute",
margin: float = 0.0,
contrast_anchors: bool = True,
contrast_positives: bool = True,
gather_across_devices: bool = False,
) -> None:
"""
This loss is used to train a SentenceTransformer model using the GISTEmbed algorithm.
It takes a model and a guide model as input, and uses the guide model to guide the
in-batch negative sample selection. The cosine similarity is used to compute the loss
and the temperature parameter is used to scale the cosine similarities.
You can apply different false-negative filtering strategies to discard hard negatives that are too similar to
the positive. Two strategies are supported:
- "absolute": Discards negatives whose similarity score is greater than or equal to ``positive_score - margin``.
- "relative": Discards negatives whose similarity score is greater than or equal to ``positive_score * (1 - margin)``.
Args:
model: SentenceTransformer model based on a `transformers` model.
guide: SentenceTransformer model to guide the in-batch negative sample selection.
temperature: Temperature parameter to scale the cosine similarities. Inverse of the ``scale`` parameter
in :class:`MultipleNegativesRankingLoss`.
margin_strategy: Strategy used for false negative filtering. One of {"absolute", "relative"}.
margin: The margin value for filtering negatives. Defaults to 0.0, together with the "absolute" strategy,
this only removes negatives that are more similar to the query than the positive is to the query.
contrast_anchors: If True, include anchor-anchor pairs in the loss computation, resulting in the embeddings
of the anchors being pushed further apart. Defaults to True, following the original GISTEmbed paper.
contrast_positives: If True, include positive-positive pairs in the loss computation, resulting in the embeddings
of the positives being pushed further apart. Defaults to True, following the original GISTEmbed paper,
but setting to False may yield better results in some retrieval tasks.
gather_across_devices: If True, gather the embeddings across all devices before computing the loss.
Recommended when training on multiple GPUs, as it allows for larger batch sizes, but it may slow down
training due to communication overhead, and can potentially lead to out-of-memory errors.
References:
- For further details, see: https://huggingface.co/papers/2402.16829
Requirements:
1. (anchor, positive, negative) triplets
2. (anchor, positive) pairs
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| (anchor, positive, negative) triplets | none |
+---------------------------------------+--------+
| (anchor, positive) pairs | none |
+---------------------------------------+--------+
Recommendations:
- Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that no in-batch negatives are duplicates of the anchor or positive samples.
Relations:
- :class:`MultipleNegativesRankingLoss` is similar to this loss, but it does not use
a guide model to guide the in-batch negative sample selection. `GISTEmbedLoss` yields
a stronger training signal at the cost of some training overhead.
Example:
::
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import Dataset
model = SentenceTransformer("microsoft/mpnet-base")
guide = SentenceTransformer("all-MiniLM-L6-v2")
train_dataset = Dataset.from_dict({
"anchor": ["It's nice weather outside today.", "He drove to work."],
"positive": ["It's so sunny.", "He took the car to the office."],
})
loss = losses.GISTEmbedLoss(model, guide)
trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
"""
super().__init__()
self.model = model
self.guide = guide
self.temperature = temperature
self.similarity_fct = nn.CosineSimilarity(dim=-1)
if not hasattr(model, "tokenizer") or not hasattr(guide, "tokenizer"):
raise ValueError("Both the training model and the guiding model must have a tokenizer attribute.")
if not isinstance(model.tokenizer, PreTrainedTokenizerBase) or not isinstance(
guide.tokenizer, PreTrainedTokenizerBase
):
raise ValueError(
"Both the training model and the guiding model must use a PreTrainedTokenizer from transformers."
)
self.must_retokenize = (
model.tokenizer.get_vocab() != guide.tokenizer.get_vocab() or guide.max_seq_length < model.max_seq_length
)
if self.must_retokenize:
self.tokenizer = self.model.tokenizer
if isinstance(self.model[0], StaticEmbedding):
raise ValueError(
"If we must retokenize because the guide model has a different tokenizer, "
"then the Sentence Transformer model must not be based on a StaticEmbedding."
)
if margin_strategy not in ("absolute", "relative"):
raise ValueError("margin_strategy must be 'absolute' or 'relative'.")
self.margin_strategy = margin_strategy
self.margin = margin
self.contrast_anchors = contrast_anchors
self.contrast_positives = contrast_positives
self.gather_across_devices = gather_across_devices
self.cross_entropy_loss = nn.CrossEntropyLoss()
def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor:
return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0))
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
with torch.no_grad():
if self.must_retokenize:
decoded = [
self.tokenizer.batch_decode(sentence_feature["input_ids"], skip_special_tokens=True)
for sentence_feature in sentence_features
]
sentence_features = [self.guide.tokenize(sentences) for sentences in decoded]
sentence_features = [
{key: value.to(self.guide.device) for key, value in sentence_feature.items()}
for sentence_feature in sentence_features
]
guide_embeddings = [
self.guide(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features
]
negative = None
negative_guide = None
if len(embeddings) == 2:
anchor, positive = embeddings
anchor_guide, positive_guide = guide_embeddings
elif len(embeddings) == 3:
anchor, positive, negative = embeddings
anchor_guide, positive_guide, negative_guide = guide_embeddings
else:
raise ValueError(f"Expected 2 or 3 embeddings, got {len(embeddings)}")
batch_size = anchor.size(0)
offset = 0
if self.gather_across_devices:
# Gather the candidates across all devices, with gradients, but not the anchors. We compute only this
# device's anchors with all candidates from all devices, such that the backward pass on the document
# embeddings can flow back to the original devices.
positive = all_gather_with_grad(positive)
positive_guide = all_gather_with_grad(positive_guide)
if negative is not None:
negative = all_gather_with_grad(negative)
negative_guide = all_gather_with_grad(negative_guide)
# All have this shape: (batch_size * world_size * (1 + num_negatives), embedding_dim)
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
offset = rank * batch_size
# Compute the similarities given the training and guide embeddings.
ap_sim = self.sim_matrix(anchor, positive)
guided_ap_sim = self.sim_matrix(anchor_guide, positive_guide)
# Define the anchor threshold
guided_sim = guided_ap_sim.diagonal(offset=offset).view(-1, 1)
# This uses guided (teacher) similarity as a dynamic threshold to identify and suppress false negatives
def mask_false_negatives(guided_sim_mat, sim_mat, positive_mask: Tensor | None = None):
if self.margin_strategy == "absolute":
# Remove samples whose guided similarity is higher than (positive_sim - margin)
mask = guided_sim_mat > (guided_sim - self.margin)
elif self.margin_strategy == "relative":
# Remove samples whose guided similarity is higher than (positive_sim * margin)
mask = guided_sim_mat > (guided_sim * (1 - self.margin))
if positive_mask is not None:
# Ensure true positive pairs are not masked out
mask = mask & ~positive_mask
sim_mat[mask] = -torch.inf
return sim_mat
# Create a mask to protect true positive pairs in the anchor-positive matrix (i.e., diagonal elements)
positive_mask = torch.eye(*guided_ap_sim.shape, dtype=torch.bool, device=guided_ap_sim.device)
# Apply false negative suppression to each similarity matrix using guided similarity as anchor
ap_sim = mask_false_negatives(guided_ap_sim, ap_sim, positive_mask=positive_mask) # anchor-positive
scores = [ap_sim]
if self.contrast_anchors:
aa_sim = self.sim_matrix(anchor, anchor)
guided_aa_sim = self.sim_matrix(anchor_guide, anchor_guide)
aa_sim = mask_false_negatives(guided_aa_sim, aa_sim) # anchor-anchor
scores.append(aa_sim)
if self.contrast_positives:
pp_sim = self.sim_matrix(positive[offset : offset + batch_size], positive)
guided_pp_sim = self.sim_matrix(positive_guide[offset : offset + batch_size], positive_guide)
pp_sim = mask_false_negatives(guided_pp_sim, pp_sim) # positive-positive
scores.append(pp_sim)
# Handle the case where we have a negative sample
if negative is not None:
an_sim = self.sim_matrix(anchor, negative)
guided_an_sim = self.sim_matrix(anchor_guide, negative_guide)
an_sim = mask_false_negatives(guided_an_sim, an_sim) # anchor-negative
scores.append(an_sim)
scores = torch.cat(scores, dim=1) / self.temperature
# anchor[i] should be most similar to candidates[i], as that is the paired positive,
# so the label for anchor[i] is i. This means that we can just use arange
range_labels = torch.arange(offset, offset + batch_size, device=anchor.device)
return self.cross_entropy_loss(scores, range_labels)
def get_config_dict(self) -> dict[str, Any]:
return {
"guide": self.guide,
"temperature": self.temperature,
"margin_strategy": self.margin_strategy,
"margin": self.margin,
"contrast_anchors": self.contrast_anchors,
"contrast_positives": self.contrast_positives,
"gather_across_devices": self.gather_across_devices,
}
@property
def citation(self) -> str:
return """
@misc{solatorio2024gistembed,
title={GISTEmbed: Guided In-sample Selection of Training Negatives for Text Embedding Fine-tuning},
author={Aivin V. Solatorio},
year={2024},
eprint={2402.16829},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
"""