You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
5.0 KiB
121 lines
5.0 KiB
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING, Literal
|
|
|
|
from sentence_transformers.backend.utils import save_or_push_to_hub_model
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from sentence_transformers import CrossEncoder, SentenceTransformer, SparseEncoder
|
|
|
|
try:
|
|
from optimum.onnxruntime.configuration import OptimizationConfig
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def export_optimized_onnx_model(
|
|
model: SentenceTransformer | SparseEncoder | CrossEncoder,
|
|
optimization_config: OptimizationConfig | Literal["O1", "O2", "O3", "O4"],
|
|
model_name_or_path: str,
|
|
push_to_hub: bool = False,
|
|
create_pr: bool = False,
|
|
file_suffix: str | None = None,
|
|
) -> None:
|
|
"""
|
|
Export an optimized ONNX model from a SentenceTransformer, SparseEncoder, or CrossEncoder model.
|
|
|
|
The O1-O4 optimization levels are defined by Optimum and are documented here:
|
|
https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/optimization
|
|
|
|
The optimization levels are:
|
|
|
|
- O1: basic general optimizations.
|
|
- O2: basic and extended general optimizations, transformers-specific fusions.
|
|
- O3: same as O2 with GELU approximation.
|
|
- O4: same as O3 with mixed precision (fp16, GPU-only)
|
|
|
|
See the following pages for more information & benchmarks:
|
|
|
|
- `Sentence Transformer > Usage > Speeding up Inference <https://sbert.net/docs/sentence_transformer/usage/efficiency.html>`_
|
|
- `Cross Encoder > Usage > Speeding up Inference <https://sbert.net/docs/cross_encoder/usage/efficiency.html>`_
|
|
|
|
Args:
|
|
model (SentenceTransformer | SparseEncoder | CrossEncoder): The SentenceTransformer, SparseEncoder,
|
|
or CrossEncoder model to be optimized. Must be loaded with `backend="onnx"`.
|
|
optimization_config (OptimizationConfig | Literal["O1", "O2", "O3", "O4"]): The optimization configuration or level.
|
|
model_name_or_path (str): The path or Hugging Face Hub repository name where the optimized model will be saved.
|
|
push_to_hub (bool, optional): Whether to push the optimized model to the Hugging Face Hub. Defaults to False.
|
|
create_pr (bool, optional): Whether to create a pull request when pushing to the Hugging Face Hub. Defaults to False.
|
|
file_suffix (str | None, optional): The suffix to add to the optimized model file name. Defaults to None.
|
|
|
|
Raises:
|
|
ImportError: If the required packages `optimum` and `onnxruntime` are not installed.
|
|
ValueError: If the provided model is not a valid SentenceTransformer, SparseEncoder, or CrossEncoder model loaded with `backend="onnx"`.
|
|
ValueError: If the provided optimization_config is not valid.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
from sentence_transformers import CrossEncoder, SentenceTransformer, SparseEncoder
|
|
|
|
try:
|
|
from optimum.onnxruntime import ORTModel, ORTOptimizer
|
|
from optimum.onnxruntime.configuration import AutoOptimizationConfig
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install Optimum and ONNX Runtime to use this function. "
|
|
"You can install them with pip: `pip install sentence-transformers[onnx]` "
|
|
"or `pip install sentence-transformers[onnx-gpu]`"
|
|
)
|
|
|
|
viable_st_model = (
|
|
isinstance(model, SentenceTransformer)
|
|
and len(model)
|
|
and hasattr(model[0], "auto_model")
|
|
and isinstance(model[0].auto_model, ORTModel)
|
|
)
|
|
viable_se_model = (
|
|
isinstance(model, SparseEncoder)
|
|
and len(model)
|
|
and hasattr(model[0], "auto_model")
|
|
and isinstance(model[0].auto_model, ORTModel)
|
|
)
|
|
viable_ce_model = isinstance(model, CrossEncoder) and isinstance(model.model, ORTModel)
|
|
if not (viable_st_model or viable_ce_model or viable_se_model):
|
|
raise ValueError(
|
|
'The model must be a Transformer-based SentenceTransformer, SparseEncoder, or CrossEncoder model loaded with `backend="onnx"`.'
|
|
)
|
|
|
|
if viable_st_model or viable_se_model:
|
|
ort_model: ORTModel = model[0].auto_model
|
|
else:
|
|
ort_model: ORTModel = model.model
|
|
optimizer = ORTOptimizer.from_pretrained(ort_model)
|
|
|
|
if isinstance(optimization_config, str):
|
|
if optimization_config not in AutoOptimizationConfig._LEVELS:
|
|
raise ValueError(
|
|
"optimization_config must be an OptimizationConfig instance or one of 'O1', 'O2', 'O3', 'O4'."
|
|
)
|
|
|
|
file_suffix = file_suffix or optimization_config
|
|
optimization_config = getattr(AutoOptimizationConfig, optimization_config)()
|
|
|
|
if file_suffix is None:
|
|
file_suffix = "optimized"
|
|
|
|
save_or_push_to_hub_model(
|
|
export_function=lambda save_dir: optimizer.optimize(optimization_config, save_dir, file_suffix=file_suffix),
|
|
export_function_name="export_optimized_onnx_model",
|
|
config=optimization_config,
|
|
model_name_or_path=model_name_or_path,
|
|
push_to_hub=push_to_hub,
|
|
create_pr=create_pr,
|
|
file_suffix=file_suffix,
|
|
backend="onnx",
|
|
model=model,
|
|
)
|