You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
129 lines
5.3 KiB
129 lines
5.3 KiB
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
from typing import Any
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
from sentence_transformers import util
|
|
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
|
|
|
|
|
class CoSENTLoss(nn.Module):
|
|
def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.pairwise_cos_sim) -> None:
|
|
"""
|
|
This class implements CoSENT (Consistent SENTence embedding) loss.
|
|
It expects that each of the InputExamples consists of a pair of texts and a float valued label, representing
|
|
the expected similarity score between the pair.
|
|
|
|
It computes the following loss function:
|
|
|
|
``loss = logsum(1+exp(s(k,l)-s(i,j))+exp...)``, where ``(i,j)`` and ``(k,l)`` are any of the input pairs in the
|
|
batch such that the expected similarity of ``(i,j)`` is greater than ``(k,l)``. The summation is over all possible
|
|
pairs of input pairs in the batch that match this condition.
|
|
|
|
Anecdotal experiments show that this loss function produces a more powerful training signal than :class:`CosineSimilarityLoss`,
|
|
resulting in faster convergence and a final model with superior performance. Consequently, CoSENTLoss may be used
|
|
as a drop-in replacement for :class:`CosineSimilarityLoss` in any training script.
|
|
|
|
Args:
|
|
model: SentenceTransformerModel
|
|
similarity_fct: Function to compute the PAIRWISE similarity
|
|
between embeddings. Default is
|
|
``util.pairwise_cos_sim``.
|
|
scale: Output of similarity function is multiplied by scale
|
|
value. Represents the inverse temperature.
|
|
|
|
References:
|
|
- For further details, see: https://penghao-bdsc.github.io/papers/CoSENT_TASLP2024.pdf
|
|
|
|
Requirements:
|
|
- Sentence pairs with corresponding similarity scores in range of the similarity function. Default is [-1,1].
|
|
|
|
Inputs:
|
|
+--------------------------------+------------------------+
|
|
| Texts | Labels |
|
|
+================================+========================+
|
|
| (sentence_A, sentence_B) pairs | float similarity score |
|
|
+--------------------------------+------------------------+
|
|
|
|
Relations:
|
|
- :class:`AnglELoss` is CoSENTLoss with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``.
|
|
- :class:`CosineSimilarityLoss` seems to produce a weaker training signal than CoSENTLoss. In our experiments, CoSENTLoss is recommended.
|
|
|
|
Example:
|
|
::
|
|
|
|
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
|
|
from datasets import Dataset
|
|
|
|
model = SentenceTransformer("microsoft/mpnet-base")
|
|
train_dataset = Dataset.from_dict({
|
|
"sentence1": ["It's nice weather outside today.", "He drove to work."],
|
|
"sentence2": ["It's so sunny.", "She walked to the store."],
|
|
"score": [1.0, 0.3],
|
|
})
|
|
loss = losses.CoSENTLoss(model)
|
|
|
|
trainer = SentenceTransformerTrainer(
|
|
model=model,
|
|
train_dataset=train_dataset,
|
|
loss=loss,
|
|
)
|
|
trainer.train()
|
|
"""
|
|
super().__init__()
|
|
self.model = model
|
|
self.similarity_fct = similarity_fct
|
|
self.scale = scale
|
|
|
|
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
|
|
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
|
|
|
|
return self.compute_loss_from_embeddings(embeddings, labels)
|
|
|
|
def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
|
|
"""
|
|
Compute the CoSENT loss from embeddings.
|
|
|
|
Args:
|
|
embeddings: List of embeddings
|
|
labels: Labels indicating the similarity scores of the pairs
|
|
|
|
Returns:
|
|
Loss value
|
|
"""
|
|
|
|
scores = self.similarity_fct(embeddings[0], embeddings[1])
|
|
scores = scores * self.scale
|
|
scores = scores[:, None] - scores[None, :]
|
|
|
|
# label matrix indicating which pairs are relevant
|
|
labels = labels[:, None] < labels[None, :]
|
|
labels = labels.float()
|
|
|
|
# mask out irrelevant pairs so they are negligible after exp()
|
|
scores = scores - (1 - labels) * 1e12
|
|
|
|
# append a zero as e^0 = 1
|
|
scores = torch.cat((torch.zeros(1).to(scores.device), scores.view(-1)), dim=0)
|
|
loss = torch.logsumexp(scores, dim=0)
|
|
|
|
return loss
|
|
|
|
def get_config_dict(self) -> dict[str, Any]:
|
|
return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}
|
|
|
|
@property
|
|
def citation(self) -> str:
|
|
return """
|
|
@article{10531646,
|
|
author={Huang, Xiang and Peng, Hao and Zou, Dongcheng and Liu, Zhiwei and Li, Jianxin and Liu, Kay and Wu, Jia and Su, Jianlin and Yu, Philip S.},
|
|
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
|
|
title={CoSENT: Consistent Sentence Embedding via Similarity Ranking},
|
|
year={2024},
|
|
doi={10.1109/TASLP.2024.3402087}
|
|
}
|
|
"""
|