from __future__ import annotations import logging import os import shutil import tempfile from collections.abc import Callable from fnmatch import fnmatch from pathlib import Path from typing import TYPE_CHECKING, Any import huggingface_hub from huggingface_hub import list_repo_files if TYPE_CHECKING: from sentence_transformers import CrossEncoder, SentenceTransformer, SparseEncoder logger = logging.getLogger(__name__) def _save_pretrained_wrapper(_save_pretrained_fn: Callable, subfolder: str) -> Callable[..., None]: """ Wraps the save_pretrained method of a model to save to a subfolder. Args: _save_pretrained_fn: The original save_pretrained function subfolder: The subfolder to save to Returns: A wrapped function that saves to the specified subfolder """ def wrapper(save_directory: str | Path, **kwargs) -> None: os.makedirs(Path(save_directory) / subfolder, exist_ok=True) return _save_pretrained_fn(Path(save_directory) / subfolder, **kwargs) return wrapper def backend_should_export( load_path: Path, is_local: bool, model_kwargs: dict[str, Any], target_file_name: str, target_file_glob: str, backend_name: str, ) -> tuple[bool, dict[str, Any]]: """ Determines whether the model should be exported to the backend, or if it can be loaded directly. Also update the `file_name` and `subfolder` model_kwargs if necessary. These are the cases: 1. If export is set in model_kwargs, just return export 2. If `/` exists; set export to False 3. If `/` exists; set export to False and set subfolder to the backend (e.g. "onnx") 4. If `` contains a folder, add those folders to the subfolder and set the file_name to the last part We will warn if: 1. The expected file does not exist in the model directory given the optional file_name and subfolder. If there are valid files for this backend, but they're don't align with file_name, then we give a useful warning. 2. Multiple files are found in the model directory that match the target file name and the user did not specify the desired file name via `model_kwargs={"file_name": ""}` Args: load_path: The model repository or directory, as a Path instance is_local: Whether the model is local or remote, i.e. whether load_path is a local directory model_kwargs: The model_kwargs dictionary. Notable keys are "export", "file_name", and "subfolder" target_file_name: The expected file name in the model directory, e.g. "model.onnx" or "openvino_model.xml" target_file_glob: The glob pattern to match the target file name, e.g. "*.onnx" or "openvino*.xml" backend_name: The human-readable name of the backend for use in warnings, e.g. "ONNX" or "OpenVINO" Returns: Tuple[bool, dict[str, Any]]: A tuple of the export boolean and the updated model_kwargs dictionary. Notable keys in model_kwargs are "export", "file_name", and "subfolder". """ export = model_kwargs.pop("export", None) if export: return export, model_kwargs backend = backend_name.lower() file_name = model_kwargs.get("file_name", target_file_name) subfolder = model_kwargs.get("subfolder", None) primary_full_path = Path(subfolder, file_name).as_posix() if subfolder else Path(file_name).as_posix() secondary_full_path = ( Path(subfolder, backend, file_name).as_posix() if subfolder else Path(backend, file_name).as_posix() ) glob_pattern = f"{subfolder}/**/{target_file_glob}" if subfolder else f"**/{target_file_glob}" # Get the list of files in the model directory that match the target file name if is_local: model_file_names = [path.relative_to(load_path).as_posix() for path in load_path.glob(glob_pattern)] else: all_files = list_repo_files( load_path.as_posix(), repo_type="model", revision=model_kwargs.get("revision", None), token=model_kwargs.get("token", None), ) model_file_names = [fname for fname in all_files if fnmatch(fname, glob_pattern)] # First check if the expected file exists in the root of the model directory # If it doesn't, check if it exists in the backend subfolder. # If it does, set the subfolder to include the backend model_found = primary_full_path in model_file_names if not model_found: model_found = secondary_full_path in model_file_names if model_found: if len(model_file_names) > 1 and "file_name" not in model_kwargs: logger.warning( f"Multiple {backend_name} files found in {load_path.as_posix()!r}: {model_file_names}, defaulting to {secondary_full_path!r}. " f'Please specify the desired file name via `model_kwargs={{"file_name": ""}}`.' ) model_kwargs["subfolder"] = Path(subfolder, backend).as_posix() if subfolder else backend model_kwargs["file_name"] = file_name if export is None: export = not model_found # If the file_name contains subfolders, set it as the subfolder instead file_name_parts = Path(file_name).parts if len(file_name_parts) > 1: model_kwargs["file_name"] = file_name_parts[-1] model_kwargs["subfolder"] = Path(model_kwargs.get("subfolder", ""), *file_name_parts[:-1]).as_posix() if export: logger.warning(f"No {file_name!r} found in {load_path.as_posix()!r}. Exporting the model to {backend_name}.") if model_file_names: logger.warning( f"If you intended to load one of the {model_file_names} {backend_name} files, " f'please specify the desired file name via `model_kwargs={{"file_name": "{model_file_names[0]}"}}`.' ) return export, model_kwargs def backend_warn_to_save(model_name_or_path: str, is_local: bool, backend_name: str) -> None: """ Warns the user to save the model if they just exported it. Args: model_name_or_path: The model name or path is_local: Whether the model is local backend_name: The name of the backend (ONNX or OpenVINO) """ to_log = f"Saving the exported {backend_name} model is heavily recommended to avoid having to export it again." if is_local: to_log += f" Do so with `model.save_pretrained({model_name_or_path!r})`." else: to_log += f" Do so with `model.push_to_hub({model_name_or_path!r}, create_pr=True)`." logger.warning(to_log) def save_or_push_to_hub_model( export_function: Callable, export_function_name: str, config, model_name_or_path: str, push_to_hub: bool = False, create_pr: bool = False, file_suffix: str | None = None, backend: str = "onnx", model: SentenceTransformer | SparseEncoder | CrossEncoder | None = None, ): from sentence_transformers import CrossEncoder, SentenceTransformer, SparseEncoder if backend == "onnx": file_name = f"model_{file_suffix}.onnx" elif backend == "openvino": file_name = f"openvino_model_{file_suffix}.xml" with tempfile.TemporaryDirectory() as save_dir: export_function(save_dir) # OpenVINO models are saved in a nested directory if backend == "openvino": save_dir = Path(save_dir) / backend # and we need to attach the file_suffix for both the .xml and .bin files shutil.move(save_dir / "openvino_model.xml", save_dir / file_name) shutil.move(save_dir / "openvino_model.bin", (save_dir / file_name).with_suffix(".bin")) save_dir = save_dir.as_posix() # Because we upload folders and save_dir now has unnecessary files (tokenizer.json, config.json, etc.), # we move the main file to a nested directory if backend == "onnx": dst_dir = Path(save_dir) / backend dst_dir.mkdir(parents=True, exist_ok=True) source = Path(save_dir) / file_name destination = dst_dir / file_name shutil.move(source, destination) save_dir = dst_dir.as_posix() if push_to_hub: commit_description = "" if create_pr: opt_config_string = repr(config).replace("(", "(\n\t").replace(", ", ",\n\t").replace(")", "\n)") if isinstance(model, SparseEncoder): commit_description = f"""\ Hello! *This pull request has been automatically generated from the [`{export_function_name}`](https://sbert.net/docs/package_reference/util.html#sentence_transformers.backend.{export_function_name}) function from the Sentence Transformers library.* ## Config ```python {opt_config_string} ``` ## Tip: Consider testing this pull request before merging by loading the model from this PR with the `revision` argument: ```python from sentence_transformers import SparseEncoder # TODO: Fill in the PR number pr_number = 2 model = SparseEncoder( "{model_name_or_path}", revision=f"refs/pr/{{pr_number}}", backend="{backend}", model_kwargs={{"file_name": "{file_name}"}}, ) # Verify that everything works as expected embeddings = model.encode(["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."]) print(embeddings.shape) similarities = model.similarity(embeddings, embeddings) print(similarities) ``` """ elif model is None or isinstance(model, SentenceTransformer): commit_description = f"""\ Hello! *This pull request has been automatically generated from the [`{export_function_name}`](https://sbert.net/docs/package_reference/util.html#sentence_transformers.backend.{export_function_name}) function from the Sentence Transformers library.* ## Config ```python {opt_config_string} ``` ## Tip: Consider testing this pull request before merging by loading the model from this PR with the `revision` argument: ```python from sentence_transformers import SentenceTransformer # TODO: Fill in the PR number pr_number = 2 model = SentenceTransformer( "{model_name_or_path}", revision=f"refs/pr/{{pr_number}}", backend="{backend}", model_kwargs={{"file_name": "{file_name}"}}, ) # Verify that everything works as expected embeddings = model.encode(["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."]) print(embeddings.shape) similarities = model.similarity(embeddings, embeddings) print(similarities) ``` """ elif isinstance(model, CrossEncoder): commit_description = f"""\ Hello! *This pull request has been automatically generated from the [`{export_function_name}`](https://sbert.net/docs/package_reference/util.html#sentence_transformers.backend.{export_function_name}) function from the Sentence Transformers library.* ## Config ```python {opt_config_string} ``` ## Tip: Consider testing this pull request before merging by loading the model from this PR with the `revision` argument: ```python from sentence_transformers import CrossEncoder # TODO: Fill in the PR number pr_number = 2 model = CrossEncoder( "{model_name_or_path}", revision=f"refs/pr/{{pr_number}}", backend="{backend}", model_kwargs={{"file_name": "{file_name}"}}, ) # Verify that everything works as expected query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) ``` """ huggingface_hub.upload_folder( folder_path=save_dir, path_in_repo=backend, repo_id=model_name_or_path, repo_type="model", commit_message=f"Add exported {backend} model {file_name!r}", commit_description=commit_description, create_pr=create_pr, ) else: dst_dir = Path(model_name_or_path) / backend # Create destination if it does not exist dst_dir.mkdir(parents=True, exist_ok=True) source = Path(save_dir) / file_name destination = dst_dir / file_name shutil.copy(source, destination) # OpenVINO has a second file to save: the .bin file if backend == "openvino": bin_source = (Path(save_dir) / file_name).with_suffix(".bin") bin_destination = (Path(dst_dir) / file_name).with_suffix(".bin") shutil.copy(bin_source, bin_destination)