You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
139 lines
4.9 KiB
139 lines
4.9 KiB
import base64
|
|
import time
|
|
from abc import ABC
|
|
from typing import Any, Optional, Union
|
|
from urllib.parse import urlparse
|
|
|
|
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
|
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
|
from huggingface_hub.utils import get_session, hf_raise_for_status
|
|
from huggingface_hub.utils.logging import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Polling interval (in seconds)
|
|
_POLLING_INTERVAL = 0.5
|
|
|
|
|
|
class WavespeedAITask(TaskProviderHelper, ABC):
|
|
def __init__(self, task: str):
|
|
super().__init__(provider="wavespeed", base_url="https://api.wavespeed.ai", task=task)
|
|
|
|
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
|
return f"/api/v3/{mapped_model}"
|
|
|
|
def get_response(
|
|
self,
|
|
response: Union[bytes, dict],
|
|
request_params: Optional[RequestParameters] = None,
|
|
) -> Any:
|
|
response_dict = _as_dict(response)
|
|
data = response_dict.get("data", {})
|
|
result_path = data.get("urls", {}).get("get")
|
|
|
|
if not result_path:
|
|
raise ValueError("No result URL found in the response")
|
|
if request_params is None:
|
|
raise ValueError("A `RequestParameters` object should be provided to get responses with WaveSpeed AI.")
|
|
|
|
# Parse the request URL to determine base URL
|
|
parsed_url = urlparse(request_params.url)
|
|
# Add /wavespeed to base URL if going through HF router
|
|
if parsed_url.netloc == "router.huggingface.co":
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/wavespeed"
|
|
else:
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
|
# Extract path from result_path URL
|
|
if isinstance(result_path, str):
|
|
result_url_path = urlparse(result_path).path
|
|
else:
|
|
result_url_path = result_path
|
|
|
|
result_url = f"{base_url}{result_url_path}"
|
|
|
|
logger.info("Processing request, polling for results...")
|
|
|
|
# Poll until task is completed
|
|
while True:
|
|
time.sleep(_POLLING_INTERVAL)
|
|
result_response = get_session().get(result_url, headers=request_params.headers)
|
|
hf_raise_for_status(result_response)
|
|
|
|
result = result_response.json()
|
|
task_result = result.get("data", {})
|
|
status = task_result.get("status")
|
|
|
|
if status == "completed":
|
|
# Get content from the first output URL
|
|
if not task_result.get("outputs") or len(task_result["outputs"]) == 0:
|
|
raise ValueError("No output URL in completed response")
|
|
|
|
output_url = task_result["outputs"][0]
|
|
return get_session().get(output_url).content
|
|
elif status == "failed":
|
|
error_msg = task_result.get("error", "Task failed with no specific error message")
|
|
raise ValueError(f"WaveSpeed AI task failed: {error_msg}")
|
|
elif status in ["processing", "created"]:
|
|
continue
|
|
else:
|
|
raise ValueError(f"Unknown status: {status}")
|
|
|
|
|
|
class WavespeedAITextToImageTask(WavespeedAITask):
|
|
def __init__(self):
|
|
super().__init__("text-to-image")
|
|
|
|
def _prepare_payload_as_dict(
|
|
self,
|
|
inputs: Any,
|
|
parameters: dict,
|
|
provider_mapping_info: InferenceProviderMapping,
|
|
) -> Optional[dict]:
|
|
return {"prompt": inputs, **filter_none(parameters)}
|
|
|
|
|
|
class WavespeedAITextToVideoTask(WavespeedAITextToImageTask):
|
|
def __init__(self):
|
|
WavespeedAITask.__init__(self, "text-to-video")
|
|
|
|
|
|
class WavespeedAIImageToImageTask(WavespeedAITask):
|
|
def __init__(self):
|
|
super().__init__("image-to-image")
|
|
|
|
def _prepare_payload_as_dict(
|
|
self,
|
|
inputs: Any,
|
|
parameters: dict,
|
|
provider_mapping_info: InferenceProviderMapping,
|
|
) -> Optional[dict]:
|
|
# Convert inputs to image (URL or base64)
|
|
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
|
|
image = inputs
|
|
elif isinstance(inputs, str):
|
|
# If input is a file path, read it first
|
|
with open(inputs, "rb") as f:
|
|
file_content = f.read()
|
|
image_b64 = base64.b64encode(file_content).decode("utf-8")
|
|
image = f"data:image/jpeg;base64,{image_b64}"
|
|
else:
|
|
# If input is binary data
|
|
image_b64 = base64.b64encode(inputs).decode("utf-8")
|
|
image = f"data:image/jpeg;base64,{image_b64}"
|
|
|
|
# Extract prompt from parameters if present
|
|
prompt = parameters.pop("prompt", None)
|
|
payload = {"image": image, **filter_none(parameters)}
|
|
if prompt is not None:
|
|
payload["prompt"] = prompt
|
|
|
|
return payload
|
|
|
|
|
|
class WavespeedAIImageToVideoTask(WavespeedAIImageToImageTask):
|
|
def __init__(self):
|
|
WavespeedAITask.__init__(self, "image-to-video")
|