You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
418 lines
19 KiB
418 lines
19 KiB
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
try:
|
|
from typing import Self
|
|
except ImportError:
|
|
from typing_extensions import Self
|
|
|
|
from torch import Tensor, nn
|
|
from transformers.utils import logging
|
|
|
|
from sentence_transformers.models.InputModule import InputModule
|
|
from sentence_transformers.models.Module import Module
|
|
from sentence_transformers.util import import_from_string, load_dir_path
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class Router(InputModule):
|
|
forward_kwargs = {"task"}
|
|
config_keys: list[str] = ["default_route", "allow_empty_key"]
|
|
config_file_name = "router_config.json"
|
|
|
|
def __init__(
|
|
self, sub_modules: dict[str, list[Module]], default_route: str | None = None, allow_empty_key: bool = True
|
|
) -> None:
|
|
r"""
|
|
This model allows to create asymmetric SentenceTransformer models that apply different modules depending on the specified route,
|
|
such as "query" or "document". Especially useful for models that have different encoders for queries and documents.
|
|
|
|
Notably, the ``task`` argument of ``model.encode`` can be used to specify which route to use, and
|
|
``model.encode_query`` and ``model.encode_document`` are shorthands for using ``task="query"`` and
|
|
``task="document"``, respectively. These methods also optionally apply ``prompts`` specific to queries
|
|
or documents.
|
|
|
|
.. note::
|
|
|
|
When training models with the :class:`~sentence_transformers.models.Router` module, you must use the
|
|
``router_mapping`` argument in the :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`
|
|
or :class:`~sentence_transformers.sparse_encoder.training_args.SparseEncoderTrainingArguments` to map the
|
|
training dataset columns to the correct route ("query" or "document"). For example, if your training dataset(s)
|
|
have ``["question", "positive", "negative"]`` columns, then you can use the following mapping::
|
|
|
|
args = SparseEncoderTrainingArguments(
|
|
...,
|
|
router_mapping={
|
|
"question": "query",
|
|
"positive": "document",
|
|
"negative": "document",
|
|
}
|
|
)
|
|
|
|
Additionally, it is common to use a different learning rate for the different routes. For this, you should
|
|
use the ``learning_rate_mapping`` argument in the :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`
|
|
or :class:`~sentence_transformers.sparse_encoder.training_args.SparseEncoderTrainingArguments` to map parameter patterns
|
|
to their learning rates. For example, if you want to use a learning rate of ``1e-3`` for an SparseStaticEmbedding module and
|
|
``2e-5`` for the rest of the model, you can do this::
|
|
|
|
args = SparseEncoderTrainingArguments(
|
|
...,
|
|
learning_rate=2e-5,
|
|
learning_rate_mapping={
|
|
r"SparseStaticEmbedding\.*": 1e-3,
|
|
}
|
|
)
|
|
|
|
In the below examples, the ``Router`` model is used to create asymmetric models with different encoders for
|
|
queries and documents. In these examples, the "query" route is efficient (e.g., using SparseStaticEmbedding),
|
|
while the "document" route uses a more complex model (e.g. a Transformers module). This allows for efficient
|
|
query encoding while still using a powerful document encoder, but the combinations are not limited to this.
|
|
|
|
Example:
|
|
::
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
from sentence_transformers.models import Router, Normalize
|
|
|
|
# Use a regular SentenceTransformer for the document embeddings, and a static embedding model for the query embeddings
|
|
document_embedder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
|
query_embedder = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1")
|
|
router = Router.for_query_document(
|
|
query_modules=list(query_embedder.children()),
|
|
document_modules=list(document_embedder.children()),
|
|
)
|
|
normalize = Normalize()
|
|
|
|
# Create an asymmetric model with different encoders for queries and documents
|
|
model = SentenceTransformer(
|
|
modules=[router, normalize],
|
|
)
|
|
|
|
# ... requires more training to align the vector spaces
|
|
|
|
# Use the query & document routes
|
|
query_embedding = model.encode_query("What is the capital of France?")
|
|
document_embedding = model.encode_document("Paris is the capital of France.")
|
|
|
|
::
|
|
|
|
from sentence_transformers.models import Router
|
|
from sentence_transformers.sparse_encoder import SparseEncoder
|
|
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseStaticEmbedding, SpladePooling
|
|
|
|
# Load an asymmetric model with different encoders for queries and documents
|
|
doc_encoder = MLMTransformer("opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill")
|
|
router = Router.for_query_document(
|
|
query_modules=[
|
|
SparseStaticEmbedding.from_json(
|
|
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill",
|
|
tokenizer=doc_encoder.tokenizer,
|
|
frozen=True,
|
|
),
|
|
],
|
|
document_modules=[
|
|
doc_encoder,
|
|
SpladePooling(pooling_strategy="max", activation_function="log1p_relu"),
|
|
],
|
|
)
|
|
|
|
model = SparseEncoder(modules=[router], similarity_fn_name="dot")
|
|
|
|
query = "What's the weather in ny now?"
|
|
document = "Currently New York is rainy."
|
|
|
|
query_embed = model.encode_query(query)
|
|
document_embed = model.encode_document(document)
|
|
|
|
sim = model.similarity(query_embed, document_embed)
|
|
print(f"Similarity: {sim}")
|
|
|
|
# Visualize top tokens for each text
|
|
top_k = 10
|
|
print(f"Top tokens {top_k} for each text:")
|
|
|
|
decoded_query = model.decode(query_embed, top_k=top_k)
|
|
decoded_document = model.decode(document_embed)
|
|
|
|
for i in range(min(top_k, len(decoded_query))):
|
|
query_token, query_score = decoded_query[i]
|
|
doc_score = next((score for token, score in decoded_document if token == query_token), 0)
|
|
if doc_score != 0:
|
|
print(f"Token: {query_token}, Query score: {query_score:.4f}, Document score: {doc_score:.4f}")
|
|
|
|
'''
|
|
Similarity: tensor([[11.1105]], device='cuda:0')
|
|
Top tokens 10 for each text:
|
|
Token: ny, Query score: 5.7729, Document score: 0.8049
|
|
Token: weather, Query score: 4.5684, Document score: 0.9710
|
|
Token: now, Query score: 3.5895, Document score: 0.4720
|
|
Token: ?, Query score: 3.3313, Document score: 0.0286
|
|
Token: what, Query score: 2.7699, Document score: 0.0787
|
|
Token: in, Query score: 0.4989, Document score: 0.0417
|
|
'''
|
|
|
|
Note:
|
|
These models are not necessarily stronger than non-asymmetric models. Rudimentary experiments indicate
|
|
that non-Router models perform better in many cases.
|
|
|
|
Args:
|
|
sub_modules: Mapping of route keys to lists of modules. Each key corresponds to a specific task type,
|
|
often "query" or "document", and the list contains the modules to be applied for that task type.
|
|
default_route: The default route to use if no task type is specified. If None, an exception will be thrown
|
|
if no task type is specified. If ``allow_empty_key`` is True, the first key in sub_modules will be used as
|
|
the default route. Defaults to None.
|
|
allow_empty_key: If True, allows the default route to be set to the first key in `sub_modules` if
|
|
``default_route`` is None. Defaults to True.
|
|
"""
|
|
super().__init__()
|
|
if sub_modules is None or len(sub_modules) == 0:
|
|
raise ValueError("The routes dictionary cannot be empty.")
|
|
if default_route is not None and default_route not in sub_modules:
|
|
raise ValueError(f"Default route '{default_route}' not found in route keys: {list(sub_modules.keys())}")
|
|
|
|
self.sub_modules = nn.ModuleDict(
|
|
{route_name: nn.Sequential(*modules) for route_name, modules in sub_modules.items()}
|
|
)
|
|
|
|
# If allow_empty_key is True, we can set a default route to the first key in sub_modules.
|
|
if allow_empty_key and default_route is None:
|
|
default_route = next(iter(sub_modules.keys()))
|
|
self.default_route = default_route
|
|
self.allow_empty_key = allow_empty_key
|
|
|
|
@classmethod
|
|
def for_query_document(
|
|
cls,
|
|
query_modules: list[Module],
|
|
document_modules: list[Module],
|
|
default_route: str | None = None,
|
|
allow_empty_key: bool = True,
|
|
) -> Self:
|
|
"""
|
|
Creates a Router model specifically for query and document modules, allowing convenient usage via `model.encode_query`
|
|
and `model.encode_document`.
|
|
|
|
Args:
|
|
query_modules: List of modules to be applied for the "query" task type.
|
|
document_modules: List of modules to be applied for the "document" task type.
|
|
default_route: The default route to use if no task type is specified. If None, an exception will be thrown
|
|
if no task type is specified. If ``allow_empty_key`` is True, the first key in sub_modules will be used as
|
|
the default route. Defaults to None.
|
|
allow_empty_key: If True, allows the default route to be set to the first key in `sub_modules` if
|
|
``default_route`` is None. Defaults to True.
|
|
|
|
Returns:
|
|
Router: An instance of the Router model with the specified query and document modules.
|
|
"""
|
|
return cls(
|
|
sub_modules={"query": query_modules, "document": document_modules},
|
|
default_route=default_route or "document",
|
|
allow_empty_key=allow_empty_key,
|
|
)
|
|
|
|
def forward(self, features: dict[str, Tensor], task: str | None = None, **kwargs) -> dict[str, Tensor]:
|
|
if task is None:
|
|
task = features.get("task", self.default_route)
|
|
if task is None:
|
|
if self.training:
|
|
raise ValueError(
|
|
"You must provide a `router_mapping` argument on the training arguments, "
|
|
"or set a default route in the `Router` module."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"You must provide a `task` argument when calling this method, "
|
|
"or set a default route in the `Router` module."
|
|
)
|
|
|
|
if task not in self.sub_modules:
|
|
raise ValueError(
|
|
f"No route found for task type '{task}'. Available routes: {list(self.sub_modules.keys())}"
|
|
)
|
|
|
|
kwargs["task"] = task
|
|
for module in self.sub_modules[task]:
|
|
module_kwargs = {
|
|
key: value
|
|
for key, value in kwargs.items()
|
|
if hasattr(module, "forward_kwargs") and key in module.forward_kwargs
|
|
}
|
|
features = module(features, **module_kwargs)
|
|
return features
|
|
|
|
def get_sentence_embedding_dimension(self) -> int:
|
|
for sub_modules in self.sub_modules.values():
|
|
for module in reversed(sub_modules):
|
|
if hasattr(module, "get_sentence_embedding_dimension"):
|
|
return module.get_sentence_embedding_dimension()
|
|
return None
|
|
|
|
def save(self, output_path: str, safe_serialization: bool = True, **kwargs):
|
|
model_lookup = {}
|
|
model_types = {}
|
|
model_structure = {}
|
|
|
|
for name, models in self.sub_modules.items():
|
|
model_structure[name] = []
|
|
for module_idx, model in enumerate(models):
|
|
model_id = f"{name}_{module_idx}_{type(model).__name__}"
|
|
model_lookup[model_id] = model
|
|
model_types[model_id] = f"{type(model).__module__}.{type(model).__name__}"
|
|
model_structure[name].append(model_id)
|
|
|
|
for model_id, model in model_lookup.items():
|
|
model_path = os.path.join(output_path, str(model_id))
|
|
os.makedirs(model_path, exist_ok=True)
|
|
try:
|
|
model.save(model_path, safe_serialization=safe_serialization, **kwargs)
|
|
except TypeError:
|
|
# Fallback for legacy models that do not support kwargs
|
|
model.save(model_path)
|
|
|
|
with open(os.path.join(output_path, self.config_file_name), "w", encoding="utf8") as fOut:
|
|
json.dump(
|
|
{
|
|
"types": model_types,
|
|
"structure": model_structure,
|
|
"parameters": self.get_config_dict(),
|
|
},
|
|
fOut,
|
|
indent=4,
|
|
)
|
|
|
|
def tokenize(self, texts: list[str] | list[tuple[str, str]], task: str | None = None, **kwargs):
|
|
"""Tokenizes a text and maps tokens to token-ids"""
|
|
if isinstance(texts[0], dict):
|
|
# Extract the task type key from the dictionaries
|
|
if task is None:
|
|
tasks = set(key for text in texts for key in text.keys())
|
|
if len(tasks) > 1:
|
|
raise ValueError(
|
|
"You cannot pass a list of dictionaries with different task types. "
|
|
"Please ensure all dictionaries have the same task type key, or pass a single `task` argument."
|
|
)
|
|
task = tasks.pop()
|
|
|
|
# Remove dictionary structure
|
|
texts = [text[task] for text in texts]
|
|
|
|
if task is None:
|
|
task = self.default_route
|
|
if task is None:
|
|
if self.training:
|
|
raise ValueError(
|
|
"You must provide a `router_mapping` argument on the training arguments, "
|
|
"or set a default route in the `Router` module."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"You must provide a `task` argument when calling this method, "
|
|
"or set a default route in the `Router` module."
|
|
)
|
|
if task not in self.sub_modules:
|
|
raise ValueError(
|
|
f"No route found for task type '{task}'. Available routes: {list(self.sub_modules.keys())}"
|
|
)
|
|
|
|
input_module = self.sub_modules[task][0]
|
|
tokenized = input_module.tokenize(texts, **kwargs)
|
|
tokenized["task"] = task
|
|
return tokenized
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
model_name_or_path: str,
|
|
subfolder: str = "",
|
|
token: bool | str | None = None,
|
|
cache_folder: str | None = None,
|
|
revision: str | None = None,
|
|
local_files_only: bool = False,
|
|
**kwargs,
|
|
) -> Self:
|
|
hub_kwargs = {
|
|
"token": token,
|
|
"cache_folder": cache_folder,
|
|
"revision": revision,
|
|
"local_files_only": local_files_only,
|
|
}
|
|
# Try the official config file first, then fall back to the legacy config file
|
|
config = cls.load_config(model_name_or_path=model_name_or_path, subfolder=subfolder, **hub_kwargs)
|
|
if not config:
|
|
config = cls.load_config(
|
|
model_name_or_path=model_name_or_path, config_filename="config.json", subfolder=subfolder, **hub_kwargs
|
|
)
|
|
modules = {}
|
|
for model_id, model_type in config["types"].items():
|
|
module_class: Module = import_from_string(model_type)
|
|
try:
|
|
module = module_class.load(
|
|
model_name_or_path, subfolder=Path(subfolder, model_id).as_posix(), **hub_kwargs, **kwargs
|
|
)
|
|
except TypeError:
|
|
local_path = load_dir_path(
|
|
model_name_or_path=model_name_or_path, subfolder=Path(subfolder, model_id).as_posix(), **hub_kwargs
|
|
)
|
|
module = module_class.load(local_path)
|
|
modules[model_id] = module
|
|
|
|
model_structure = {}
|
|
for key_name, models_list in config["structure"].items():
|
|
model_structure[key_name] = []
|
|
for model_id in models_list:
|
|
model_structure[key_name].append(modules[model_id])
|
|
|
|
model = cls(model_structure, **config["parameters"])
|
|
return model
|
|
|
|
@property
|
|
def tokenizer(self):
|
|
# We might have multiple tokenizers, one for each route, but we can only return one here.
|
|
for sub_modules in self.sub_modules.values():
|
|
input_module: InputModule = sub_modules[0]
|
|
if hasattr(input_module, "tokenizer") and input_module.tokenizer is not None:
|
|
return input_module.tokenizer
|
|
return None
|
|
|
|
@property
|
|
def max_seq_length(self) -> int:
|
|
# Collect all unique max_seq_length values
|
|
max_seq_lengths = set()
|
|
for modules in self.sub_modules.values():
|
|
input_module: InputModule = modules[0]
|
|
if modules and hasattr(input_module, "max_seq_length"):
|
|
max_seq_lengths.add(input_module.max_seq_length)
|
|
|
|
if not max_seq_lengths:
|
|
return None
|
|
elif len(max_seq_lengths) == 1:
|
|
# Only one unique max_seq_length
|
|
return max_seq_lengths.pop()
|
|
else:
|
|
logger.warning_once(f"Different max_seq_lengths detected: {max_seq_lengths}. Using the maximum value.")
|
|
return max(max_seq_lengths)
|
|
|
|
@max_seq_length.setter
|
|
def max_seq_length(self, value) -> None:
|
|
# Check which modules have max_seq_length
|
|
has_max_seq_length_keys = []
|
|
for key, models in self.sub_modules.items():
|
|
if models and hasattr(models[0], "max_seq_length"):
|
|
has_max_seq_length_keys.append(key)
|
|
|
|
if len(has_max_seq_length_keys) == 0:
|
|
logger.warning("No modules have a max_seq_length attribute to set.")
|
|
return
|
|
|
|
for key in has_max_seq_length_keys:
|
|
input_module: InputModule = self.sub_modules[key][0]
|
|
input_module.max_seq_length = value
|
|
|
|
|
|
# For backwards compatibility, we ensure that the legacy `Asym` alias points to the new `Router` class.
|
|
Asym = Router
|