You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
175 lines
8.2 KiB
175 lines
8.2 KiB
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
from sentence_transformers import util
|
|
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
|
|
|
|
|
class DistillKLDivLoss(nn.Module):
|
|
def __init__(
|
|
self, model: SentenceTransformer, similarity_fct=util.pairwise_dot_score, temperature: float = 1.0
|
|
) -> None:
|
|
"""
|
|
Compute the KL divergence loss between probability distributions derived from student and teacher models' similarity scores.
|
|
By default, similarity is calculated using the dot-product. This loss is designed for knowledge distillation
|
|
where a smaller student model learns from a more powerful teacher model.
|
|
|
|
The loss computes softmax probabilities from the teacher similarity scores and log-softmax probabilities
|
|
from the student model, then calculates the KL divergence between these distributions.
|
|
|
|
Args:
|
|
model: SentenceTransformer model (student model)
|
|
similarity_fct: Which similarity function to use for the student model
|
|
temperature: Temperature parameter to soften probability distributions (higher temperature = softer distributions)
|
|
A temperature of 1.0 does not scale the scores. Note: in the v5.0.1 release, the default temperature was changed from 2.0 to 1.0.
|
|
|
|
References:
|
|
- For more details, please refer to https://huggingface.co/papers/2010.11386
|
|
|
|
Requirements:
|
|
1. (query, positive, negative_1, ..., negative_n) examples
|
|
2. Labels containing teacher model's scores between query-positive and query-negative pairs
|
|
|
|
Inputs:
|
|
+------------------------------------------------+------------------------------------------------------------+
|
|
| Texts | Labels |
|
|
+================================================+============================================================+
|
|
| (query, positive, negative) | [Teacher(query, positive), Teacher(query, negative)] |
|
|
+------------------------------------------------+------------------------------------------------------------+
|
|
| (query, positive, negative_1, ..., negative_n) | [Teacher(query, positive), Teacher(query, negative_i)...] |
|
|
+------------------------------------------------+------------------------------------------------------------+
|
|
|
|
Relations:
|
|
- Similar to :class:`~sentence_transformers.losses.MarginMSELoss` but uses KL divergence instead of MSE
|
|
- More suited for distillation tasks where preserving ranking is important
|
|
|
|
Example:
|
|
|
|
Using a teacher model to compute similarity scores for distillation:
|
|
|
|
::
|
|
|
|
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
|
|
from datasets import Dataset
|
|
import torch
|
|
|
|
student_model = SentenceTransformer("microsoft/mpnet-base")
|
|
teacher_model = SentenceTransformer("all-mpnet-base-v2")
|
|
train_dataset = Dataset.from_dict({
|
|
"query": ["It's nice weather outside today.", "He drove to work."],
|
|
"positive": ["It's so sunny.", "He took the car to work."],
|
|
"negative": ["It's very cold.", "She walked to the store."],
|
|
})
|
|
|
|
def compute_labels(batch):
|
|
emb_queries = teacher_model.encode(batch["query"])
|
|
emb_positives = teacher_model.encode(batch["positive"])
|
|
emb_negatives = teacher_model.encode(batch["negative"])
|
|
|
|
pos_scores = teacher_model.similarity_pairwise(emb_queries, emb_positives)
|
|
neg_scores = teacher_model.similarity_pairwise(emb_queries, emb_negatives)
|
|
|
|
# Stack the scores for positive and negative pairs
|
|
return {
|
|
"label": torch.stack([pos_scores, neg_scores], dim=1)
|
|
}
|
|
|
|
train_dataset = train_dataset.map(compute_labels, batched=True)
|
|
loss = losses.DistillKLDivLoss(student_model)
|
|
|
|
trainer = SentenceTransformerTrainer(model=student_model, train_dataset=train_dataset, loss=loss)
|
|
trainer.train()
|
|
|
|
With multiple negatives:
|
|
|
|
::
|
|
|
|
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
|
|
from datasets import Dataset
|
|
import torch
|
|
|
|
student_model = SentenceTransformer("microsoft/mpnet-base")
|
|
teacher_model = SentenceTransformer("all-mpnet-base-v2")
|
|
|
|
train_dataset = Dataset.from_dict(
|
|
{
|
|
"query": ["It's nice weather outside today.", "He drove to work."],
|
|
"positive": ["It's so sunny.", "He took the car to work."],
|
|
"negative1": ["It's very cold.", "She walked to the store."],
|
|
"negative2": ["Its rainy", "She took the bus"],
|
|
}
|
|
)
|
|
|
|
|
|
def compute_labels(batch):
|
|
emb_queries = teacher_model.encode(batch["query"])
|
|
emb_positives = teacher_model.encode(batch["positive"])
|
|
emb_negatives1 = teacher_model.encode(batch["negative1"])
|
|
emb_negatives2 = teacher_model.encode(batch["negative2"])
|
|
|
|
pos_scores = teacher_model.similarity_pairwise(emb_queries, emb_positives)
|
|
neg_scores1 = teacher_model.similarity_pairwise(emb_queries, emb_negatives1)
|
|
neg_scores2 = teacher_model.similarity_pairwise(emb_queries, emb_negatives2)
|
|
|
|
# Stack the scores for positive and multiple negative pairs
|
|
return {
|
|
"label": torch.stack([pos_scores, neg_scores1, neg_scores2], dim=1)
|
|
}
|
|
|
|
train_dataset = train_dataset.map(compute_labels, batched=True)
|
|
loss = losses.DistillKLDivLoss(student_model)
|
|
|
|
trainer = SentenceTransformerTrainer(model=student_model, train_dataset=train_dataset, loss=loss)
|
|
trainer.train()
|
|
"""
|
|
super().__init__()
|
|
self.model = model
|
|
self.similarity_fct = similarity_fct
|
|
self.temperature = temperature
|
|
self.loss_fct = nn.KLDivLoss(reduction="batchmean")
|
|
|
|
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
|
|
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
|
|
|
|
return self.compute_loss_from_embeddings(embeddings, labels)
|
|
|
|
def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
|
|
embeddings_query = embeddings[0]
|
|
|
|
# Compute student scores
|
|
student_scores = torch.stack(
|
|
[self.similarity_fct(embeddings_query, embeddings_other) for embeddings_other in embeddings[1:]],
|
|
dim=1,
|
|
)
|
|
# Scale student scores by temperature to soften distributions, then apply log-softmax
|
|
student_scores = student_scores / self.temperature
|
|
student_log_probs = torch.log_softmax(student_scores, dim=1)
|
|
|
|
# Compute teacher scores
|
|
teacher_scores = labels / self.temperature
|
|
teacher_probs = torch.softmax(teacher_scores, dim=1)
|
|
|
|
# Compute the KL Divergence
|
|
loss = self.loss_fct(student_log_probs, teacher_probs)
|
|
# Scale the loss to counteract the temperature scaling
|
|
loss = loss * (self.temperature**2)
|
|
return loss
|
|
|
|
@property
|
|
def citation(self) -> str:
|
|
return """
|
|
@misc{lin2020distillingdenserepresentationsranking,
|
|
title={Distilling Dense Representations for Ranking using Tightly-Coupled Teachers},
|
|
author={Sheng-Chieh Lin and Jheng-Hong Yang and Jimmy Lin},
|
|
year={2020},
|
|
eprint={2010.11386},
|
|
archivePrefix={arXiv},
|
|
primaryClass={cs.IR},
|
|
url={https://arxiv.org/abs/2010.11386},
|
|
}
|
|
"""
|