You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
118 lines
4.4 KiB
118 lines
4.4 KiB
from typing import Optional
|
|
|
|
import httpx
|
|
from huggingface_hub import Discussion, HfApi, get_repo_discussions
|
|
|
|
from .utils import cached_file, http_user_agent, logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]:
|
|
main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id
|
|
for discussion in get_repo_discussions(repo_id=model_id, token=token):
|
|
if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request:
|
|
commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token)
|
|
|
|
if main_commit == commits[1].commit_id:
|
|
return discussion
|
|
return None
|
|
|
|
|
|
def spawn_conversion(token: str, private: bool, model_id: str):
|
|
logger.info("Attempting to convert .bin model on the fly to safetensors.")
|
|
|
|
safetensors_convert_space_url = "https://safetensors-convert.hf.space"
|
|
sse_url = f"{safetensors_convert_space_url}/call/run"
|
|
|
|
def start(_sse_connection):
|
|
for line in _sse_connection.iter_lines():
|
|
if not isinstance(line, str):
|
|
line = line.decode()
|
|
if line.startswith("event:"):
|
|
status = line[7:]
|
|
logger.debug(f"Safetensors conversion status: {status}")
|
|
|
|
if status == "complete":
|
|
return
|
|
elif status == "heartbeat":
|
|
logger.debug("Heartbeat")
|
|
else:
|
|
logger.debug(f"Unknown status {status}")
|
|
else:
|
|
logger.debug(line)
|
|
|
|
data = {"data": [model_id, private, token]}
|
|
|
|
result = httpx.post(sse_url, follow_redirects=True, json=data).json()
|
|
event_id = result["event_id"]
|
|
|
|
with httpx.stream("GET", f"{sse_url}/{event_id}") as sse_connection:
|
|
try:
|
|
logger.debug("Spawning safetensors automatic conversion.")
|
|
start(sse_connection)
|
|
except Exception as e:
|
|
logger.warning(f"Error during conversion: {repr(e)}")
|
|
|
|
|
|
def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
|
|
private = api.model_info(model_id).private
|
|
|
|
logger.info("Attempting to create safetensors variant")
|
|
pr_title = "Adding `safetensors` variant of this model"
|
|
token = kwargs.get("token")
|
|
|
|
# This looks into the current repo's open PRs to see if a PR for safetensors was already open. If so, it
|
|
# returns it. It checks that the PR was opened by the bot and not by another user so as to prevent
|
|
# security breaches.
|
|
pr = previous_pr(api, model_id, pr_title, token=token)
|
|
|
|
if pr is None or (not private and pr.author != "SFconvertbot"):
|
|
spawn_conversion(token, private, model_id)
|
|
pr = previous_pr(api, model_id, pr_title, token=token)
|
|
else:
|
|
logger.info("Safetensors PR exists")
|
|
if pr is None:
|
|
raise OSError(
|
|
"Could not create safetensors conversion PR. The repo does not appear to have a file named pytorch_model.bin or model.safetensors."
|
|
"If you are loading with variant, use `use_safetensors=False` to load the original model."
|
|
)
|
|
|
|
sha = f"refs/pr/{pr.num}"
|
|
|
|
return sha
|
|
|
|
|
|
def auto_conversion(
|
|
pretrained_model_name_or_path: str,
|
|
ignore_errors_during_conversion: bool = False,
|
|
safe_weights_name: str = "model.safetensors",
|
|
safe_weights_index_name: str = "model.safetensors.index.json",
|
|
**cached_file_kwargs,
|
|
):
|
|
try:
|
|
api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()})
|
|
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
|
|
|
|
if sha is None:
|
|
return None, None
|
|
cached_file_kwargs["revision"] = sha
|
|
del cached_file_kwargs["_commit_hash"]
|
|
|
|
# This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
|
|
# description.
|
|
sharded = api.file_exists(
|
|
pretrained_model_name_or_path,
|
|
safe_weights_index_name,
|
|
revision=sha,
|
|
token=cached_file_kwargs.get("token"),
|
|
)
|
|
filename = safe_weights_index_name if sharded else safe_weights_name
|
|
|
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
|
return resolved_archive_file, sha, sharded
|
|
except Exception as e:
|
|
if not ignore_errors_during_conversion:
|
|
raise e
|