You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

293 lines
13 KiB

from __future__ import annotations
import logging
import random
from collections.abc import Iterable, Sequence
from typing import Any
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from sentence_transformers.losses import (
CachedGISTEmbedLoss,
CachedMultipleNegativesRankingLoss,
CachedMultipleNegativesSymmetricRankingLoss,
)
from sentence_transformers.SentenceTransformer import SentenceTransformer
logger = logging.getLogger(__name__)
def shrink(tensor: Tensor, dim: int) -> Tensor:
tensor_dim = tensor.shape[-1]
if dim > tensor_dim:
raise ValueError(
f"Dimension {dim} in matryoshka_dims cannot be greater than the model's embedding dimension: {tensor_dim}"
)
tensor = tensor[..., :dim]
tensor = F.normalize(tensor, p=2, dim=-1)
return tensor
class ForwardDecorator:
"""
This decorator is used to cache the output of the Sentence Transformer's forward pass,
so that it can be shrank and reused for multiple loss calculations. This prevents the
model from recalculating the embeddings for each desired Matryoshka dimensionality.
This decorator is applied to `SentenceTransformer.forward`.
"""
def __init__(self, fn) -> None:
self.fn = fn
self.dim = None
self.cache = []
self.cache_dim = None
self.idx = 0
def set_dim(self, dim) -> None:
self.dim = dim
self.idx = 0
def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
# Growing cache:
if self.cache_dim is None or self.cache_dim == self.dim:
output = self.fn(features)
self.cache.append(output)
self.cache_dim = self.dim
# Using cache:
else:
output = self.cache[self.idx]
if "token_embeddings" in output:
output["token_embeddings"] = shrink(output["token_embeddings"], self.dim)
output["sentence_embedding"] = shrink(output["sentence_embedding"], self.dim)
self.idx += 1
return output
class CachedLossDecorator:
"""
This decorator is used with the Cached... losses to compute the underlying loss function
for each Matryoshka dimensionality. This is done by shrinking the pre-computed embeddings
to the desired dimensionality and then passing them to the underlying loss function once
for each desired dimensionality.
This decorator is applied to the `calculate_loss` method of the Cached... losses.
"""
def __init__(
self,
fn,
matryoshka_dims: Sequence[int],
matryoshka_weights: Sequence[float] | Sequence[int],
n_dims_per_step: int = -1,
) -> None:
self.fn = fn
self.matryoshka_dims = matryoshka_dims
self.matryoshka_weights = matryoshka_weights
self.n_dims_per_step = n_dims_per_step
def __call__(self, reps: list[list[Tensor]], *args, **kwargs) -> Tensor:
dim_indices = range(len(self.matryoshka_dims))
if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
dim_indices = random.sample(dim_indices, self.n_dims_per_step)
loss = 0.0
for idx in dim_indices:
dim = self.matryoshka_dims[idx]
weight = self.matryoshka_weights[idx]
truncated = [[shrink(r, dim) for r in minibatch] for minibatch in reps]
compute_gradients = torch.is_grad_enabled()
# we need to detach the truncated embeddings,
# otherwise the first backward pass of the underlying function will clear the computation graph of the embedding truncation
if compute_gradients:
matryoshka_reps = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated]
else:
matryoshka_reps = truncated
loss += weight * self.fn(matryoshka_reps, *args, **kwargs)
# After computing the gradients in minibatches, we need to continue the backward pass through the truncation calculation
# the gradients must be multiplied with the weights because otherwise the matryoshka weights are not considered in the backward pass
if compute_gradients:
for t_minibatch, d_minibatch in zip(truncated, matryoshka_reps):
for t, d in zip(t_minibatch, d_minibatch):
t.backward(weight * d.grad)
return loss
class MatryoshkaLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
loss: nn.Module,
matryoshka_dims: Sequence[int],
matryoshka_weights: Sequence[float] | Sequence[int] | None = None,
n_dims_per_step: int = -1,
) -> None:
"""
The MatryoshkaLoss can be seen as a loss *modifier* that allows you to use other loss functions at various
different embedding dimensions. This is useful for when you want to train a model where users have the option
to lower the embedding dimension to improve their embedding comparison speed and costs.
This loss is also compatible with the Cached... losses, which are in-batch negative losses that allow for
higher batch sizes. The higher batch sizes allow for more negatives, and often result in a stronger model.
Args:
model: SentenceTransformer model
loss: The loss function to be used, e.g.
:class:`MultipleNegativesRankingLoss`,
:class:`CoSENTLoss`, etc.
matryoshka_dims: A list of embedding dimensions to be used
for the loss function, e.g. [768, 512, 256, 128, 64].
matryoshka_weights: A list of weights to be used for the
loss function, e.g. [1, 1, 1, 1, 1]. If None, then the
weights will be set to 1 for all dimensions.
n_dims_per_step: The number of dimensions to use per step.
If -1, then all dimensions are used. If > 0, then a
random sample of n_dims_per_step dimensions are used per
step. The default value is -1.
References:
- The concept was introduced in this paper: https://huggingface.co/papers/2205.13147
- `Matryoshka Embeddings <../../../examples/sentence_transformer/training/matryoshka/README.html>`_
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| any | any |
+---------------------------------------+--------+
Relations:
- :class:`Matryoshka2dLoss` uses this loss in combination with :class:`AdaptiveLayerLoss` which allows for
layer reduction for faster inference.
Example:
::
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import Dataset
model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
"anchor": ["It's nice weather outside today.", "He drove to work."],
"positive": ["It's so sunny.", "He took the car to the office."],
})
loss = losses.MultipleNegativesRankingLoss(model)
loss = losses.MatryoshkaLoss(model, loss, [768, 512, 256, 128, 64])
trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
"""
super().__init__()
self.model = model
self.loss = loss
if not matryoshka_dims:
raise ValueError("You must provide at least one dimension in matryoshka_dims.")
if any(dim <= 0 for dim in matryoshka_dims):
raise ValueError("All dimensions passed to a matryoshka loss must be > 0.")
if matryoshka_weights is None:
matryoshka_weights = [1] * len(matryoshka_dims)
elif len(matryoshka_weights) != len(matryoshka_dims):
raise ValueError("matryoshka_weights must be the same length as matryoshka_dims.")
model_embedding_dim = model.get_sentence_embedding_dimension()
if model_embedding_dim is not None:
if any(d > model_embedding_dim for d in matryoshka_dims):
raise ValueError(
f"Dimensions in matryoshka_dims cannot exceed the model's embedding dimension of {model_embedding_dim}."
)
if model_embedding_dim not in matryoshka_dims:
logger.warning(
f"The model's embedding dimension {model_embedding_dim} is not included in matryoshka_dims: {matryoshka_dims}. "
"This means that the full model dimension won't be trained, which may lead to degraded performance "
"when using the model without specifying a lower truncation dimension. It is strongly recommended to include "
f"{model_embedding_dim} in matryoshka_dims."
)
# Sort the dimensions and weights in descending order
dims_weights = zip(matryoshka_dims, matryoshka_weights)
self.matryoshka_dims: tuple[int, ...]
self.matryoshka_weights: tuple[float, ...] | tuple[int, ...]
self.matryoshka_dims, self.matryoshka_weights = zip(*sorted(dims_weights, key=lambda x: x[0], reverse=True))
self.n_dims_per_step = n_dims_per_step
# The Cached... losses require a special treatment as their backward pass is incompatible with the
# ForwardDecorator approach. Instead, we use a CachedLossDecorator to compute the loss for each
# Matryoshka dimensionality given pre-computed embeddings passed to `calculate_loss`.
self.cached_losses = (
CachedMultipleNegativesRankingLoss,
CachedGISTEmbedLoss,
CachedMultipleNegativesSymmetricRankingLoss,
)
if isinstance(loss, self.cached_losses):
loss.calculate_loss = CachedLossDecorator(
loss.calculate_loss, self.matryoshka_dims, self.matryoshka_weights
)
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
# For the Cached... losses, the CachedLossDecorator has been applied to the `calculate_loss` method,
# so we can directly call the loss function.
if isinstance(self.loss, self.cached_losses):
return self.loss(sentence_features, labels)
# Otherwise, we apply the ForwardDecorator to the model's forward pass, which will cache the output
# embeddings for each Matryoshka dimensionality, allowing it to be reused for the smaller dimensions.
original_forward = self.model.forward
try:
decorated_forward = ForwardDecorator(original_forward)
self.model.forward = decorated_forward
dim_indices = range(len(self.matryoshka_dims))
if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
dim_indices = random.sample(dim_indices, self.n_dims_per_step)
dim_indices.sort()
loss = 0.0
for idx in dim_indices:
dim = self.matryoshka_dims[idx]
weight = self.matryoshka_weights[idx]
decorated_forward.set_dim(dim)
# If the labels seem to be embeddings, truncate them to match the soon-to-be-truncated predicted embeddings
# This allows for MatryoshkaLoss with a direct distillation loss
dim_labels = labels
if (
isinstance(labels, torch.Tensor)
and labels.ndim == 2
and labels.size(-1) == self.model.get_sentence_embedding_dimension()
):
dim_labels = labels[:, :dim]
loss += weight * self.loss(sentence_features, dim_labels)
finally:
self.model.forward = original_forward
return loss
def get_config_dict(self) -> dict[str, Any]:
return {
"loss": self.loss.__class__.__name__,
"matryoshka_dims": self.matryoshka_dims,
"matryoshka_weights": self.matryoshka_weights,
"n_dims_per_step": self.n_dims_per_step,
}
@property
def citation(self) -> str:
return """
@misc{kusupati2024matryoshka,
title={Matryoshka Representation Learning},
author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
year={2024},
eprint={2205.13147},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
"""