You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
208 lines
7.1 KiB
208 lines
7.1 KiB
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from huggingface_hub import hf_hub_download, snapshot_download
|
|
from tqdm.autonotebook import tqdm
|
|
|
|
|
|
class disabled_tqdm(tqdm):
|
|
"""
|
|
Class to override `disable` argument in case progress bars are globally disabled.
|
|
|
|
Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs["disable"] = True
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def __delattr__(self, attr: str) -> None:
|
|
"""Fix for https://github.com/huggingface/huggingface_hub/issues/1603"""
|
|
try:
|
|
super().__delattr__(attr)
|
|
except AttributeError:
|
|
if attr != "_lock":
|
|
raise
|
|
|
|
|
|
def is_sentence_transformer_model(
|
|
model_name_or_path: str,
|
|
token: bool | str | None = None,
|
|
cache_folder: str | None = None,
|
|
revision: str | None = None,
|
|
local_files_only: bool = False,
|
|
) -> bool:
|
|
"""
|
|
Checks if the given model name or path corresponds to a SentenceTransformer model.
|
|
|
|
Args:
|
|
model_name_or_path (str): The name or path of the model.
|
|
token (Optional[Union[bool, str]]): The token to be used for authentication. Defaults to None.
|
|
cache_folder (Optional[str]): The folder to cache the model files. Defaults to None.
|
|
revision (Optional[str]): The revision of the model. Defaults to None.
|
|
local_files_only (bool): Whether to only use local files for the model. Defaults to False.
|
|
|
|
Returns:
|
|
bool: True if the model is a SentenceTransformer model, False otherwise.
|
|
"""
|
|
return bool(
|
|
load_file_path(
|
|
model_name_or_path,
|
|
"modules.json",
|
|
token=token,
|
|
cache_folder=cache_folder,
|
|
revision=revision,
|
|
local_files_only=local_files_only,
|
|
)
|
|
)
|
|
|
|
|
|
def load_file_path(
|
|
model_name_or_path: str,
|
|
filename: str | Path,
|
|
subfolder: str = "",
|
|
token: bool | str | None = None,
|
|
cache_folder: str | None = None,
|
|
revision: str | None = None,
|
|
local_files_only: bool = False,
|
|
) -> str | None:
|
|
"""
|
|
Loads a file from a local or remote location.
|
|
|
|
Args:
|
|
model_name_or_path (str): The model name or path.
|
|
filename (str): The name of the file to load.
|
|
subfolder (str): The subfolder within the model subfolder (if applicable).
|
|
token (Optional[Union[bool, str]]): The token to access the remote file (if applicable).
|
|
cache_folder (Optional[str]): The folder to cache the downloaded file (if applicable).
|
|
revision (Optional[str], optional): The revision of the file (if applicable). Defaults to None.
|
|
local_files_only (bool, optional): Whether to only consider local files. Defaults to False.
|
|
|
|
Returns:
|
|
Optional[str]: The path to the loaded file, or None if the file could not be found or loaded.
|
|
"""
|
|
# If file is local
|
|
file_path = Path(model_name_or_path, subfolder, filename)
|
|
if file_path.exists():
|
|
return str(file_path)
|
|
|
|
# If file is remote
|
|
file_path = Path(subfolder, filename)
|
|
try:
|
|
return hf_hub_download(
|
|
model_name_or_path,
|
|
filename=file_path.name,
|
|
subfolder=file_path.parent.as_posix(),
|
|
revision=revision,
|
|
library_name="sentence-transformers",
|
|
token=token,
|
|
cache_dir=cache_folder,
|
|
local_files_only=local_files_only,
|
|
)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def load_dir_path(
|
|
model_name_or_path: str,
|
|
subfolder: str,
|
|
token: bool | str | None = None,
|
|
cache_folder: str | None = None,
|
|
revision: str | None = None,
|
|
local_files_only: bool = False,
|
|
) -> str | None:
|
|
"""
|
|
Loads the subfolder path for a given model name or path.
|
|
|
|
Args:
|
|
model_name_or_path (str): The name or path of the model.
|
|
subfolder (str): The subfolder to load.
|
|
token (Optional[Union[bool, str]]): The token for authentication.
|
|
cache_folder (Optional[str]): The folder to cache the downloaded files.
|
|
revision (Optional[str], optional): The revision of the model. Defaults to None.
|
|
local_files_only (bool, optional): Whether to only use local files. Defaults to False.
|
|
|
|
Returns:
|
|
Optional[str]: The subfolder path if it exists, otherwise None.
|
|
"""
|
|
if isinstance(subfolder, Path):
|
|
subfolder = subfolder.as_posix()
|
|
|
|
# If file is local
|
|
dir_path = Path(model_name_or_path, subfolder)
|
|
if dir_path.exists():
|
|
return str(dir_path)
|
|
|
|
download_kwargs = {
|
|
"repo_id": model_name_or_path,
|
|
"revision": revision,
|
|
"allow_patterns": f"{subfolder}/**" if subfolder not in ["", "."] else None,
|
|
"library_name": "sentence-transformers",
|
|
"token": token,
|
|
"cache_dir": cache_folder,
|
|
"local_files_only": local_files_only,
|
|
"tqdm_class": disabled_tqdm,
|
|
}
|
|
# Try to download from the remote
|
|
try:
|
|
repo_path = snapshot_download(**download_kwargs)
|
|
except Exception:
|
|
# Otherwise, try local (i.e. cache) only
|
|
download_kwargs["local_files_only"] = True
|
|
repo_path = snapshot_download(**download_kwargs)
|
|
return Path(repo_path, subfolder)
|
|
|
|
|
|
def http_get(url: str, path: str) -> None:
|
|
"""Download a URL to a local file with a progress bar.
|
|
|
|
The content is streamed in chunks and first written to a temporary
|
|
``"<path>_part"`` file, which is atomically moved to ``path`` once the
|
|
download has completed successfully. Parent directories of ``path`` are
|
|
created automatically if they do not exist.
|
|
|
|
Args:
|
|
url (str): The HTTP(S) URL to download.
|
|
path (str): Destination file path on the local filesystem.
|
|
|
|
Raises:
|
|
ImportError: If the optional ``httpx`` dependency is not installed.
|
|
httpx.HTTPStatusError: If the HTTP request returns a non-success status code.
|
|
OSError: If the file cannot be written to ``path``.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
try:
|
|
import httpx
|
|
except ImportError:
|
|
raise ImportError("httpx is required to use this function. Please install it via `pip install httpx`.")
|
|
|
|
if os.path.dirname(path) != "":
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
download_filepath = path + "_part"
|
|
with httpx.stream("GET", url, follow_redirects=True) as response:
|
|
response.raise_for_status()
|
|
content_length = response.headers.get("Content-Length")
|
|
total = int(content_length) if content_length is not None else None
|
|
progress = tqdm(
|
|
unit="B", total=total, unit_scale=True, leave=False, desc=f"Downloading {os.path.basename(path)}"
|
|
)
|
|
|
|
try:
|
|
with open(download_filepath, "wb") as file_binary:
|
|
for chunk in response.iter_bytes(chunk_size=1024):
|
|
if chunk:
|
|
progress.update(len(chunk))
|
|
file_binary.write(chunk)
|
|
os.replace(download_filepath, path)
|
|
except Exception:
|
|
if os.path.exists(download_filepath):
|
|
os.remove(download_filepath)
|
|
raise
|
|
finally:
|
|
progress.close()
|