You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
293 lines
13 KiB
293 lines
13 KiB
from __future__ import annotations
|
|
|
|
import logging
|
|
import random
|
|
from collections.abc import Iterable, Sequence
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor, nn
|
|
|
|
from sentence_transformers.losses import (
|
|
CachedGISTEmbedLoss,
|
|
CachedMultipleNegativesRankingLoss,
|
|
CachedMultipleNegativesSymmetricRankingLoss,
|
|
)
|
|
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def shrink(tensor: Tensor, dim: int) -> Tensor:
|
|
tensor_dim = tensor.shape[-1]
|
|
if dim > tensor_dim:
|
|
raise ValueError(
|
|
f"Dimension {dim} in matryoshka_dims cannot be greater than the model's embedding dimension: {tensor_dim}"
|
|
)
|
|
tensor = tensor[..., :dim]
|
|
tensor = F.normalize(tensor, p=2, dim=-1)
|
|
return tensor
|
|
|
|
|
|
class ForwardDecorator:
|
|
"""
|
|
This decorator is used to cache the output of the Sentence Transformer's forward pass,
|
|
so that it can be shrank and reused for multiple loss calculations. This prevents the
|
|
model from recalculating the embeddings for each desired Matryoshka dimensionality.
|
|
|
|
This decorator is applied to `SentenceTransformer.forward`.
|
|
"""
|
|
|
|
def __init__(self, fn) -> None:
|
|
self.fn = fn
|
|
|
|
self.dim = None
|
|
self.cache = []
|
|
self.cache_dim = None
|
|
self.idx = 0
|
|
|
|
def set_dim(self, dim) -> None:
|
|
self.dim = dim
|
|
self.idx = 0
|
|
|
|
def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
# Growing cache:
|
|
if self.cache_dim is None or self.cache_dim == self.dim:
|
|
output = self.fn(features)
|
|
self.cache.append(output)
|
|
self.cache_dim = self.dim
|
|
# Using cache:
|
|
else:
|
|
output = self.cache[self.idx]
|
|
if "token_embeddings" in output:
|
|
output["token_embeddings"] = shrink(output["token_embeddings"], self.dim)
|
|
output["sentence_embedding"] = shrink(output["sentence_embedding"], self.dim)
|
|
self.idx += 1
|
|
return output
|
|
|
|
|
|
class CachedLossDecorator:
|
|
"""
|
|
This decorator is used with the Cached... losses to compute the underlying loss function
|
|
for each Matryoshka dimensionality. This is done by shrinking the pre-computed embeddings
|
|
to the desired dimensionality and then passing them to the underlying loss function once
|
|
for each desired dimensionality.
|
|
|
|
This decorator is applied to the `calculate_loss` method of the Cached... losses.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fn,
|
|
matryoshka_dims: Sequence[int],
|
|
matryoshka_weights: Sequence[float] | Sequence[int],
|
|
n_dims_per_step: int = -1,
|
|
) -> None:
|
|
self.fn = fn
|
|
self.matryoshka_dims = matryoshka_dims
|
|
self.matryoshka_weights = matryoshka_weights
|
|
self.n_dims_per_step = n_dims_per_step
|
|
|
|
def __call__(self, reps: list[list[Tensor]], *args, **kwargs) -> Tensor:
|
|
dim_indices = range(len(self.matryoshka_dims))
|
|
if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
|
|
dim_indices = random.sample(dim_indices, self.n_dims_per_step)
|
|
|
|
loss = 0.0
|
|
for idx in dim_indices:
|
|
dim = self.matryoshka_dims[idx]
|
|
weight = self.matryoshka_weights[idx]
|
|
|
|
truncated = [[shrink(r, dim) for r in minibatch] for minibatch in reps]
|
|
compute_gradients = torch.is_grad_enabled()
|
|
# we need to detach the truncated embeddings,
|
|
# otherwise the first backward pass of the underlying function will clear the computation graph of the embedding truncation
|
|
if compute_gradients:
|
|
matryoshka_reps = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated]
|
|
else:
|
|
matryoshka_reps = truncated
|
|
loss += weight * self.fn(matryoshka_reps, *args, **kwargs)
|
|
# After computing the gradients in minibatches, we need to continue the backward pass through the truncation calculation
|
|
# the gradients must be multiplied with the weights because otherwise the matryoshka weights are not considered in the backward pass
|
|
if compute_gradients:
|
|
for t_minibatch, d_minibatch in zip(truncated, matryoshka_reps):
|
|
for t, d in zip(t_minibatch, d_minibatch):
|
|
t.backward(weight * d.grad)
|
|
return loss
|
|
|
|
|
|
class MatryoshkaLoss(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model: SentenceTransformer,
|
|
loss: nn.Module,
|
|
matryoshka_dims: Sequence[int],
|
|
matryoshka_weights: Sequence[float] | Sequence[int] | None = None,
|
|
n_dims_per_step: int = -1,
|
|
) -> None:
|
|
"""
|
|
The MatryoshkaLoss can be seen as a loss *modifier* that allows you to use other loss functions at various
|
|
different embedding dimensions. This is useful for when you want to train a model where users have the option
|
|
to lower the embedding dimension to improve their embedding comparison speed and costs.
|
|
|
|
This loss is also compatible with the Cached... losses, which are in-batch negative losses that allow for
|
|
higher batch sizes. The higher batch sizes allow for more negatives, and often result in a stronger model.
|
|
|
|
Args:
|
|
model: SentenceTransformer model
|
|
loss: The loss function to be used, e.g.
|
|
:class:`MultipleNegativesRankingLoss`,
|
|
:class:`CoSENTLoss`, etc.
|
|
matryoshka_dims: A list of embedding dimensions to be used
|
|
for the loss function, e.g. [768, 512, 256, 128, 64].
|
|
matryoshka_weights: A list of weights to be used for the
|
|
loss function, e.g. [1, 1, 1, 1, 1]. If None, then the
|
|
weights will be set to 1 for all dimensions.
|
|
n_dims_per_step: The number of dimensions to use per step.
|
|
If -1, then all dimensions are used. If > 0, then a
|
|
random sample of n_dims_per_step dimensions are used per
|
|
step. The default value is -1.
|
|
|
|
References:
|
|
- The concept was introduced in this paper: https://huggingface.co/papers/2205.13147
|
|
- `Matryoshka Embeddings <../../../examples/sentence_transformer/training/matryoshka/README.html>`_
|
|
|
|
Inputs:
|
|
+---------------------------------------+--------+
|
|
| Texts | Labels |
|
|
+=======================================+========+
|
|
| any | any |
|
|
+---------------------------------------+--------+
|
|
|
|
Relations:
|
|
- :class:`Matryoshka2dLoss` uses this loss in combination with :class:`AdaptiveLayerLoss` which allows for
|
|
layer reduction for faster inference.
|
|
|
|
Example:
|
|
::
|
|
|
|
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
|
|
from datasets import Dataset
|
|
|
|
model = SentenceTransformer("microsoft/mpnet-base")
|
|
train_dataset = Dataset.from_dict({
|
|
"anchor": ["It's nice weather outside today.", "He drove to work."],
|
|
"positive": ["It's so sunny.", "He took the car to the office."],
|
|
})
|
|
loss = losses.MultipleNegativesRankingLoss(model)
|
|
loss = losses.MatryoshkaLoss(model, loss, [768, 512, 256, 128, 64])
|
|
|
|
trainer = SentenceTransformerTrainer(
|
|
model=model,
|
|
train_dataset=train_dataset,
|
|
loss=loss,
|
|
)
|
|
trainer.train()
|
|
"""
|
|
super().__init__()
|
|
self.model = model
|
|
self.loss = loss
|
|
|
|
if not matryoshka_dims:
|
|
raise ValueError("You must provide at least one dimension in matryoshka_dims.")
|
|
if any(dim <= 0 for dim in matryoshka_dims):
|
|
raise ValueError("All dimensions passed to a matryoshka loss must be > 0.")
|
|
if matryoshka_weights is None:
|
|
matryoshka_weights = [1] * len(matryoshka_dims)
|
|
elif len(matryoshka_weights) != len(matryoshka_dims):
|
|
raise ValueError("matryoshka_weights must be the same length as matryoshka_dims.")
|
|
|
|
model_embedding_dim = model.get_sentence_embedding_dimension()
|
|
if model_embedding_dim is not None:
|
|
if any(d > model_embedding_dim for d in matryoshka_dims):
|
|
raise ValueError(
|
|
f"Dimensions in matryoshka_dims cannot exceed the model's embedding dimension of {model_embedding_dim}."
|
|
)
|
|
if model_embedding_dim not in matryoshka_dims:
|
|
logger.warning(
|
|
f"The model's embedding dimension {model_embedding_dim} is not included in matryoshka_dims: {matryoshka_dims}. "
|
|
"This means that the full model dimension won't be trained, which may lead to degraded performance "
|
|
"when using the model without specifying a lower truncation dimension. It is strongly recommended to include "
|
|
f"{model_embedding_dim} in matryoshka_dims."
|
|
)
|
|
|
|
# Sort the dimensions and weights in descending order
|
|
dims_weights = zip(matryoshka_dims, matryoshka_weights)
|
|
self.matryoshka_dims: tuple[int, ...]
|
|
self.matryoshka_weights: tuple[float, ...] | tuple[int, ...]
|
|
self.matryoshka_dims, self.matryoshka_weights = zip(*sorted(dims_weights, key=lambda x: x[0], reverse=True))
|
|
self.n_dims_per_step = n_dims_per_step
|
|
|
|
# The Cached... losses require a special treatment as their backward pass is incompatible with the
|
|
# ForwardDecorator approach. Instead, we use a CachedLossDecorator to compute the loss for each
|
|
# Matryoshka dimensionality given pre-computed embeddings passed to `calculate_loss`.
|
|
self.cached_losses = (
|
|
CachedMultipleNegativesRankingLoss,
|
|
CachedGISTEmbedLoss,
|
|
CachedMultipleNegativesSymmetricRankingLoss,
|
|
)
|
|
if isinstance(loss, self.cached_losses):
|
|
loss.calculate_loss = CachedLossDecorator(
|
|
loss.calculate_loss, self.matryoshka_dims, self.matryoshka_weights
|
|
)
|
|
|
|
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
|
|
# For the Cached... losses, the CachedLossDecorator has been applied to the `calculate_loss` method,
|
|
# so we can directly call the loss function.
|
|
if isinstance(self.loss, self.cached_losses):
|
|
return self.loss(sentence_features, labels)
|
|
|
|
# Otherwise, we apply the ForwardDecorator to the model's forward pass, which will cache the output
|
|
# embeddings for each Matryoshka dimensionality, allowing it to be reused for the smaller dimensions.
|
|
original_forward = self.model.forward
|
|
try:
|
|
decorated_forward = ForwardDecorator(original_forward)
|
|
self.model.forward = decorated_forward
|
|
|
|
dim_indices = range(len(self.matryoshka_dims))
|
|
if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
|
|
dim_indices = random.sample(dim_indices, self.n_dims_per_step)
|
|
dim_indices.sort()
|
|
|
|
loss = 0.0
|
|
for idx in dim_indices:
|
|
dim = self.matryoshka_dims[idx]
|
|
weight = self.matryoshka_weights[idx]
|
|
decorated_forward.set_dim(dim)
|
|
# If the labels seem to be embeddings, truncate them to match the soon-to-be-truncated predicted embeddings
|
|
# This allows for MatryoshkaLoss with a direct distillation loss
|
|
dim_labels = labels
|
|
if (
|
|
isinstance(labels, torch.Tensor)
|
|
and labels.ndim == 2
|
|
and labels.size(-1) == self.model.get_sentence_embedding_dimension()
|
|
):
|
|
dim_labels = labels[:, :dim]
|
|
|
|
loss += weight * self.loss(sentence_features, dim_labels)
|
|
finally:
|
|
self.model.forward = original_forward
|
|
return loss
|
|
|
|
def get_config_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"loss": self.loss.__class__.__name__,
|
|
"matryoshka_dims": self.matryoshka_dims,
|
|
"matryoshka_weights": self.matryoshka_weights,
|
|
"n_dims_per_step": self.n_dims_per_step,
|
|
}
|
|
|
|
@property
|
|
def citation(self) -> str:
|
|
return """
|
|
@misc{kusupati2024matryoshka,
|
|
title={Matryoshka Representation Learning},
|
|
author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
|
|
year={2024},
|
|
eprint={2205.13147},
|
|
archivePrefix={arXiv},
|
|
primaryClass={cs.LG}
|
|
}
|
|
"""
|