You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

418 lines
19 KiB

4 days ago
from __future__ import annotations
import json
import os
from pathlib import Path
try:
from typing import Self
except ImportError:
from typing_extensions import Self
from torch import Tensor, nn
from transformers.utils import logging
from sentence_transformers.models.InputModule import InputModule
from sentence_transformers.models.Module import Module
from sentence_transformers.util import import_from_string, load_dir_path
logger = logging.get_logger(__name__)
class Router(InputModule):
forward_kwargs = {"task"}
config_keys: list[str] = ["default_route", "allow_empty_key"]
config_file_name = "router_config.json"
def __init__(
self, sub_modules: dict[str, list[Module]], default_route: str | None = None, allow_empty_key: bool = True
) -> None:
r"""
This model allows to create asymmetric SentenceTransformer models that apply different modules depending on the specified route,
such as "query" or "document". Especially useful for models that have different encoders for queries and documents.
Notably, the ``task`` argument of ``model.encode`` can be used to specify which route to use, and
``model.encode_query`` and ``model.encode_document`` are shorthands for using ``task="query"`` and
``task="document"``, respectively. These methods also optionally apply ``prompts`` specific to queries
or documents.
.. note::
When training models with the :class:`~sentence_transformers.models.Router` module, you must use the
``router_mapping`` argument in the :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`
or :class:`~sentence_transformers.sparse_encoder.training_args.SparseEncoderTrainingArguments` to map the
training dataset columns to the correct route ("query" or "document"). For example, if your training dataset(s)
have ``["question", "positive", "negative"]`` columns, then you can use the following mapping::
args = SparseEncoderTrainingArguments(
...,
router_mapping={
"question": "query",
"positive": "document",
"negative": "document",
}
)
Additionally, it is common to use a different learning rate for the different routes. For this, you should
use the ``learning_rate_mapping`` argument in the :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`
or :class:`~sentence_transformers.sparse_encoder.training_args.SparseEncoderTrainingArguments` to map parameter patterns
to their learning rates. For example, if you want to use a learning rate of ``1e-3`` for an SparseStaticEmbedding module and
``2e-5`` for the rest of the model, you can do this::
args = SparseEncoderTrainingArguments(
...,
learning_rate=2e-5,
learning_rate_mapping={
r"SparseStaticEmbedding\.*": 1e-3,
}
)
In the below examples, the ``Router`` model is used to create asymmetric models with different encoders for
queries and documents. In these examples, the "query" route is efficient (e.g., using SparseStaticEmbedding),
while the "document" route uses a more complex model (e.g. a Transformers module). This allows for efficient
query encoding while still using a powerful document encoder, but the combinations are not limited to this.
Example:
::
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Router, Normalize
# Use a regular SentenceTransformer for the document embeddings, and a static embedding model for the query embeddings
document_embedder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
query_embedder = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1")
router = Router.for_query_document(
query_modules=list(query_embedder.children()),
document_modules=list(document_embedder.children()),
)
normalize = Normalize()
# Create an asymmetric model with different encoders for queries and documents
model = SentenceTransformer(
modules=[router, normalize],
)
# ... requires more training to align the vector spaces
# Use the query & document routes
query_embedding = model.encode_query("What is the capital of France?")
document_embedding = model.encode_document("Paris is the capital of France.")
::
from sentence_transformers.models import Router
from sentence_transformers.sparse_encoder import SparseEncoder
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseStaticEmbedding, SpladePooling
# Load an asymmetric model with different encoders for queries and documents
doc_encoder = MLMTransformer("opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill")
router = Router.for_query_document(
query_modules=[
SparseStaticEmbedding.from_json(
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill",
tokenizer=doc_encoder.tokenizer,
frozen=True,
),
],
document_modules=[
doc_encoder,
SpladePooling(pooling_strategy="max", activation_function="log1p_relu"),
],
)
model = SparseEncoder(modules=[router], similarity_fn_name="dot")
query = "What's the weather in ny now?"
document = "Currently New York is rainy."
query_embed = model.encode_query(query)
document_embed = model.encode_document(document)
sim = model.similarity(query_embed, document_embed)
print(f"Similarity: {sim}")
# Visualize top tokens for each text
top_k = 10
print(f"Top tokens {top_k} for each text:")
decoded_query = model.decode(query_embed, top_k=top_k)
decoded_document = model.decode(document_embed)
for i in range(min(top_k, len(decoded_query))):
query_token, query_score = decoded_query[i]
doc_score = next((score for token, score in decoded_document if token == query_token), 0)
if doc_score != 0:
print(f"Token: {query_token}, Query score: {query_score:.4f}, Document score: {doc_score:.4f}")
'''
Similarity: tensor([[11.1105]], device='cuda:0')
Top tokens 10 for each text:
Token: ny, Query score: 5.7729, Document score: 0.8049
Token: weather, Query score: 4.5684, Document score: 0.9710
Token: now, Query score: 3.5895, Document score: 0.4720
Token: ?, Query score: 3.3313, Document score: 0.0286
Token: what, Query score: 2.7699, Document score: 0.0787
Token: in, Query score: 0.4989, Document score: 0.0417
'''
Note:
These models are not necessarily stronger than non-asymmetric models. Rudimentary experiments indicate
that non-Router models perform better in many cases.
Args:
sub_modules: Mapping of route keys to lists of modules. Each key corresponds to a specific task type,
often "query" or "document", and the list contains the modules to be applied for that task type.
default_route: The default route to use if no task type is specified. If None, an exception will be thrown
if no task type is specified. If ``allow_empty_key`` is True, the first key in sub_modules will be used as
the default route. Defaults to None.
allow_empty_key: If True, allows the default route to be set to the first key in `sub_modules` if
``default_route`` is None. Defaults to True.
"""
super().__init__()
if sub_modules is None or len(sub_modules) == 0:
raise ValueError("The routes dictionary cannot be empty.")
if default_route is not None and default_route not in sub_modules:
raise ValueError(f"Default route '{default_route}' not found in route keys: {list(sub_modules.keys())}")
self.sub_modules = nn.ModuleDict(
{route_name: nn.Sequential(*modules) for route_name, modules in sub_modules.items()}
)
# If allow_empty_key is True, we can set a default route to the first key in sub_modules.
if allow_empty_key and default_route is None:
default_route = next(iter(sub_modules.keys()))
self.default_route = default_route
self.allow_empty_key = allow_empty_key
@classmethod
def for_query_document(
cls,
query_modules: list[Module],
document_modules: list[Module],
default_route: str | None = None,
allow_empty_key: bool = True,
) -> Self:
"""
Creates a Router model specifically for query and document modules, allowing convenient usage via `model.encode_query`
and `model.encode_document`.
Args:
query_modules: List of modules to be applied for the "query" task type.
document_modules: List of modules to be applied for the "document" task type.
default_route: The default route to use if no task type is specified. If None, an exception will be thrown
if no task type is specified. If ``allow_empty_key`` is True, the first key in sub_modules will be used as
the default route. Defaults to None.
allow_empty_key: If True, allows the default route to be set to the first key in `sub_modules` if
``default_route`` is None. Defaults to True.
Returns:
Router: An instance of the Router model with the specified query and document modules.
"""
return cls(
sub_modules={"query": query_modules, "document": document_modules},
default_route=default_route or "document",
allow_empty_key=allow_empty_key,
)
def forward(self, features: dict[str, Tensor], task: str | None = None, **kwargs) -> dict[str, Tensor]:
if task is None:
task = features.get("task", self.default_route)
if task is None:
if self.training:
raise ValueError(
"You must provide a `router_mapping` argument on the training arguments, "
"or set a default route in the `Router` module."
)
else:
raise ValueError(
"You must provide a `task` argument when calling this method, "
"or set a default route in the `Router` module."
)
if task not in self.sub_modules:
raise ValueError(
f"No route found for task type '{task}'. Available routes: {list(self.sub_modules.keys())}"
)
kwargs["task"] = task
for module in self.sub_modules[task]:
module_kwargs = {
key: value
for key, value in kwargs.items()
if hasattr(module, "forward_kwargs") and key in module.forward_kwargs
}
features = module(features, **module_kwargs)
return features
def get_sentence_embedding_dimension(self) -> int:
for sub_modules in self.sub_modules.values():
for module in reversed(sub_modules):
if hasattr(module, "get_sentence_embedding_dimension"):
return module.get_sentence_embedding_dimension()
return None
def save(self, output_path: str, safe_serialization: bool = True, **kwargs):
model_lookup = {}
model_types = {}
model_structure = {}
for name, models in self.sub_modules.items():
model_structure[name] = []
for module_idx, model in enumerate(models):
model_id = f"{name}_{module_idx}_{type(model).__name__}"
model_lookup[model_id] = model
model_types[model_id] = f"{type(model).__module__}.{type(model).__name__}"
model_structure[name].append(model_id)
for model_id, model in model_lookup.items():
model_path = os.path.join(output_path, str(model_id))
os.makedirs(model_path, exist_ok=True)
try:
model.save(model_path, safe_serialization=safe_serialization, **kwargs)
except TypeError:
# Fallback for legacy models that do not support kwargs
model.save(model_path)
with open(os.path.join(output_path, self.config_file_name), "w", encoding="utf8") as fOut:
json.dump(
{
"types": model_types,
"structure": model_structure,
"parameters": self.get_config_dict(),
},
fOut,
indent=4,
)
def tokenize(self, texts: list[str] | list[tuple[str, str]], task: str | None = None, **kwargs):
"""Tokenizes a text and maps tokens to token-ids"""
if isinstance(texts[0], dict):
# Extract the task type key from the dictionaries
if task is None:
tasks = set(key for text in texts for key in text.keys())
if len(tasks) > 1:
raise ValueError(
"You cannot pass a list of dictionaries with different task types. "
"Please ensure all dictionaries have the same task type key, or pass a single `task` argument."
)
task = tasks.pop()
# Remove dictionary structure
texts = [text[task] for text in texts]
if task is None:
task = self.default_route
if task is None:
if self.training:
raise ValueError(
"You must provide a `router_mapping` argument on the training arguments, "
"or set a default route in the `Router` module."
)
else:
raise ValueError(
"You must provide a `task` argument when calling this method, "
"or set a default route in the `Router` module."
)
if task not in self.sub_modules:
raise ValueError(
f"No route found for task type '{task}'. Available routes: {list(self.sub_modules.keys())}"
)
input_module = self.sub_modules[task][0]
tokenized = input_module.tokenize(texts, **kwargs)
tokenized["task"] = task
return tokenized
@classmethod
def load(
cls,
model_name_or_path: str,
subfolder: str = "",
token: bool | str | None = None,
cache_folder: str | None = None,
revision: str | None = None,
local_files_only: bool = False,
**kwargs,
) -> Self:
hub_kwargs = {
"token": token,
"cache_folder": cache_folder,
"revision": revision,
"local_files_only": local_files_only,
}
# Try the official config file first, then fall back to the legacy config file
config = cls.load_config(model_name_or_path=model_name_or_path, subfolder=subfolder, **hub_kwargs)
if not config:
config = cls.load_config(
model_name_or_path=model_name_or_path, config_filename="config.json", subfolder=subfolder, **hub_kwargs
)
modules = {}
for model_id, model_type in config["types"].items():
module_class: Module = import_from_string(model_type)
try:
module = module_class.load(
model_name_or_path, subfolder=Path(subfolder, model_id).as_posix(), **hub_kwargs, **kwargs
)
except TypeError:
local_path = load_dir_path(
model_name_or_path=model_name_or_path, subfolder=Path(subfolder, model_id).as_posix(), **hub_kwargs
)
module = module_class.load(local_path)
modules[model_id] = module
model_structure = {}
for key_name, models_list in config["structure"].items():
model_structure[key_name] = []
for model_id in models_list:
model_structure[key_name].append(modules[model_id])
model = cls(model_structure, **config["parameters"])
return model
@property
def tokenizer(self):
# We might have multiple tokenizers, one for each route, but we can only return one here.
for sub_modules in self.sub_modules.values():
input_module: InputModule = sub_modules[0]
if hasattr(input_module, "tokenizer") and input_module.tokenizer is not None:
return input_module.tokenizer
return None
@property
def max_seq_length(self) -> int:
# Collect all unique max_seq_length values
max_seq_lengths = set()
for modules in self.sub_modules.values():
input_module: InputModule = modules[0]
if modules and hasattr(input_module, "max_seq_length"):
max_seq_lengths.add(input_module.max_seq_length)
if not max_seq_lengths:
return None
elif len(max_seq_lengths) == 1:
# Only one unique max_seq_length
return max_seq_lengths.pop()
else:
logger.warning_once(f"Different max_seq_lengths detected: {max_seq_lengths}. Using the maximum value.")
return max(max_seq_lengths)
@max_seq_length.setter
def max_seq_length(self, value) -> None:
# Check which modules have max_seq_length
has_max_seq_length_keys = []
for key, models in self.sub_modules.items():
if models and hasattr(models[0], "max_seq_length"):
has_max_seq_length_keys.append(key)
if len(has_max_seq_length_keys) == 0:
logger.warning("No modules have a max_seq_length attribute to set.")
return
for key in has_max_seq_length_keys:
input_module: InputModule = self.sub_modules[key][0]
input_module.max_seq_length = value
# For backwards compatibility, we ensure that the legacy `Asym` alias points to the new `Router` class.
Asym = Router