You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

327 lines
13 KiB

from __future__ import annotations
import logging
import os
import shutil
import tempfile
from collections.abc import Callable
from fnmatch import fnmatch
from pathlib import Path
from typing import TYPE_CHECKING, Any
import huggingface_hub
from huggingface_hub import list_repo_files
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder, SentenceTransformer, SparseEncoder
logger = logging.getLogger(__name__)
def _save_pretrained_wrapper(_save_pretrained_fn: Callable, subfolder: str) -> Callable[..., None]:
"""
Wraps the save_pretrained method of a model to save to a subfolder.
Args:
_save_pretrained_fn: The original save_pretrained function
subfolder: The subfolder to save to
Returns:
A wrapped function that saves to the specified subfolder
"""
def wrapper(save_directory: str | Path, **kwargs) -> None:
os.makedirs(Path(save_directory) / subfolder, exist_ok=True)
return _save_pretrained_fn(Path(save_directory) / subfolder, **kwargs)
return wrapper
def backend_should_export(
load_path: Path,
is_local: bool,
model_kwargs: dict[str, Any],
target_file_name: str,
target_file_glob: str,
backend_name: str,
) -> tuple[bool, dict[str, Any]]:
"""
Determines whether the model should be exported to the backend, or if it can be loaded directly.
Also update the `file_name` and `subfolder` model_kwargs if necessary.
These are the cases:
1. If export is set in model_kwargs, just return export
2. If `<subfolder>/<file_name>` exists; set export to False
3. If `<backend>/<file_name>` exists; set export to False and set subfolder to the backend (e.g. "onnx")
4. If `<file_name>` contains a folder, add those folders to the subfolder and set the file_name to the last part
We will warn if:
1. The expected file does not exist in the model directory given the optional file_name and subfolder.
If there are valid files for this backend, but they're don't align with file_name, then we give a useful warning.
2. Multiple files are found in the model directory that match the target file name and the user did not
specify the desired file name via `model_kwargs={"file_name": "<file_name>"}`
Args:
load_path: The model repository or directory, as a Path instance
is_local: Whether the model is local or remote, i.e. whether load_path is a local directory
model_kwargs: The model_kwargs dictionary. Notable keys are "export", "file_name", and "subfolder"
target_file_name: The expected file name in the model directory, e.g. "model.onnx" or "openvino_model.xml"
target_file_glob: The glob pattern to match the target file name, e.g. "*.onnx" or "openvino*.xml"
backend_name: The human-readable name of the backend for use in warnings, e.g. "ONNX" or "OpenVINO"
Returns:
Tuple[bool, dict[str, Any]]: A tuple of the export boolean and the updated model_kwargs dictionary.
Notable keys in model_kwargs are "export", "file_name", and "subfolder".
"""
export = model_kwargs.pop("export", None)
if export:
return export, model_kwargs
backend = backend_name.lower()
file_name = model_kwargs.get("file_name", target_file_name)
subfolder = model_kwargs.get("subfolder", None)
primary_full_path = Path(subfolder, file_name).as_posix() if subfolder else Path(file_name).as_posix()
secondary_full_path = (
Path(subfolder, backend, file_name).as_posix() if subfolder else Path(backend, file_name).as_posix()
)
glob_pattern = f"{subfolder}/**/{target_file_glob}" if subfolder else f"**/{target_file_glob}"
# Get the list of files in the model directory that match the target file name
if is_local:
model_file_names = [path.relative_to(load_path).as_posix() for path in load_path.glob(glob_pattern)]
else:
all_files = list_repo_files(
load_path.as_posix(),
repo_type="model",
revision=model_kwargs.get("revision", None),
token=model_kwargs.get("token", None),
)
model_file_names = [fname for fname in all_files if fnmatch(fname, glob_pattern)]
# First check if the expected file exists in the root of the model directory
# If it doesn't, check if it exists in the backend subfolder.
# If it does, set the subfolder to include the backend
model_found = primary_full_path in model_file_names
if not model_found:
model_found = secondary_full_path in model_file_names
if model_found:
if len(model_file_names) > 1 and "file_name" not in model_kwargs:
logger.warning(
f"Multiple {backend_name} files found in {load_path.as_posix()!r}: {model_file_names}, defaulting to {secondary_full_path!r}. "
f'Please specify the desired file name via `model_kwargs={{"file_name": "<file_name>"}}`.'
)
model_kwargs["subfolder"] = Path(subfolder, backend).as_posix() if subfolder else backend
model_kwargs["file_name"] = file_name
if export is None:
export = not model_found
# If the file_name contains subfolders, set it as the subfolder instead
file_name_parts = Path(file_name).parts
if len(file_name_parts) > 1:
model_kwargs["file_name"] = file_name_parts[-1]
model_kwargs["subfolder"] = Path(model_kwargs.get("subfolder", ""), *file_name_parts[:-1]).as_posix()
if export:
logger.warning(f"No {file_name!r} found in {load_path.as_posix()!r}. Exporting the model to {backend_name}.")
if model_file_names:
logger.warning(
f"If you intended to load one of the {model_file_names} {backend_name} files, "
f'please specify the desired file name via `model_kwargs={{"file_name": "{model_file_names[0]}"}}`.'
)
return export, model_kwargs
def backend_warn_to_save(model_name_or_path: str, is_local: bool, backend_name: str) -> None:
"""
Warns the user to save the model if they just exported it.
Args:
model_name_or_path: The model name or path
is_local: Whether the model is local
backend_name: The name of the backend (ONNX or OpenVINO)
"""
to_log = f"Saving the exported {backend_name} model is heavily recommended to avoid having to export it again."
if is_local:
to_log += f" Do so with `model.save_pretrained({model_name_or_path!r})`."
else:
to_log += f" Do so with `model.push_to_hub({model_name_or_path!r}, create_pr=True)`."
logger.warning(to_log)
def save_or_push_to_hub_model(
export_function: Callable,
export_function_name: str,
config,
model_name_or_path: str,
push_to_hub: bool = False,
create_pr: bool = False,
file_suffix: str | None = None,
backend: str = "onnx",
model: SentenceTransformer | SparseEncoder | CrossEncoder | None = None,
):
from sentence_transformers import CrossEncoder, SentenceTransformer, SparseEncoder
if backend == "onnx":
file_name = f"model_{file_suffix}.onnx"
elif backend == "openvino":
file_name = f"openvino_model_{file_suffix}.xml"
with tempfile.TemporaryDirectory() as save_dir:
export_function(save_dir)
# OpenVINO models are saved in a nested directory
if backend == "openvino":
save_dir = Path(save_dir) / backend
# and we need to attach the file_suffix for both the .xml and .bin files
shutil.move(save_dir / "openvino_model.xml", save_dir / file_name)
shutil.move(save_dir / "openvino_model.bin", (save_dir / file_name).with_suffix(".bin"))
save_dir = save_dir.as_posix()
# Because we upload folders and save_dir now has unnecessary files (tokenizer.json, config.json, etc.),
# we move the main file to a nested directory
if backend == "onnx":
dst_dir = Path(save_dir) / backend
dst_dir.mkdir(parents=True, exist_ok=True)
source = Path(save_dir) / file_name
destination = dst_dir / file_name
shutil.move(source, destination)
save_dir = dst_dir.as_posix()
if push_to_hub:
commit_description = ""
if create_pr:
opt_config_string = repr(config).replace("(", "(\n\t").replace(", ", ",\n\t").replace(")", "\n)")
if isinstance(model, SparseEncoder):
commit_description = f"""\
Hello!
*This pull request has been automatically generated from the [`{export_function_name}`](https://sbert.net/docs/package_reference/util.html#sentence_transformers.backend.{export_function_name}) function from the Sentence Transformers library.*
## Config
```python
{opt_config_string}
```
## Tip:
Consider testing this pull request before merging by loading the model from this PR with the `revision` argument:
```python
from sentence_transformers import SparseEncoder
# TODO: Fill in the PR number
pr_number = 2
model = SparseEncoder(
"{model_name_or_path}",
revision=f"refs/pr/{{pr_number}}",
backend="{backend}",
model_kwargs={{"file_name": "{file_name}"}},
)
# Verify that everything works as expected
embeddings = model.encode(["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."])
print(embeddings.shape)
similarities = model.similarity(embeddings, embeddings)
print(similarities)
```
"""
elif model is None or isinstance(model, SentenceTransformer):
commit_description = f"""\
Hello!
*This pull request has been automatically generated from the [`{export_function_name}`](https://sbert.net/docs/package_reference/util.html#sentence_transformers.backend.{export_function_name}) function from the Sentence Transformers library.*
## Config
```python
{opt_config_string}
```
## Tip:
Consider testing this pull request before merging by loading the model from this PR with the `revision` argument:
```python
from sentence_transformers import SentenceTransformer
# TODO: Fill in the PR number
pr_number = 2
model = SentenceTransformer(
"{model_name_or_path}",
revision=f"refs/pr/{{pr_number}}",
backend="{backend}",
model_kwargs={{"file_name": "{file_name}"}},
)
# Verify that everything works as expected
embeddings = model.encode(["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."])
print(embeddings.shape)
similarities = model.similarity(embeddings, embeddings)
print(similarities)
```
"""
elif isinstance(model, CrossEncoder):
commit_description = f"""\
Hello!
*This pull request has been automatically generated from the [`{export_function_name}`](https://sbert.net/docs/package_reference/util.html#sentence_transformers.backend.{export_function_name}) function from the Sentence Transformers library.*
## Config
```python
{opt_config_string}
```
## Tip:
Consider testing this pull request before merging by loading the model from this PR with the `revision` argument:
```python
from sentence_transformers import CrossEncoder
# TODO: Fill in the PR number
pr_number = 2
model = CrossEncoder(
"{model_name_or_path}",
revision=f"refs/pr/{{pr_number}}",
backend="{backend}",
model_kwargs={{"file_name": "{file_name}"}},
)
# Verify that everything works as expected
query = "Which planet is known as the Red Planet?"
passages = [
"Venus is often called Earth's twin because of its similar size and proximity.",
"Mars, known for its reddish appearance, is often referred to as the Red Planet.",
"Jupiter, the largest planet in our solar system, has a prominent red spot.",
"Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]
scores = model.predict([(query, passage) for passage in passages])
print(scores)
```
"""
huggingface_hub.upload_folder(
folder_path=save_dir,
path_in_repo=backend,
repo_id=model_name_or_path,
repo_type="model",
commit_message=f"Add exported {backend} model {file_name!r}",
commit_description=commit_description,
create_pr=create_pr,
)
else:
dst_dir = Path(model_name_or_path) / backend
# Create destination if it does not exist
dst_dir.mkdir(parents=True, exist_ok=True)
source = Path(save_dir) / file_name
destination = dst_dir / file_name
shutil.copy(source, destination)
# OpenVINO has a second file to save: the .bin file
if backend == "openvino":
bin_source = (Path(save_dir) / file_name).with_suffix(".bin")
bin_destination = (Path(dst_dir) / file_name).with_suffix(".bin")
shutil.copy(bin_source, bin_destination)