You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1915 lines
78 KiB
1915 lines
78 KiB
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import asyncio
|
|
import base64
|
|
import copy
|
|
import enum
|
|
import gc
|
|
import io
|
|
import json
|
|
import re
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from collections.abc import Generator, Iterable
|
|
from contextlib import asynccontextmanager
|
|
from functools import lru_cache
|
|
from io import BytesIO
|
|
from threading import Thread
|
|
from typing import TYPE_CHECKING, Annotated, Optional, TypedDict, Union
|
|
|
|
import typer
|
|
from huggingface_hub import scan_cache_dir
|
|
from tokenizers.decoders import DecodeStream
|
|
from tqdm import tqdm
|
|
|
|
import transformers
|
|
from transformers import AutoTokenizer, BitsAndBytesConfig, GenerationConfig, PreTrainedTokenizerBase
|
|
from transformers.utils.import_utils import (
|
|
is_fastapi_available,
|
|
is_librosa_available,
|
|
is_openai_available,
|
|
is_pydantic_available,
|
|
is_uvicorn_available,
|
|
is_vision_available,
|
|
)
|
|
|
|
from .. import (
|
|
LogitsProcessorList,
|
|
TextIteratorStreamer,
|
|
)
|
|
from ..utils import logging
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from transformers import (
|
|
PreTrainedModel,
|
|
PreTrainedTokenizerFast,
|
|
ProcessorMixin,
|
|
)
|
|
|
|
from ..generation.continuous_batching import ContinuousBatchingManager
|
|
|
|
|
|
if is_librosa_available():
|
|
import librosa
|
|
|
|
if is_vision_available():
|
|
from PIL import Image
|
|
|
|
serve_dependencies_available = (
|
|
is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available()
|
|
)
|
|
if serve_dependencies_available:
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from openai.types.audio.transcription import Transcription
|
|
from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase
|
|
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
|
|
from openai.types.chat.chat_completion import Choice
|
|
from openai.types.chat.chat_completion_chunk import (
|
|
ChatCompletionChunk,
|
|
ChoiceDelta,
|
|
ChoiceDeltaToolCall,
|
|
ChoiceDeltaToolCallFunction,
|
|
)
|
|
from openai.types.chat.chat_completion_chunk import (
|
|
Choice as ChoiceChunk,
|
|
)
|
|
from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
|
|
from openai.types.responses import (
|
|
Response,
|
|
ResponseCompletedEvent,
|
|
ResponseContentPartAddedEvent,
|
|
ResponseContentPartDoneEvent,
|
|
ResponseCreatedEvent,
|
|
ResponseError,
|
|
ResponseErrorEvent,
|
|
ResponseFailedEvent,
|
|
ResponseInProgressEvent,
|
|
ResponseOutputItemAddedEvent,
|
|
ResponseOutputItemDoneEvent,
|
|
ResponseOutputMessage,
|
|
ResponseOutputText,
|
|
ResponseTextDeltaEvent,
|
|
ResponseTextDoneEvent,
|
|
)
|
|
from openai.types.responses.response_create_params import ResponseCreateParamsStreaming
|
|
from pydantic import BaseModel, TypeAdapter, ValidationError
|
|
|
|
# Expand OpenAI's request input types with an optional `generation_config` field
|
|
class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False):
|
|
"""
|
|
OpenAI's ResponseCreateParamsStreaming with an additional field for the generation config (as a json string).
|
|
"""
|
|
|
|
generation_config: str
|
|
|
|
class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
|
|
"""
|
|
OpenAI's CompletionCreateParamsStreaming with additional fields for the generation config (as a json string) and passing the request_id
|
|
"""
|
|
|
|
generation_config: str
|
|
|
|
class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False):
|
|
"""
|
|
OpenAI's TranscriptionCreateParamsBase with an additional field for the generation config (as a json string).
|
|
"""
|
|
|
|
file: bytes # Overwritten -- pydantic isn't happy with `typing.IO[bytes]`, present in the original type
|
|
generation_config: str
|
|
stream: bool = False
|
|
|
|
# Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have built-in validation.
|
|
response_validator = TypeAdapter(TransformersResponseCreateParamsStreaming)
|
|
completion_validator = TypeAdapter(TransformersCompletionCreateParamsStreaming)
|
|
transcription_validator = TypeAdapter(TransformersTranscriptionCreateParams)
|
|
|
|
# Define request fields that are not yet used in `transformers serve`. Receiving these fields will raise an
|
|
# HTTPException.
|
|
UNUSED_RESPONSE_FIELDS = {
|
|
"background",
|
|
"include",
|
|
"max_tool_calls",
|
|
"previous_response_id",
|
|
"prompt",
|
|
"reasoning",
|
|
"service_tier",
|
|
"store",
|
|
"text",
|
|
"tool_choice",
|
|
"top_logprobs",
|
|
"truncation",
|
|
"user",
|
|
}
|
|
|
|
UNUSED_CHAT_COMPLETION_FIELDS = {
|
|
"audio",
|
|
"function_call",
|
|
"functions",
|
|
"logprobs",
|
|
"max_completion_tokens",
|
|
"metadata",
|
|
"modalities",
|
|
"n",
|
|
"parallel_tool_calls",
|
|
"prediction",
|
|
"presence_penalty",
|
|
"reasoning_effort",
|
|
"response_format",
|
|
"service_tier",
|
|
"stop",
|
|
"store",
|
|
"stream_options",
|
|
"tool_choice",
|
|
"top_logprobs",
|
|
"user",
|
|
"web_search_options",
|
|
}
|
|
UNUSED_TRANSCRIPTION_FIELDS = {
|
|
"chunking_strategy",
|
|
"include",
|
|
"language",
|
|
"prompt",
|
|
"response_format",
|
|
"timestamp_granularities",
|
|
}
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
# Possible tokens that indicate the start/end of a tool call
|
|
# TODO (joao, matt): streamline tool token detection logic
|
|
_TOOL_CALL_TOKENS = {
|
|
"qwen": {
|
|
"start": "<tool_call>",
|
|
"end": "</tool_call>",
|
|
},
|
|
}
|
|
_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys())
|
|
|
|
X_REQUEST_ID = "x-request-id"
|
|
|
|
|
|
def set_torch_seed(_seed):
|
|
import torch
|
|
|
|
torch.manual_seed(_seed)
|
|
|
|
|
|
def reset_torch_cache():
|
|
import torch
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def torch_ones_like(_input_tensor):
|
|
import torch
|
|
|
|
return torch.ones_like(_input_tensor)
|
|
|
|
|
|
class Modality(enum.Enum):
|
|
LLM = "LLM"
|
|
VLM = "VLM"
|
|
STT = "STT"
|
|
TTS = "TTS"
|
|
|
|
|
|
def create_generation_config_from_req(
|
|
req: dict,
|
|
model_generation_config: GenerationConfig,
|
|
**kwargs,
|
|
) -> GenerationConfig:
|
|
"""
|
|
Creates a generation config from the parameters of the request. If a generation config is passed in the request,
|
|
it will be used as a baseline for parameterization. Otherwise, we will use the model's default generation config.
|
|
Other parameters in the request will be applied on top of the baseline.
|
|
|
|
Args:
|
|
req (`dict`):
|
|
The request which may optionally contain generation parameters.
|
|
model_generation_config (`GenerationConfig`):
|
|
The model's default generation config.
|
|
kwargs (`dict`):
|
|
Additional parameters to set in the generation config.
|
|
|
|
Returns:
|
|
The prepared `GenerationConfig` object.
|
|
"""
|
|
# If there is a generation config in the request, it is a json string serialization from a `GenerationConfig`
|
|
# object. For simplicity, flags set here take precedence over all other flags.
|
|
if req.get("generation_config") is not None:
|
|
generation_config = GenerationConfig(**json.loads(req["generation_config"]))
|
|
else:
|
|
generation_config = copy.deepcopy(model_generation_config)
|
|
|
|
non_standard_kwargs = generation_config.update(**kwargs)
|
|
# Set extra kwargs that are not in the `GenerationConfig` class (e.g. continuous batching flags)
|
|
for k, v in non_standard_kwargs.items():
|
|
if v is not None:
|
|
setattr(generation_config, k, v)
|
|
|
|
# Response-specific parameters
|
|
if req.get("max_output_tokens") is not None:
|
|
generation_config.max_new_tokens = int(req["max_output_tokens"])
|
|
|
|
# Completion-specific parameters
|
|
if req.get("max_tokens") is not None:
|
|
generation_config.max_new_tokens = int(req["max_tokens"])
|
|
if req.get("frequency_penalty") is not None:
|
|
generation_config.repetition_penalty = float(req["frequency_penalty"])
|
|
if req.get("logit_bias") is not None:
|
|
generation_config.sequence_bias = req["logit_bias"]
|
|
if req.get("stop") is not None:
|
|
generation_config.stop_strings = req["stop"]
|
|
if req.get("temperature") is not None:
|
|
generation_config.temperature = float(req["temperature"])
|
|
if float(req["temperature"]) == 0.0:
|
|
generation_config.do_sample = False
|
|
if req.get("top_p") is not None:
|
|
generation_config.top_p = float(req["top_p"])
|
|
if req.get("seed") is not None:
|
|
set_torch_seed(req["seed"])
|
|
|
|
return generation_config
|
|
|
|
|
|
class ToolState:
|
|
"""Lightweight class to keep track of the tool call state."""
|
|
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
"""Reset the tool call state (assumes we're outside a tool call)."""
|
|
self.inside_tool_call = False
|
|
self.has_tool_name_defined = False
|
|
self.arg_nesting_level = 0
|
|
self.buffer = ""
|
|
|
|
|
|
class TimedModel:
|
|
"""
|
|
A class that holds a PreTrainedModel instance and its associated processor.
|
|
Automatically deletes the instances after a specified timeout.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: "PreTrainedModel",
|
|
timeout_seconds: int,
|
|
processor: Union["ProcessorMixin", "PreTrainedTokenizerFast"] | None = None,
|
|
):
|
|
self.model = model
|
|
self._name_or_path = str(model.name_or_path)
|
|
self.processor = processor
|
|
self.timeout_seconds = timeout_seconds
|
|
self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached)
|
|
self._timer.start()
|
|
|
|
def reset_timer(self):
|
|
"""Reset the timer for the deletion of the instances."""
|
|
self._timer.cancel()
|
|
self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached)
|
|
self._timer.start()
|
|
|
|
def delete_model(self):
|
|
"""Delete the wrapped model and processor and clean up resources."""
|
|
if hasattr(self, "model") and self.model is not None:
|
|
del self.model
|
|
del self.processor
|
|
self.model = None
|
|
self.processor = None
|
|
gc.collect()
|
|
|
|
# Clear CUDA cache if available
|
|
reset_torch_cache()
|
|
|
|
# XXX: in case we manually delete the model, like on server shutdown
|
|
self._timer.cancel()
|
|
|
|
def timeout_reached(self):
|
|
if self.timeout_seconds > 0:
|
|
self.delete_model()
|
|
logger.info(
|
|
f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity"
|
|
)
|
|
|
|
def is_deleted(self):
|
|
"""Check if the instances have been deleted."""
|
|
return not hasattr(self, "model") or self.model is None
|
|
|
|
|
|
class Serve:
|
|
# Defining a class to help with internal state but in practice it's just a method to call
|
|
# TODO: refactor into a proper module with helpers + 1 main method
|
|
def __init__(
|
|
self,
|
|
continuous_batching: Annotated[
|
|
bool, typer.Option(help="Whether to use continuous batching for chat completions.")
|
|
] = False,
|
|
device: Annotated[
|
|
str,
|
|
typer.Option(
|
|
help="Device to use for inference; will default to `auto` and place the model on an accelerator if available."
|
|
),
|
|
] = "auto",
|
|
dtype: Annotated[
|
|
str | None,
|
|
typer.Option(
|
|
help="Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype will be automatically derived from the model's weights."
|
|
),
|
|
] = "auto",
|
|
trust_remote_code: Annotated[
|
|
bool, typer.Option(help="Whether to trust remote code when loading a model.")
|
|
] = False,
|
|
attn_implementation: Annotated[
|
|
str | None,
|
|
typer.Option(
|
|
help="Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
|
|
),
|
|
] = None,
|
|
quantization: Annotated[
|
|
str | None,
|
|
typer.Option(help="Which quantization method to use. choices: 'bnb-4bit', 'bnb-8bit'"),
|
|
] = None,
|
|
host: Annotated[str, typer.Option(help="Interface the server will listen to.")] = "localhost",
|
|
port: Annotated[int, typer.Option(help="Port the server will listen to.")] = 8000,
|
|
model_timeout: Annotated[
|
|
int, typer.Option(help="Time in seconds after which a model will be removed from memory.")
|
|
] = 300,
|
|
log_level: Annotated[
|
|
str, typer.Option(help="Logging level as a string. Example: 'info' or 'warning'.")
|
|
] = "info",
|
|
default_seed: Annotated[
|
|
int | None, typer.Option(help="The default seed for torch, should be an integer.")
|
|
] = None,
|
|
enable_cors: Annotated[
|
|
bool,
|
|
typer.Option(
|
|
help="Whether to enable CORS. Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled."
|
|
),
|
|
] = False,
|
|
input_validation: Annotated[bool, typer.Option(help="Whether to turn on strict input validation.")] = False,
|
|
force_model: Annotated[
|
|
str | None,
|
|
typer.Option(
|
|
help="Name of the model to be forced on all requests. This is useful for testing Apps that don't allow changing models in the request."
|
|
),
|
|
] = None,
|
|
non_blocking: Annotated[
|
|
bool, typer.Option(hidden=True, help="Whether to run the server in a separate thread.")
|
|
] = False,
|
|
) -> None:
|
|
if not serve_dependencies_available:
|
|
raise ImportError(
|
|
"Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`"
|
|
)
|
|
|
|
# Save input arguments
|
|
self.continuous_batching = continuous_batching
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.trust_remote_code = trust_remote_code
|
|
self.attn_implementation = attn_implementation
|
|
self.quantization = quantization
|
|
self.host = host
|
|
self.port = port
|
|
self.model_timeout = model_timeout
|
|
self.log_level = log_level
|
|
self.default_seed = default_seed
|
|
self.enable_cors = enable_cors
|
|
self.input_validation = input_validation
|
|
self.force_model = force_model
|
|
self.non_blocking = non_blocking
|
|
|
|
# Seed
|
|
if default_seed is not None:
|
|
set_torch_seed(default_seed)
|
|
|
|
# Set up logging
|
|
transformers_logger = logging.get_logger("transformers")
|
|
transformers_logger.setLevel(logging.log_levels[log_level.lower()])
|
|
|
|
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
|
|
cb_logger.setLevel(logging.log_levels[log_level.lower()])
|
|
|
|
# Internal state:
|
|
# 1. Tracks models in memory, to prevent reloading the model unnecessarily
|
|
self.loaded_models: dict[str, TimedModel] = {}
|
|
self.running_continuous_batching_manager: ContinuousBatchingManager | None = None
|
|
|
|
# 2. preserves information about the last call and last KV cache, to determine whether we can reuse the KV
|
|
# cache and avoid re-running prefill
|
|
self.last_messages = None
|
|
self.last_kv_cache = None
|
|
self.last_model = None
|
|
|
|
if self.model_timeout is None:
|
|
self.model_timeout = -1 if self.force_model else 300
|
|
|
|
if self.force_model:
|
|
model_id_and_revision = self.process_model_name(self.force_model)
|
|
self.last_model = model_id_and_revision
|
|
self.load_model_and_processor(model_id_and_revision)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
yield
|
|
for model in self.loaded_models.values():
|
|
model.delete_model()
|
|
if self.running_continuous_batching_manager is not None:
|
|
self.running_continuous_batching_manager.stop(block=True, timeout=5)
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
# Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled. However, for
|
|
# security purposes, it's disabled by default
|
|
if self.enable_cors:
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
logger.warning_once(
|
|
"CORS allow origin is set to `*`. This is not recommended for production environments."
|
|
)
|
|
|
|
from fastapi import Request
|
|
|
|
@app.post("/v1/chat/completions")
|
|
def chat_completion(request: Request, body: dict):
|
|
self.validate_chat_completion_request(request=body)
|
|
|
|
if self.continuous_batching:
|
|
return self.continuous_batching_chat_completion(body, request.state.request_id)
|
|
else:
|
|
return self.generate_chat_completion(body)
|
|
|
|
@app.post("/v1/responses")
|
|
def responses(request: dict):
|
|
self.validate_response_request(request=request)
|
|
# Support non-streaming mode when `stream=false` is provided
|
|
stream = request.get("stream", True)
|
|
if not stream:
|
|
response_obj = self.generate_response_non_streaming(request)
|
|
return JSONResponse(response_obj)
|
|
|
|
output = self.generate_response(request)
|
|
return StreamingResponse(output, media_type="text/event-stream")
|
|
|
|
@app.post("/v1/audio/transcriptions")
|
|
async def audio_transcriptions(request: Request):
|
|
# Parses the multipart/form-data request into the request format used by other endpoints
|
|
async with request.form() as form:
|
|
parsed_request = TransformersTranscriptionCreateParams(
|
|
file=await form["file"].read(),
|
|
model=form["model"],
|
|
# TODO: add other fields
|
|
)
|
|
logger.debug(
|
|
f"Received file: {form['file'].filename}; MIME type: {form['file'].content_type}; "
|
|
f"size: {form['file'].size / 1024:.2f} KiB"
|
|
)
|
|
self.validate_transcription_request(request=parsed_request)
|
|
|
|
output = self.generate_transcription(parsed_request)
|
|
return StreamingResponse(output, media_type="text/event-stream")
|
|
|
|
@app.options("/v1/models")
|
|
@app.get("/v1/models")
|
|
def get_all_models():
|
|
return JSONResponse({"object": "list", "data": self.get_gen_models()})
|
|
|
|
@app.get("/health")
|
|
def healthcheck():
|
|
return JSONResponse({"status": "ok"})
|
|
|
|
@app.middleware("http")
|
|
async def get_or_set_request_id(request: Request, call_next):
|
|
request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
|
|
request.state.request_id = request_id
|
|
response = await call_next(request)
|
|
response.headers[X_REQUEST_ID] = request_id
|
|
return response
|
|
|
|
config = uvicorn.Config(app, host=self.host, port=self.port, log_level=self.log_level)
|
|
self.server = uvicorn.Server(config)
|
|
|
|
if self.non_blocking:
|
|
self.start_server()
|
|
else:
|
|
self.server.run()
|
|
|
|
def start_server(self):
|
|
def _run():
|
|
self._loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self._loop)
|
|
# serve() is a coroutine; it exits when server.should_exit becomes True
|
|
self._loop.run_until_complete(self.server.serve())
|
|
|
|
self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False)
|
|
self._thread.start()
|
|
|
|
def kill_server(self):
|
|
if not self._thread:
|
|
raise ValueError("The server cannot be killed as it was not launched in a separate thread.")
|
|
|
|
if not self._thread.is_alive():
|
|
raise ValueError("The server is already killed.")
|
|
|
|
self.server.should_exit = True
|
|
if self._thread and self._thread.is_alive():
|
|
self._thread.join(timeout=2)
|
|
|
|
def _validate_request(
|
|
self,
|
|
request: dict,
|
|
schema: TypedDict,
|
|
validator: TypeAdapter,
|
|
unused_fields: set,
|
|
):
|
|
"""
|
|
Validates the request against the schema, and checks for unexpected keys.
|
|
|
|
Args:
|
|
request (`dict`):
|
|
The request to validate.
|
|
schema (`TypedDict`):
|
|
The schema of the request to validate. It is a `TypedDict` definition.
|
|
validator (`TypeAdapter`):
|
|
The validator to use to validate the request. Built from `schema`.
|
|
unused_fields (`set`):
|
|
Fields accepted by `schema`, but not used in `transformers serve`.
|
|
|
|
Raises:
|
|
HTTPException: If the request is invalid or contains unexpected or unused fields.
|
|
"""
|
|
logger.debug(f"Validating request: {request}")
|
|
|
|
# Validate unexpected keys -- Pydantic doesn't validate extra keys in the request.
|
|
input_keys = set(request.keys())
|
|
possible_keys = schema.__mutable_keys__
|
|
unexpected_keys = input_keys - possible_keys
|
|
if unexpected_keys:
|
|
logger.error(f"Unexpected keys in the request: {unexpected_keys}")
|
|
raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected_keys}")
|
|
|
|
if self.input_validation:
|
|
# Validate expected keys
|
|
try:
|
|
validator.validate_python(request)
|
|
except ValidationError as e:
|
|
logger.error(f"Validation error: {e.errors()}")
|
|
raise HTTPException(status_code=422, detail=e.errors())
|
|
|
|
# Validate unused fields
|
|
unused_fields_in_request = input_keys & unused_fields
|
|
if unused_fields_in_request:
|
|
logger.error(f"Unused fields in the request: {unused_fields_in_request}")
|
|
raise HTTPException(
|
|
status_code=422, detail=f"Unused fields in the request: {unused_fields_in_request}"
|
|
)
|
|
|
|
def validate_response_request(self, request: dict):
|
|
self._validate_request(
|
|
request=request,
|
|
schema=TransformersResponseCreateParamsStreaming,
|
|
validator=response_validator,
|
|
unused_fields=UNUSED_RESPONSE_FIELDS,
|
|
)
|
|
|
|
def validate_chat_completion_request(self, request: dict):
|
|
self._validate_request(
|
|
request=request,
|
|
schema=TransformersCompletionCreateParamsStreaming,
|
|
validator=completion_validator,
|
|
unused_fields=UNUSED_CHAT_COMPLETION_FIELDS,
|
|
)
|
|
|
|
def validate_transcription_request(self, request: dict):
|
|
self._validate_request(
|
|
request=request,
|
|
schema=TransformersTranscriptionCreateParams,
|
|
validator=transcription_validator,
|
|
unused_fields=UNUSED_TRANSCRIPTION_FIELDS,
|
|
)
|
|
|
|
def build_chat_completion_chunk(
|
|
self,
|
|
request_id: str = "",
|
|
content: int | None = None,
|
|
model: str | None = None,
|
|
role: str | None = None,
|
|
finish_reason: str | None = None,
|
|
tool_calls: list[ChoiceDeltaToolCall] | None = None,
|
|
decode_stream: DecodeStream | None = None,
|
|
tokenizer: Optional["PreTrainedTokenizerFast"] = None,
|
|
) -> ChatCompletionChunk:
|
|
"""
|
|
Builds a chunk of a streaming OpenAI Chat Completion response.
|
|
|
|
IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps,
|
|
like Cursor, assume that when the field exists, it has data.
|
|
|
|
Args:
|
|
request_id (`str`):
|
|
The request ID.
|
|
content (`str`, *optional*):
|
|
Content of the response from the model.
|
|
model (`str`, *optional*):
|
|
The model that generated the content.
|
|
role (`str`, *optional*):
|
|
The role of the next content, until a new role is defined.
|
|
finish_reason (`str`, *optional*):
|
|
The reason the generation by the model has finished.
|
|
tool_calls (`list[ChoiceDeltaToolCall]`, *optional*):
|
|
Data about the tool calls, when they are triggered.
|
|
|
|
Returns:
|
|
`str`: The built chunk, a string containing a JSON string with the payload.
|
|
"""
|
|
if decode_stream is not None and content is not None and tokenizer is not None:
|
|
content = decode_stream.step(tokenizer._tokenizer, content)
|
|
|
|
chunk = ChatCompletionChunk(
|
|
id=request_id,
|
|
created=int(time.time()),
|
|
model=model,
|
|
choices=[
|
|
ChoiceChunk(
|
|
delta=ChoiceDelta(
|
|
content=content,
|
|
role=role,
|
|
tool_calls=tool_calls,
|
|
),
|
|
index=0,
|
|
finish_reason=finish_reason,
|
|
)
|
|
],
|
|
system_fingerprint="",
|
|
object="chat.completion.chunk",
|
|
)
|
|
|
|
return chunk
|
|
|
|
@staticmethod
|
|
def chunk_to_sse_element(chunk: ChatCompletionChunk | BaseModel) -> str:
|
|
"""
|
|
Builds an event of a streaming OpenAI Response model or a ChatCompletion chunk.
|
|
|
|
IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps,
|
|
like Cursor, assume that when the field exists, it has data.
|
|
|
|
Args:
|
|
chunk (`BaseModel` or `ChatCompletionChunk`):
|
|
The response to build an event from. One of the multiple OpenAI Response output types
|
|
|
|
Returns:
|
|
`str`: The built chunk, a string containing a JSON string with the payload.
|
|
"""
|
|
return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n"
|
|
|
|
@staticmethod
|
|
@lru_cache
|
|
def get_gen_models(cache_dir: str | None = None) -> list[dict[str, any]]:
|
|
"""
|
|
List LLMs and VLMs in the cache.
|
|
"""
|
|
from transformers.models.auto.modeling_auto import (
|
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
|
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
|
)
|
|
|
|
generative_models = []
|
|
|
|
logger.warning("Scanning the cache directory for LLMs and VLMs.")
|
|
for repo in tqdm(scan_cache_dir(cache_dir).repos):
|
|
if repo.repo_type != "model":
|
|
continue
|
|
|
|
refs = repo.refs
|
|
for ref, revision_info in refs.items():
|
|
files = revision_info.files
|
|
config_path = next((f.file_path for f in files if f.file_name == "config.json"), None)
|
|
|
|
if not config_path:
|
|
continue
|
|
|
|
config = json.loads(config_path.open().read())
|
|
|
|
if not (isinstance(config, dict) and "architectures" in config):
|
|
continue
|
|
|
|
architectures = config["architectures"]
|
|
llms = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()
|
|
vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
|
|
|
|
if any(arch for arch in architectures if arch in [*llms, *vlms]):
|
|
author = repo.repo_id.split("/") if "/" in repo.repo_id else ""
|
|
repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "")
|
|
generative_models.append(
|
|
{
|
|
"owned_by": author,
|
|
"id": repo_handle,
|
|
"object": "model",
|
|
"created": repo.last_modified,
|
|
}
|
|
)
|
|
|
|
return generative_models
|
|
|
|
def continuous_batching_chat_completion(self, req: dict, request_id: str) -> StreamingResponse | JSONResponse:
|
|
"""
|
|
Generates an OpenAI Chat Completion using continuous batching.
|
|
|
|
Args:
|
|
req (`dict`): The request to generate an OpenAI Chat Completion for.
|
|
|
|
Returns:
|
|
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
|
|
"""
|
|
|
|
model_id_and_revision = self.process_model_name(req["model"])
|
|
must_discard_cache = model_id_and_revision != self.last_model
|
|
|
|
self.last_model = model_id_and_revision
|
|
|
|
# When switching models, terminate a continuous batching manager if it is running.
|
|
if must_discard_cache:
|
|
if self.running_continuous_batching_manager is not None:
|
|
self.running_continuous_batching_manager.stop(block=True, timeout=2)
|
|
self.running_continuous_batching_manager = None
|
|
|
|
model, processor = self.load_model_and_processor(model_id_and_revision)
|
|
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
|
|
generation_config = create_generation_config_from_req(
|
|
req,
|
|
model_generation_config=model.generation_config,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
use_cache=False,
|
|
do_sample=False,
|
|
scheduler="fifo",
|
|
)
|
|
|
|
if self.running_continuous_batching_manager is None:
|
|
self.running_continuous_batching_manager = model.init_continuous_batching(
|
|
generation_config=generation_config
|
|
)
|
|
|
|
# TODO (Joao, Lysandre): the logits processors should be fixed in continuous batching and correctly applied in non-cb
|
|
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
|
|
self.running_continuous_batching_manager.start()
|
|
|
|
# TODO (Joao, Lysandre): this should also work with tool support
|
|
inputs = processor.apply_chat_template(
|
|
req["messages"], return_tensors="pt", add_generation_prompt=True, return_dict=True
|
|
).to(model.device)["input_ids"][0]
|
|
|
|
def stream_chat_completion(request_id, decode_stream):
|
|
from ..generation.continuous_batching import RequestStatus
|
|
|
|
try:
|
|
# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
|
|
# they come from the assistant.
|
|
yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
|
|
|
|
n_tokens_generated = 0
|
|
for result in self.running_continuous_batching_manager.request_id_iter(request_id):
|
|
n_tokens_generated += 1
|
|
|
|
# Always yield the token content (even for the final FINISHED token)
|
|
if result.generated_tokens:
|
|
token_id = result.generated_tokens[-1]
|
|
yield self.build_chat_completion_chunk(
|
|
request_id=request_id,
|
|
content=token_id,
|
|
model=model_id_and_revision,
|
|
decode_stream=decode_stream,
|
|
tokenizer=tokenizer,
|
|
)
|
|
|
|
if result.status == RequestStatus.FINISHED:
|
|
generated_all_tokens = n_tokens_generated >= generation_config.max_new_tokens
|
|
|
|
# If the tokenizer has an eos_token, we can have a more robust check.
|
|
if hasattr(tokenizer, "eos_token"):
|
|
final_token_is_eos = result == tokenizer.eos_token
|
|
generated_all_tokens = generated_all_tokens and not final_token_is_eos
|
|
|
|
reason = "length" if generated_all_tokens else "stop"
|
|
|
|
yield self.build_chat_completion_chunk(
|
|
request_id,
|
|
finish_reason=reason,
|
|
model=model_id_and_revision,
|
|
)
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(str(e))
|
|
self.running_continuous_batching_manager.cancel_request(request_id)
|
|
yield f'data: {{"error": "{str(e)}"}}'
|
|
|
|
def buffer_chat_completion(_request_id):
|
|
result = None
|
|
while self.running_continuous_batching_manager.is_running() and result is None:
|
|
result = self.running_continuous_batching_manager.get_result(request_id=_request_id, timeout=1)
|
|
|
|
content = tokenizer.decode(result.generated_tokens)
|
|
|
|
chat_completion_result = ChatCompletion(
|
|
id=_request_id,
|
|
created=int(time.time()),
|
|
object="chat.completion",
|
|
model=model_id_and_revision,
|
|
choices=[
|
|
Choice(
|
|
# TODO check the index
|
|
index=0,
|
|
message=ChatCompletionMessage(content=content, role="assistant"),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
# TODO implement function calling
|
|
# TODO implement usage
|
|
)
|
|
|
|
return chat_completion_result
|
|
|
|
async def cancellation_wrapper_stream(_request_id):
|
|
# Enables cancellation in an async context
|
|
try:
|
|
decode_stream = DecodeStream(inputs.tolist(), False)
|
|
for _chunk in stream_chat_completion(_request_id, decode_stream):
|
|
yield self.chunk_to_sse_element(_chunk)
|
|
await asyncio.sleep(0)
|
|
except asyncio.CancelledError:
|
|
self.running_continuous_batching_manager.cancel_request(_request_id)
|
|
logger.warning(f"Request {_request_id} was cancelled.")
|
|
|
|
def cancellation_wrapper_buffer(_request_id):
|
|
# Enables cancellation in an async context
|
|
try:
|
|
return buffer_chat_completion(_request_id)
|
|
except asyncio.CancelledError:
|
|
self.running_continuous_batching_manager.cancel_request(_request_id)
|
|
logger.warning(f"Request {_request_id} was cancelled.")
|
|
|
|
request_id = self.running_continuous_batching_manager.add_request(
|
|
inputs, request_id=request_id, max_new_tokens=generation_config.max_new_tokens, streaming=req.get("stream")
|
|
)
|
|
|
|
if req.get("stream"):
|
|
return StreamingResponse(cancellation_wrapper_stream(request_id), media_type="text/event-stream")
|
|
else:
|
|
chunk = cancellation_wrapper_buffer(request_id)
|
|
json_chunk = chunk.model_dump_json(exclude_none=True)
|
|
return JSONResponse(json_chunk, media_type="application/json")
|
|
|
|
@staticmethod
|
|
def get_model_modality(model: "PreTrainedModel", processor=None) -> Modality:
|
|
if processor is not None:
|
|
if isinstance(processor, PreTrainedTokenizerBase):
|
|
return Modality.LLM
|
|
|
|
from transformers.models.auto.modeling_auto import (
|
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
|
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
|
)
|
|
|
|
model_classname = model.__class__.__name__
|
|
if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
|
|
modality = Modality.VLM
|
|
elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
|
modality = Modality.LLM
|
|
else:
|
|
raise ValueError(f"Unknown modality: {model_classname}")
|
|
|
|
return modality
|
|
|
|
@staticmethod
|
|
def get_processor_inputs_from_inbound_messages(messages, modality: Modality):
|
|
processor_inputs = []
|
|
|
|
for message in messages:
|
|
parsed_message = {"role": message["role"], "content": []}
|
|
|
|
if modality == Modality.LLM:
|
|
# Input: `content` is a string or a list of dictionaries with a "text" key.
|
|
# Output: `content` is a string.
|
|
if isinstance(message["content"], str):
|
|
parsed_content = message["content"]
|
|
elif isinstance(message["content"], list):
|
|
parsed_content = []
|
|
for content in message["content"]:
|
|
if content["type"] == "text":
|
|
parsed_content.append(content["text"])
|
|
parsed_content = " ".join(parsed_content)
|
|
parsed_message["content"] = parsed_content
|
|
|
|
elif modality == Modality.VLM:
|
|
# Input: `content` is a string or a list of dictionaries with a "type" key (possible types: "text",
|
|
# "image_url").
|
|
# Output: `content` is a list of dictionaries with a "type" key
|
|
if isinstance(message["content"], str):
|
|
parsed_message["content"].append({"type": "text", "text": message["content"]})
|
|
else:
|
|
for content in message["content"]:
|
|
if content["type"] == "text":
|
|
parsed_message["content"].append(content)
|
|
elif content["type"] == "image_url":
|
|
if "base64" in content["image_url"]["url"]:
|
|
image_data = re.sub("^data:image/.+;base64,", "", content["image_url"]["url"])
|
|
image = Image.open(BytesIO(base64.b64decode(image_data)))
|
|
|
|
file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
url = file.name
|
|
|
|
image.save(file.name)
|
|
else:
|
|
url = content["image_url"]["url"]
|
|
|
|
parsed_message["content"].append({"type": "image", "url": url})
|
|
processor_inputs.append(parsed_message)
|
|
return processor_inputs
|
|
|
|
def generate_chat_completion(self, req: dict) -> StreamingResponse | JSONResponse:
|
|
"""
|
|
Generates an OpenAI Chat Completion using `generate`.
|
|
|
|
Args:
|
|
req (`dict`): The request to generate an OpenAI Chat Completion for.
|
|
|
|
Returns:
|
|
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
|
|
"""
|
|
|
|
# TODO: This should throw an error in case the specified model in the request is different to the forced model.
|
|
if self.force_model is not None:
|
|
req["model"] = self.force_model
|
|
|
|
messages: Iterable[ChatCompletionMessageParam] = req["messages"]
|
|
|
|
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
|
|
# request whose last message is from the assistant.
|
|
if messages[-1]["role"] == "assistant":
|
|
return
|
|
|
|
model_id_and_revision = self.process_model_name(req["model"])
|
|
must_discard_cache = model_id_and_revision != self.last_model
|
|
|
|
self.last_model = model_id_and_revision
|
|
model, processor = self.load_model_and_processor(model_id_and_revision)
|
|
|
|
modality = self.get_model_modality(model, processor=processor)
|
|
processor_inputs = self.get_processor_inputs_from_inbound_messages(messages, modality)
|
|
|
|
# ====== TOOL PREPROCESSING LOGIC ======
|
|
tool_model_family = None
|
|
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
|
|
if supported_model_families in model.config.architectures[0].lower():
|
|
tool_model_family = supported_model_families
|
|
break
|
|
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
|
|
# 1. force generation to pick from the tool names
|
|
# 2. force generation to pick from that tool's arguments
|
|
# ====== END OF TOOL PREPROCESSING LOGIC ======
|
|
|
|
inputs = processor.apply_chat_template(
|
|
processor_inputs,
|
|
add_generation_prompt=True,
|
|
tools=req.get("tools"),
|
|
return_tensors="pt",
|
|
return_dict=True,
|
|
tokenize=True,
|
|
)
|
|
inputs = inputs.to(model.device)
|
|
request_id = req.get("request_id", "req_0")
|
|
|
|
# Temporary hack for GPTOSS 1: don't filter special tokens
|
|
skip_special_tokens = True
|
|
if "gptoss" in model.config.architectures[0].lower():
|
|
skip_special_tokens = False
|
|
|
|
generation_streamer = TextIteratorStreamer(
|
|
processor,
|
|
skip_special_tokens=skip_special_tokens,
|
|
skip_prompt=True,
|
|
)
|
|
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
|
|
|
last_kv_cache = None
|
|
if self.is_continuation(req) and not must_discard_cache:
|
|
seq_len = self.last_kv_cache.get_seq_length()
|
|
if inputs["input_ids"].shape[-1] > seq_len:
|
|
last_kv_cache = self.last_kv_cache
|
|
|
|
generation_kwargs = {
|
|
**inputs,
|
|
"streamer": generation_streamer,
|
|
"generation_config": generation_config,
|
|
"return_dict_in_generate": True,
|
|
"past_key_values": last_kv_cache,
|
|
}
|
|
|
|
def stream_chat_completion(streamer, _request_id):
|
|
# Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
|
|
# classes and piping the reasoning trace into a new field
|
|
filter_cot = False
|
|
cot_trace_end = None
|
|
if "gptoss" in model.config.architectures[0].lower():
|
|
filter_cot = True
|
|
cot_trace_end = "<|channel|>final<|message|>"
|
|
|
|
# Thin wrapper to save the KV cache after generation
|
|
def generate_with_cache(**kwargs):
|
|
generate_output = model.generate(**kwargs)
|
|
self.last_kv_cache = generate_output.past_key_values
|
|
|
|
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
|
|
results = ""
|
|
|
|
try:
|
|
thread.start()
|
|
tool_state = ToolState()
|
|
|
|
# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
|
|
# they come from the assistant.
|
|
yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
|
|
|
|
result = ""
|
|
n_tokens_generated = 0
|
|
|
|
for result in streamer:
|
|
n_tokens_generated += 1
|
|
|
|
# Temporary hack for GPT-OSS 3: don't emit the final "<|return|>"
|
|
if "gptoss" in model.config.architectures[0].lower():
|
|
result = result.removesuffix("<|return|>")
|
|
results += result
|
|
|
|
# (related to temporary hack 2)
|
|
if filter_cot:
|
|
if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
|
|
filter_cot = False
|
|
continue
|
|
else:
|
|
continue
|
|
|
|
# ====== TOOL CALL LOGIC ======
|
|
if tool_model_family is not None:
|
|
# Start of a tool call: reset state variables, set `inside_tool_call`
|
|
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
|
|
tool_state.inside_tool_call = True
|
|
continue
|
|
|
|
# End of tool call: reset `inside_tool_call`, emit a `finish_reason`
|
|
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
|
|
tool_state.reset()
|
|
yield self.build_chat_completion_chunk(
|
|
request_id=_request_id,
|
|
role=None,
|
|
finish_reason="tool_calls",
|
|
model=model_id_and_revision,
|
|
)
|
|
|
|
continue
|
|
# Inside a tool call
|
|
if tool_state.inside_tool_call:
|
|
tool_state.buffer += result
|
|
|
|
# First step: extract the tool name (may need several tokens, and we can't emit a delta
|
|
# until we have the full name)
|
|
if not tool_state.has_tool_name_defined:
|
|
tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
|
|
if tool_name is None:
|
|
continue
|
|
else:
|
|
tool_name = tool_name.group(1)
|
|
tool_state.has_tool_name_defined = True
|
|
tool = ChoiceDeltaToolCall(
|
|
function=ChoiceDeltaToolCallFunction(name=tool_name),
|
|
index=0,
|
|
type="function",
|
|
id=_request_id + "_tool_call", # Only the first tool call delta has an id
|
|
)
|
|
|
|
# Second step: extract tool arguments. The tool arguments can be seen as a json string
|
|
# within the tool json string. We emit a delta for the arguments.
|
|
else:
|
|
# Empty text: skip
|
|
if result == "":
|
|
continue
|
|
# Until we see the `"arguments": {` in the buffer, we skip
|
|
# TODO: other models will likely need more elaborate processing here
|
|
if '"arguments": {' not in tool_state.buffer:
|
|
continue
|
|
|
|
# Handle nesting. We want to exclude the last } from the emitted arguments (it's
|
|
# closing the outermost nesting level, outside the arguments block)
|
|
tool_state.arg_nesting_level += result.count("{")
|
|
tool_state.arg_nesting_level -= result.count("}")
|
|
if tool_state.arg_nesting_level < 0:
|
|
result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
|
|
|
|
tool = ChoiceDeltaToolCall(
|
|
function=ChoiceDeltaToolCallFunction(arguments=result),
|
|
index=0,
|
|
type="function",
|
|
)
|
|
|
|
yield self.build_chat_completion_chunk(
|
|
request_id=_request_id,
|
|
role=None,
|
|
tool_calls=[tool],
|
|
model=model_id_and_revision,
|
|
)
|
|
continue
|
|
# ====== END OF TOOL CALL LOGIC ======
|
|
|
|
# All non-tool related tokens are emitted as assistant messages. Empty text is skipped.
|
|
if result != "":
|
|
yield self.build_chat_completion_chunk(
|
|
_request_id, content=result, model=model_id_and_revision
|
|
)
|
|
|
|
generated_all_tokens = n_tokens_generated >= generation_config.max_new_tokens
|
|
|
|
# If the tokenizer has an eos_token, we can have a more robust check.
|
|
if hasattr(streamer.tokenizer, "eos_token"):
|
|
final_token_is_eos = result == streamer.tokenizer.eos_token
|
|
generated_all_tokens = generated_all_tokens and not final_token_is_eos
|
|
|
|
reason = "length" if generated_all_tokens else "stop"
|
|
|
|
yield self.build_chat_completion_chunk(_request_id, finish_reason=reason, model=model_id_and_revision)
|
|
|
|
thread.join()
|
|
except Exception as e:
|
|
logger.error(str(e))
|
|
yield f'data: {{"error": "{str(e)}"}}'
|
|
|
|
finally:
|
|
thread.join()
|
|
|
|
if req.get("stream"):
|
|
return StreamingResponse(
|
|
map(self.chunk_to_sse_element, stream_chat_completion(generation_streamer, request_id)),
|
|
media_type="text/event-stream",
|
|
)
|
|
else:
|
|
content = []
|
|
finish_reason = "stop"
|
|
|
|
generator = stream_chat_completion(generation_streamer, request_id)
|
|
usage = None
|
|
|
|
for chunk in generator:
|
|
choice = chunk.choices[0]
|
|
if getattr(choice.delta, "content", None):
|
|
content.append(choice.delta.content)
|
|
if choice.finish_reason:
|
|
finish_reason = choice.finish_reason
|
|
if getattr(chunk, "usage", None):
|
|
usage = chunk.usage
|
|
|
|
chat_completion_result = ChatCompletion(
|
|
id=request_id,
|
|
created=int(time.time()),
|
|
object="chat.completion",
|
|
model=model_id_and_revision,
|
|
choices=[
|
|
Choice(
|
|
# TODO check the index
|
|
index=0,
|
|
message=ChatCompletionMessage(content="".join(content), role="assistant"),
|
|
finish_reason=finish_reason,
|
|
)
|
|
],
|
|
# TODO implement function calling
|
|
usage=usage,
|
|
)
|
|
|
|
result = chat_completion_result.model_dump(exclude_none=True)
|
|
|
|
return JSONResponse(result, media_type="application/json")
|
|
|
|
def generate_response(self, req: dict) -> Generator[str, None, None]:
|
|
"""
|
|
Generates an OpenAI Response using `generate`.
|
|
|
|
Args:
|
|
req (`dict`): The request to generate an OpenAI Response for.
|
|
|
|
Returns:
|
|
`Generator[str, None, None]`: A generator that yields the OpenAI Response events.
|
|
"""
|
|
# TODO -- Implement non-streaming mode
|
|
model_id_and_revision = self.process_model_name(req["model"])
|
|
must_discard_cache = model_id_and_revision != self.last_model
|
|
self.last_model = model_id_and_revision
|
|
model, processor = self.load_model_and_processor(model_id_and_revision)
|
|
|
|
if isinstance(req["input"], str):
|
|
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
|
|
inputs.append({"role": "user", "content": req["input"]})
|
|
elif isinstance(req["input"], list):
|
|
if "instructions" in req:
|
|
if req["input"][0]["role"] != "system":
|
|
inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]]
|
|
else:
|
|
inputs = req["input"]
|
|
inputs[0]["content"] = req["instructions"]
|
|
else:
|
|
inputs = req["input"]
|
|
elif isinstance(req["input"], dict):
|
|
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
|
|
inputs.append(req["input"])
|
|
else:
|
|
raise TypeError("inputs should be a list, dict, or str")
|
|
|
|
inputs = processor.apply_chat_template(
|
|
inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True
|
|
)["input_ids"]
|
|
inputs = inputs.to(model.device)
|
|
request_id = req.get("previous_response_id", "req_0")
|
|
|
|
# Temporary hack for GPT-OSS 1: don't filter special tokens
|
|
skip_special_tokens = True
|
|
if "gptoss" in model.config.architectures[0].lower():
|
|
skip_special_tokens = False
|
|
|
|
generation_streamer = TextIteratorStreamer(
|
|
processor,
|
|
skip_special_tokens=skip_special_tokens,
|
|
skip_prompt=True,
|
|
)
|
|
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
|
|
|
last_kv_cache = None
|
|
if self.is_continuation(req) and not must_discard_cache:
|
|
seq_len = self.last_kv_cache.get_seq_length()
|
|
if inputs["input_ids"].shape[-1] > seq_len:
|
|
last_kv_cache = self.last_kv_cache
|
|
|
|
generation_kwargs = {
|
|
"inputs": inputs,
|
|
"attention_mask": torch_ones_like(inputs),
|
|
"streamer": generation_streamer,
|
|
"generation_config": generation_config,
|
|
"return_dict_in_generate": True,
|
|
"past_key_values": last_kv_cache,
|
|
}
|
|
|
|
def stream_response(streamer, _request_id):
|
|
# Temporary hack for GPT-OSS 2: filter out the CoT tokens. Full solution here implies defining new output
|
|
# classes and piping the reasoning trace into a new field
|
|
filter_cot = False
|
|
cot_trace_end = None
|
|
if "gptoss" in model.config.architectures[0].lower():
|
|
filter_cot = True
|
|
cot_trace_end = "<|channel|>final<|message|>"
|
|
|
|
# Thin wrapper to save the KV cache after generation
|
|
def generate_with_cache(**kwargs):
|
|
generate_output = model.generate(**kwargs)
|
|
self.last_kv_cache = generate_output.past_key_values
|
|
|
|
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
|
|
sequence_number = 0
|
|
output_index = 0
|
|
content_index = 0
|
|
|
|
try:
|
|
thread.start()
|
|
created_at = time.time() # the spec expects a unix timestamp in seconds
|
|
|
|
# We start by acknowledging the request (the request has `status="queued"`), and then by moving it to
|
|
# in progress (`status="in_progress"`)
|
|
response_created = ResponseCreatedEvent(
|
|
type="response.created",
|
|
sequence_number=sequence_number,
|
|
response=Response(
|
|
id=f"resp_{request_id}",
|
|
created_at=created_at,
|
|
status="queued",
|
|
model=model_id_and_revision,
|
|
instructions=req.get("instructions"),
|
|
text={"format": {"type": "text"}},
|
|
object="response",
|
|
tools=[],
|
|
output=[],
|
|
parallel_tool_calls=req.get("parallel_tool_calls", False),
|
|
tool_choice="auto",
|
|
metadata=req.get("metadata"),
|
|
),
|
|
)
|
|
sequence_number += 1
|
|
yield self.chunk_to_sse_element(response_created)
|
|
|
|
response_in_progress = ResponseInProgressEvent(
|
|
type="response.in_progress",
|
|
sequence_number=sequence_number,
|
|
response=Response(
|
|
id=f"resp_{request_id}",
|
|
created_at=created_at,
|
|
status="in_progress",
|
|
model=model_id_and_revision,
|
|
instructions=req.get("instructions"),
|
|
text={"format": {"type": "text"}},
|
|
object="response",
|
|
tools=[],
|
|
output=[],
|
|
parallel_tool_calls=req.get("parallel_tool_calls", False),
|
|
tool_choice="auto",
|
|
metadata=req.get("metadata"),
|
|
),
|
|
)
|
|
sequence_number += 1
|
|
yield self.chunk_to_sse_element(response_in_progress)
|
|
|
|
# Start the output item. Emit the assistant role to start the stream. Other chunks won't have a role,
|
|
# as it is implicit
|
|
response_output_item_added = ResponseOutputItemAddedEvent(
|
|
type="response.output_item.added",
|
|
sequence_number=sequence_number,
|
|
output_index=output_index,
|
|
item=ResponseOutputMessage(
|
|
id=f"msg_{request_id}", type="message", status="in_progress", role="assistant", content=[]
|
|
),
|
|
)
|
|
sequence_number += 1
|
|
yield self.chunk_to_sse_element(response_output_item_added)
|
|
|
|
# Start the content part of the event
|
|
response_content_part_added = ResponseContentPartAddedEvent(
|
|
type="response.content_part.added",
|
|
item_id=f"msg_{request_id}",
|
|
sequence_number=sequence_number,
|
|
output_index=output_index,
|
|
content_index=content_index,
|
|
part=ResponseOutputText(type="output_text", text="", annotations=[]),
|
|
)
|
|
sequence_number += 1
|
|
yield self.chunk_to_sse_element(response_content_part_added)
|
|
|
|
# Stream the actual generated text
|
|
results = ""
|
|
for result in streamer:
|
|
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
|
|
if "gptoss" in model.config.architectures[0].lower():
|
|
result = result.removesuffix("<|return|>")
|
|
results += result
|
|
|
|
# (related to temporary hack 2)
|
|
if filter_cot:
|
|
if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
|
|
filter_cot = False
|
|
results = "" # reset the results -> results will now track the final response
|
|
continue
|
|
else:
|
|
response_output_text_delta = ResponseTextDeltaEvent(
|
|
type="response.output_text.delta",
|
|
item_id=f"msg_{request_id}",
|
|
sequence_number=sequence_number,
|
|
output_index=output_index,
|
|
content_index=content_index,
|
|
delta=result,
|
|
logprobs=[],
|
|
)
|
|
sequence_number += 1
|
|
yield self.chunk_to_sse_element(response_output_text_delta)
|
|
else:
|
|
# Normal path: emit token deltas when not filtering CoT
|
|
if result:
|
|
response_output_text_delta = ResponseTextDeltaEvent(
|
|
type="response.output_text.delta",
|
|
item_id=f"msg_{request_id}",
|
|
sequence_number=sequence_number,
|
|
output_index=output_index,
|
|
content_index=content_index,
|
|
delta=result,
|
|
logprobs=[],
|
|
)
|
|
sequence_number += 1
|
|
yield self.chunk_to_sse_element(response_output_text_delta)
|
|
|
|
# Signal the end of the text generation
|
|
response_output_text_done = ResponseTextDoneEvent(
|
|
type="response.output_text.done",
|
|
item_id=f"msg_{request_id}",
|
|
sequence_number=sequence_number,
|
|
output_index=output_index,
|
|
content_index=0,
|
|
text=results,
|
|
logprobs=[],
|
|
)
|
|
sequence_number += 1
|
|
yield self.chunk_to_sse_element(response_output_text_done)
|
|
|
|
# Complete the content part
|
|
response_content_part_done = ResponseContentPartDoneEvent(
|
|
type="response.content_part.done",
|
|
item_id=f"msg_{request_id}",
|
|
sequence_number=sequence_number,
|
|
output_index=output_index,
|
|
content_index=content_index,
|
|
part=ResponseOutputText(type="output_text", text=response_output_text_done.text, annotations=[]),
|
|
)
|
|
sequence_number += 1
|
|
content_index += 1
|
|
yield self.chunk_to_sse_element(response_content_part_done)
|
|
|
|
# Complete the output item
|
|
response_output_item_done = ResponseOutputItemDoneEvent(
|
|
type="response.output_item.done",
|
|
sequence_number=sequence_number,
|
|
output_index=output_index,
|
|
item=ResponseOutputMessage(
|
|
id=f"msg_{request_id}",
|
|
type="message",
|
|
status="completed",
|
|
role="assistant",
|
|
content=[response_content_part_done.part],
|
|
annotations=[],
|
|
),
|
|
)
|
|
sequence_number += 1
|
|
output_index += 1
|
|
yield self.chunk_to_sse_element(response_output_item_done)
|
|
|
|
# Finally, Complete the event
|
|
response_completed = ResponseCompletedEvent(
|
|
type="response.completed",
|
|
sequence_number=sequence_number,
|
|
response=Response(
|
|
id=f"resp_{request_id}",
|
|
created_at=created_at,
|
|
status="completed",
|
|
model=model_id_and_revision,
|
|
instructions=req.get("instructions"),
|
|
text={"format": {"type": "text"}},
|
|
output=[response_output_item_done.item],
|
|
object="response",
|
|
tools=[],
|
|
parallel_tool_calls=req.get("parallel_tool_calls", False),
|
|
tool_choice="auto",
|
|
metadata=req.get("metadata"),
|
|
),
|
|
)
|
|
sequence_number += 1
|
|
yield self.chunk_to_sse_element(response_completed)
|
|
|
|
thread.join()
|
|
except Exception as e:
|
|
logger.error(f"Exception in response generation: {str(e)}")
|
|
error_event = ResponseErrorEvent(
|
|
type="error",
|
|
sequence_number=sequence_number,
|
|
message=str(e),
|
|
)
|
|
sequence_number += 1
|
|
yield self.chunk_to_sse_element(error_event)
|
|
|
|
response_failed = ResponseFailedEvent(
|
|
type="response.failed",
|
|
sequence_number=sequence_number,
|
|
response=Response(
|
|
id=f"resp_{request_id}",
|
|
created_at=created_at,
|
|
status="failed",
|
|
model=model_id_and_revision,
|
|
instructions=req.get("instructions"),
|
|
text={"format": {"type": "text"}},
|
|
output=[],
|
|
object="response",
|
|
tools=[],
|
|
parallel_tool_calls=False,
|
|
tool_choice="auto",
|
|
metadata=req.get("metadata"),
|
|
error=ResponseError(
|
|
code="server_error",
|
|
message=str(e),
|
|
),
|
|
),
|
|
)
|
|
sequence_number += 1
|
|
yield self.chunk_to_sse_element(response_failed)
|
|
|
|
finally:
|
|
thread.join()
|
|
|
|
return stream_response(generation_streamer, request_id)
|
|
|
|
def generate_response_non_streaming(self, req: dict) -> dict:
|
|
"""
|
|
Generates an OpenAI Response in non-streaming mode (single JSON payload).
|
|
|
|
Args:
|
|
req (`dict`): The request to generate an OpenAI Response for.
|
|
|
|
Returns:
|
|
`dict`: The OpenAI `Response` serialized as a dict.
|
|
"""
|
|
model_id_and_revision = self.process_model_name(req["model"])
|
|
must_discard_cache = model_id_and_revision != self.last_model
|
|
self.last_model = model_id_and_revision
|
|
model, processor = self.load_model_and_processor(model_id_and_revision)
|
|
|
|
if isinstance(req["input"], str):
|
|
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
|
|
inputs.append({"role": "user", "content": req["input"]})
|
|
elif isinstance(req["input"], list):
|
|
if "instructions" in req:
|
|
if req["input"][0]["role"] != "system":
|
|
inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]]
|
|
else:
|
|
inputs = req["input"]
|
|
inputs[0]["content"] = req["instructions"]
|
|
else:
|
|
inputs = req["input"]
|
|
elif isinstance(req["input"], dict):
|
|
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
|
|
inputs.append(req["input"])
|
|
else:
|
|
raise ValueError("inputs should be a list, dict, or str")
|
|
|
|
inputs = processor.apply_chat_template(
|
|
inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True
|
|
)["input_ids"]
|
|
inputs = inputs.to(model.device)
|
|
request_id = req.get("previous_response_id", "req_0")
|
|
|
|
# Temporary hack for GPTOSS 1: don't filter special tokens
|
|
skip_special_tokens = True
|
|
if "gptoss" in model.config.architectures[0].lower():
|
|
skip_special_tokens = False
|
|
|
|
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
|
|
|
last_kv_cache = None
|
|
if self.is_continuation(req) and not must_discard_cache:
|
|
seq_len = self.last_kv_cache.get_seq_length()
|
|
if inputs.shape[-1] > seq_len:
|
|
last_kv_cache = self.last_kv_cache
|
|
|
|
generate_output = model.generate(
|
|
inputs=inputs,
|
|
attention_mask=torch_ones_like(inputs),
|
|
generation_config=generation_config,
|
|
return_dict_in_generate=True,
|
|
past_key_values=last_kv_cache,
|
|
)
|
|
# save KV cache
|
|
self.last_kv_cache = generate_output.past_key_values
|
|
|
|
# Decode full text
|
|
full_text = processor.batch_decode(generate_output.sequences, skip_special_tokens=skip_special_tokens)[0]
|
|
|
|
created_at = time.time()
|
|
response_output_item = ResponseOutputMessage(
|
|
id=f"msg_{request_id}",
|
|
type="message",
|
|
status="completed",
|
|
role="assistant",
|
|
content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])],
|
|
annotations=[],
|
|
)
|
|
response_completed = Response(
|
|
id=f"resp_{request_id}",
|
|
created_at=created_at,
|
|
status="completed",
|
|
model=model_id_and_revision,
|
|
instructions=req.get("instructions"),
|
|
text={"format": {"type": "text"}},
|
|
output=[response_output_item],
|
|
object="response",
|
|
tools=[],
|
|
parallel_tool_calls=req.get("parallel_tool_calls", False),
|
|
tool_choice="auto",
|
|
metadata=req.get("metadata"),
|
|
)
|
|
return response_completed.model_dump(exclude_none=True)
|
|
|
|
def generate_transcription(self, req: dict) -> Generator[str, None, None]:
|
|
"""
|
|
Generates an OpenAI Transcription using the audio file.
|
|
|
|
Args:
|
|
req (`dict`): The request containing the audio file and model information.
|
|
|
|
Returns:
|
|
`Generator[str, None, None]`: A generator that yields the transcription result.
|
|
"""
|
|
# TODO: implement streaming transcription (currently, it's not streaming)
|
|
if not is_librosa_available():
|
|
raise ImportError(
|
|
"Missing librosa dependency for audio transcription. Please install with `pip install librosa`"
|
|
)
|
|
model_id_and_revision = self.process_model_name(req["model"])
|
|
audio_model, audio_processor = self.load_audio_model_and_processor(model_id_and_revision)
|
|
|
|
generation_streamer = TextIteratorStreamer(
|
|
audio_processor.tokenizer, skip_special_tokens=True, skip_prompt=True
|
|
)
|
|
generation_config = create_generation_config_from_req(
|
|
req, model_generation_config=audio_model.generation_config
|
|
)
|
|
|
|
# Read the binary audio file using librosa
|
|
model_sampling_rate = audio_processor.feature_extractor.sampling_rate
|
|
audio_bytes = io.BytesIO(req["file"])
|
|
audio_array, _ = librosa.load(audio_bytes, sr=model_sampling_rate, mono=True)
|
|
audio_inputs = audio_processor(audio_array, sampling_rate=model_sampling_rate, return_tensors="pt").to(
|
|
audio_model.device
|
|
)
|
|
audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype)
|
|
|
|
generation_kwargs = {
|
|
"streamer": generation_streamer,
|
|
"generation_config": generation_config,
|
|
"return_dict_in_generate": True,
|
|
}
|
|
|
|
def _generate_transcription():
|
|
generated_ids = audio_model.generate(**audio_inputs, **generation_kwargs)
|
|
transcription_text = audio_processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)[0]
|
|
transcription = Transcription(text=transcription_text)
|
|
yield f"{transcription.model_dump_json(exclude_none=True)}"
|
|
|
|
return _generate_transcription()
|
|
|
|
def is_continuation(self, req: dict) -> bool:
|
|
"""
|
|
Determines whether the current request is a continuation of the last request. In other words, if it is the
|
|
same chat session.
|
|
|
|
Args:
|
|
req (`dict`): The request to check.
|
|
|
|
Returns:
|
|
`True` if the request is a continuation of the last request, `False` otherwise.
|
|
"""
|
|
messages = req.get("messages") or req.get("input") # ChatCompletion and Response have different fields
|
|
req_continues_last_messages = True
|
|
|
|
# No cached messages: this is a new request
|
|
if self.last_messages is None:
|
|
req_continues_last_messages = False
|
|
# The new request has no new rounds of conversation: this is a new request
|
|
elif len(self.last_messages) >= len(messages):
|
|
req_continues_last_messages = False
|
|
# Otherwise, check that the last messages are a subset of the new request
|
|
else:
|
|
for i in range(len(self.last_messages)):
|
|
if self.last_messages[i] != messages[i]:
|
|
req_continues_last_messages = False
|
|
break
|
|
|
|
self.last_messages = messages
|
|
return req_continues_last_messages
|
|
|
|
def get_quantization_config(self) -> BitsAndBytesConfig | None:
|
|
"""
|
|
Returns the quantization config for the given CLI arguments.
|
|
|
|
Returns:
|
|
`Optional[BitsAndBytesConfig]`: The quantization config.
|
|
"""
|
|
if self.quantization == "bnb-4bit":
|
|
quantization_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_use_double_quant=True,
|
|
)
|
|
elif self.quantization == "bnb-8bit":
|
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
else:
|
|
quantization_config = None
|
|
|
|
if quantization_config is not None:
|
|
logger.info(f"Quantization applied with the following config: {quantization_config}")
|
|
|
|
return quantization_config
|
|
|
|
def process_model_name(self, model_id: str) -> str:
|
|
"""
|
|
Applies the `force_model` CLI argument and canonicalizes the model name to the format "model_id@revision".
|
|
If the model_id DOESN'T contain an @, it defaults to "model_id@main".
|
|
|
|
Args:
|
|
model_id (`str`): The model ID.
|
|
|
|
Returns:
|
|
`str`: The canonicalized model name to be used
|
|
"""
|
|
if self.force_model is not None:
|
|
model_id = self.force_model
|
|
if "@" in model_id:
|
|
return model_id
|
|
return f"{model_id}@main"
|
|
|
|
def _load_model_and_data_processor(self, model_id_and_revision: str):
|
|
"""
|
|
Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI
|
|
arguments.
|
|
|
|
Args:
|
|
model_id_and_revision (`str`):
|
|
The model ID and revision to load.
|
|
model_cls (`type[PreTrainedModel]`):
|
|
The model class to load.
|
|
|
|
Returns:
|
|
`tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]`: The loaded model and
|
|
data processor (tokenizer, audio processor, etc.).
|
|
"""
|
|
import torch
|
|
|
|
from transformers import AutoConfig, AutoProcessor
|
|
|
|
logger.info(f"Loading {model_id_and_revision}")
|
|
|
|
if "@" in model_id_and_revision:
|
|
model_id, revision = model_id_and_revision.split("@", 1)
|
|
else:
|
|
model_id, revision = model_id_and_revision, "main"
|
|
|
|
try:
|
|
data_processor = AutoProcessor.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
trust_remote_code=self.trust_remote_code,
|
|
)
|
|
except OSError:
|
|
try:
|
|
data_processor = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
trust_remote_code=self.trust_remote_code,
|
|
)
|
|
except OSError:
|
|
raise OSError("Failed to load processor with `AutoProcessor` and `AutoTokenizer`.")
|
|
|
|
dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype)
|
|
quantization_config = self.get_quantization_config()
|
|
|
|
model_kwargs = {
|
|
"revision": revision,
|
|
"attn_implementation": self.attn_implementation,
|
|
"dtype": dtype,
|
|
"device_map": self.device,
|
|
"trust_remote_code": self.trust_remote_code,
|
|
"quantization_config": quantization_config,
|
|
}
|
|
|
|
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
|
architecture = getattr(transformers, config.architectures[0])
|
|
model = architecture.from_pretrained(model_id, **model_kwargs)
|
|
|
|
has_default_max_length = (
|
|
model.generation_config.max_new_tokens is None and model.generation_config.max_length == 20
|
|
)
|
|
has_short_max_new_tokens = (
|
|
model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 1024
|
|
)
|
|
if has_default_max_length or has_short_max_new_tokens:
|
|
model.generation_config.max_new_tokens = 1024
|
|
|
|
logger.info(f"Loaded model {model_id_and_revision}")
|
|
return model, data_processor
|
|
|
|
def load_model_and_processor(
|
|
self, model_id_and_revision: str
|
|
) -> tuple["PreTrainedModel", "PreTrainedTokenizerFast"]:
|
|
"""
|
|
Loads the text model and processor from the given model ID and revision into the ServeCommand instance.
|
|
|
|
Args:
|
|
model_id_and_revision (`str`):
|
|
The model ID and revision to load.
|
|
|
|
Returns:
|
|
`tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and processor.
|
|
"""
|
|
if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
|
|
model, processor = self._load_model_and_data_processor(model_id_and_revision)
|
|
self.loaded_models[model_id_and_revision] = TimedModel(
|
|
model,
|
|
timeout_seconds=self.model_timeout,
|
|
processor=processor,
|
|
)
|
|
else:
|
|
self.loaded_models[model_id_and_revision].reset_timer()
|
|
model = self.loaded_models[model_id_and_revision].model
|
|
processor = self.loaded_models[model_id_and_revision].processor
|
|
|
|
return model, processor
|
|
|
|
def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple["PreTrainedModel", "ProcessorMixin"]:
|
|
"""
|
|
Loads the audio model and processor from the given model ID and revision into the ServeCommand instance.
|
|
|
|
Args:
|
|
model_id_and_revision (`str`):
|
|
The model ID and revision to load.
|
|
|
|
Returns:
|
|
`tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor.
|
|
"""
|
|
if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
|
|
audio_model, audio_processor = self._load_model_and_data_processor(model_id_and_revision)
|
|
self.loaded_models[model_id_and_revision] = TimedModel(
|
|
audio_model,
|
|
timeout_seconds=self.model_timeout,
|
|
processor=audio_processor,
|
|
)
|
|
else:
|
|
self.loaded_models[model_id_and_revision].reset_timer()
|
|
audio_model = self.loaded_models[model_id_and_revision].model
|
|
audio_processor = self.loaded_models[model_id_and_revision].processor
|
|
|
|
return audio_model, audio_processor
|
|
|
|
|
|
# set docstring separately to make it look nice (Typer doesn't play well with the class command)
|
|
Serve.__doc__ = """
|
|
Run a FastAPI server to serve models on-demand with an OpenAI compatible API.
|
|
|
|
Models will be loaded and unloaded automatically based on usage and a timeout.
|
|
|
|
\b
|
|
The server will expose the following endpoints:
|
|
- POST /v1/chat/completions: Generates chat completions.
|
|
- POST /v1/responses: Generates responses.
|
|
- POST /v1/audio/transcriptions: Generates transcriptions from audio.
|
|
- GET /v1/models: Lists available models for 3rd party tools.
|
|
|
|
Requires FastAPI and Uvicorn to be installed.
|
|
"""
|
|
|
|
if __name__ == "__main__":
|
|
serve = Serve()
|