# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import base64 import copy import enum import gc import io import json import re import tempfile import threading import time import uuid from collections.abc import Generator, Iterable from contextlib import asynccontextmanager from functools import lru_cache from io import BytesIO from threading import Thread from typing import TYPE_CHECKING, Annotated, Optional, TypedDict, Union import typer from huggingface_hub import scan_cache_dir from tokenizers.decoders import DecodeStream from tqdm import tqdm import transformers from transformers import AutoTokenizer, BitsAndBytesConfig, GenerationConfig, PreTrainedTokenizerBase from transformers.utils.import_utils import ( is_fastapi_available, is_librosa_available, is_openai_available, is_pydantic_available, is_uvicorn_available, is_vision_available, ) from .. import ( LogitsProcessorList, TextIteratorStreamer, ) from ..utils import logging if TYPE_CHECKING: from transformers import ( PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin, ) from ..generation.continuous_batching import ContinuousBatchingManager if is_librosa_available(): import librosa if is_vision_available(): from PIL import Image serve_dependencies_available = ( is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() ) if serve_dependencies_available: import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai.types.audio.transcription import Transcription from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) from openai.types.chat.chat_completion_chunk import ( Choice as ChoiceChunk, ) from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming from openai.types.responses import ( Response, ResponseCompletedEvent, ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, ResponseCreatedEvent, ResponseError, ResponseErrorEvent, ResponseFailedEvent, ResponseInProgressEvent, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, ResponseOutputMessage, ResponseOutputText, ResponseTextDeltaEvent, ResponseTextDoneEvent, ) from openai.types.responses.response_create_params import ResponseCreateParamsStreaming from pydantic import BaseModel, TypeAdapter, ValidationError # Expand OpenAI's request input types with an optional `generation_config` field class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): """ OpenAI's ResponseCreateParamsStreaming with an additional field for the generation config (as a json string). """ generation_config: str class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): """ OpenAI's CompletionCreateParamsStreaming with additional fields for the generation config (as a json string) and passing the request_id """ generation_config: str class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): """ OpenAI's TranscriptionCreateParamsBase with an additional field for the generation config (as a json string). """ file: bytes # Overwritten -- pydantic isn't happy with `typing.IO[bytes]`, present in the original type generation_config: str stream: bool = False # Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have built-in validation. response_validator = TypeAdapter(TransformersResponseCreateParamsStreaming) completion_validator = TypeAdapter(TransformersCompletionCreateParamsStreaming) transcription_validator = TypeAdapter(TransformersTranscriptionCreateParams) # Define request fields that are not yet used in `transformers serve`. Receiving these fields will raise an # HTTPException. UNUSED_RESPONSE_FIELDS = { "background", "include", "max_tool_calls", "previous_response_id", "prompt", "reasoning", "service_tier", "store", "text", "tool_choice", "top_logprobs", "truncation", "user", } UNUSED_CHAT_COMPLETION_FIELDS = { "audio", "function_call", "functions", "logprobs", "max_completion_tokens", "metadata", "modalities", "n", "parallel_tool_calls", "prediction", "presence_penalty", "reasoning_effort", "response_format", "service_tier", "stop", "store", "stream_options", "tool_choice", "top_logprobs", "user", "web_search_options", } UNUSED_TRANSCRIPTION_FIELDS = { "chunking_strategy", "include", "language", "prompt", "response_format", "timestamp_granularities", } logger = logging.get_logger(__name__) # Possible tokens that indicate the start/end of a tool call # TODO (joao, matt): streamline tool token detection logic _TOOL_CALL_TOKENS = { "qwen": { "start": "", "end": "", }, } _MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys()) X_REQUEST_ID = "x-request-id" def set_torch_seed(_seed): import torch torch.manual_seed(_seed) def reset_torch_cache(): import torch if torch.cuda.is_available(): torch.cuda.empty_cache() def torch_ones_like(_input_tensor): import torch return torch.ones_like(_input_tensor) class Modality(enum.Enum): LLM = "LLM" VLM = "VLM" STT = "STT" TTS = "TTS" def create_generation_config_from_req( req: dict, model_generation_config: GenerationConfig, **kwargs, ) -> GenerationConfig: """ Creates a generation config from the parameters of the request. If a generation config is passed in the request, it will be used as a baseline for parameterization. Otherwise, we will use the model's default generation config. Other parameters in the request will be applied on top of the baseline. Args: req (`dict`): The request which may optionally contain generation parameters. model_generation_config (`GenerationConfig`): The model's default generation config. kwargs (`dict`): Additional parameters to set in the generation config. Returns: The prepared `GenerationConfig` object. """ # If there is a generation config in the request, it is a json string serialization from a `GenerationConfig` # object. For simplicity, flags set here take precedence over all other flags. if req.get("generation_config") is not None: generation_config = GenerationConfig(**json.loads(req["generation_config"])) else: generation_config = copy.deepcopy(model_generation_config) non_standard_kwargs = generation_config.update(**kwargs) # Set extra kwargs that are not in the `GenerationConfig` class (e.g. continuous batching flags) for k, v in non_standard_kwargs.items(): if v is not None: setattr(generation_config, k, v) # Response-specific parameters if req.get("max_output_tokens") is not None: generation_config.max_new_tokens = int(req["max_output_tokens"]) # Completion-specific parameters if req.get("max_tokens") is not None: generation_config.max_new_tokens = int(req["max_tokens"]) if req.get("frequency_penalty") is not None: generation_config.repetition_penalty = float(req["frequency_penalty"]) if req.get("logit_bias") is not None: generation_config.sequence_bias = req["logit_bias"] if req.get("stop") is not None: generation_config.stop_strings = req["stop"] if req.get("temperature") is not None: generation_config.temperature = float(req["temperature"]) if float(req["temperature"]) == 0.0: generation_config.do_sample = False if req.get("top_p") is not None: generation_config.top_p = float(req["top_p"]) if req.get("seed") is not None: set_torch_seed(req["seed"]) return generation_config class ToolState: """Lightweight class to keep track of the tool call state.""" def __init__(self): self.reset() def reset(self): """Reset the tool call state (assumes we're outside a tool call).""" self.inside_tool_call = False self.has_tool_name_defined = False self.arg_nesting_level = 0 self.buffer = "" class TimedModel: """ A class that holds a PreTrainedModel instance and its associated processor. Automatically deletes the instances after a specified timeout. """ def __init__( self, model: "PreTrainedModel", timeout_seconds: int, processor: Union["ProcessorMixin", "PreTrainedTokenizerFast"] | None = None, ): self.model = model self._name_or_path = str(model.name_or_path) self.processor = processor self.timeout_seconds = timeout_seconds self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached) self._timer.start() def reset_timer(self): """Reset the timer for the deletion of the instances.""" self._timer.cancel() self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached) self._timer.start() def delete_model(self): """Delete the wrapped model and processor and clean up resources.""" if hasattr(self, "model") and self.model is not None: del self.model del self.processor self.model = None self.processor = None gc.collect() # Clear CUDA cache if available reset_torch_cache() # XXX: in case we manually delete the model, like on server shutdown self._timer.cancel() def timeout_reached(self): if self.timeout_seconds > 0: self.delete_model() logger.info( f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity" ) def is_deleted(self): """Check if the instances have been deleted.""" return not hasattr(self, "model") or self.model is None class Serve: # Defining a class to help with internal state but in practice it's just a method to call # TODO: refactor into a proper module with helpers + 1 main method def __init__( self, continuous_batching: Annotated[ bool, typer.Option(help="Whether to use continuous batching for chat completions.") ] = False, device: Annotated[ str, typer.Option( help="Device to use for inference; will default to `auto` and place the model on an accelerator if available." ), ] = "auto", dtype: Annotated[ str | None, typer.Option( help="Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype will be automatically derived from the model's weights." ), ] = "auto", trust_remote_code: Annotated[ bool, typer.Option(help="Whether to trust remote code when loading a model.") ] = False, attn_implementation: Annotated[ str | None, typer.Option( help="Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`." ), ] = None, quantization: Annotated[ str | None, typer.Option(help="Which quantization method to use. choices: 'bnb-4bit', 'bnb-8bit'"), ] = None, host: Annotated[str, typer.Option(help="Interface the server will listen to.")] = "localhost", port: Annotated[int, typer.Option(help="Port the server will listen to.")] = 8000, model_timeout: Annotated[ int, typer.Option(help="Time in seconds after which a model will be removed from memory.") ] = 300, log_level: Annotated[ str, typer.Option(help="Logging level as a string. Example: 'info' or 'warning'.") ] = "info", default_seed: Annotated[ int | None, typer.Option(help="The default seed for torch, should be an integer.") ] = None, enable_cors: Annotated[ bool, typer.Option( help="Whether to enable CORS. Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled." ), ] = False, input_validation: Annotated[bool, typer.Option(help="Whether to turn on strict input validation.")] = False, force_model: Annotated[ str | None, typer.Option( help="Name of the model to be forced on all requests. This is useful for testing Apps that don't allow changing models in the request." ), ] = None, non_blocking: Annotated[ bool, typer.Option(hidden=True, help="Whether to run the server in a separate thread.") ] = False, ) -> None: if not serve_dependencies_available: raise ImportError( "Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`" ) # Save input arguments self.continuous_batching = continuous_batching self.device = device self.dtype = dtype self.trust_remote_code = trust_remote_code self.attn_implementation = attn_implementation self.quantization = quantization self.host = host self.port = port self.model_timeout = model_timeout self.log_level = log_level self.default_seed = default_seed self.enable_cors = enable_cors self.input_validation = input_validation self.force_model = force_model self.non_blocking = non_blocking # Seed if default_seed is not None: set_torch_seed(default_seed) # Set up logging transformers_logger = logging.get_logger("transformers") transformers_logger.setLevel(logging.log_levels[log_level.lower()]) cb_logger = logging.get_logger("transformers.generation.continuous_batching") cb_logger.setLevel(logging.log_levels[log_level.lower()]) # Internal state: # 1. Tracks models in memory, to prevent reloading the model unnecessarily self.loaded_models: dict[str, TimedModel] = {} self.running_continuous_batching_manager: ContinuousBatchingManager | None = None # 2. preserves information about the last call and last KV cache, to determine whether we can reuse the KV # cache and avoid re-running prefill self.last_messages = None self.last_kv_cache = None self.last_model = None if self.model_timeout is None: self.model_timeout = -1 if self.force_model else 300 if self.force_model: model_id_and_revision = self.process_model_name(self.force_model) self.last_model = model_id_and_revision self.load_model_and_processor(model_id_and_revision) @asynccontextmanager async def lifespan(app: FastAPI): yield for model in self.loaded_models.values(): model.delete_model() if self.running_continuous_batching_manager is not None: self.running_continuous_batching_manager.stop(block=True, timeout=5) app = FastAPI(lifespan=lifespan) # Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled. However, for # security purposes, it's disabled by default if self.enable_cors: app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) logger.warning_once( "CORS allow origin is set to `*`. This is not recommended for production environments." ) from fastapi import Request @app.post("/v1/chat/completions") def chat_completion(request: Request, body: dict): self.validate_chat_completion_request(request=body) if self.continuous_batching: return self.continuous_batching_chat_completion(body, request.state.request_id) else: return self.generate_chat_completion(body) @app.post("/v1/responses") def responses(request: dict): self.validate_response_request(request=request) # Support non-streaming mode when `stream=false` is provided stream = request.get("stream", True) if not stream: response_obj = self.generate_response_non_streaming(request) return JSONResponse(response_obj) output = self.generate_response(request) return StreamingResponse(output, media_type="text/event-stream") @app.post("/v1/audio/transcriptions") async def audio_transcriptions(request: Request): # Parses the multipart/form-data request into the request format used by other endpoints async with request.form() as form: parsed_request = TransformersTranscriptionCreateParams( file=await form["file"].read(), model=form["model"], # TODO: add other fields ) logger.debug( f"Received file: {form['file'].filename}; MIME type: {form['file'].content_type}; " f"size: {form['file'].size / 1024:.2f} KiB" ) self.validate_transcription_request(request=parsed_request) output = self.generate_transcription(parsed_request) return StreamingResponse(output, media_type="text/event-stream") @app.options("/v1/models") @app.get("/v1/models") def get_all_models(): return JSONResponse({"object": "list", "data": self.get_gen_models()}) @app.get("/health") def healthcheck(): return JSONResponse({"status": "ok"}) @app.middleware("http") async def get_or_set_request_id(request: Request, call_next): request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) request.state.request_id = request_id response = await call_next(request) response.headers[X_REQUEST_ID] = request_id return response config = uvicorn.Config(app, host=self.host, port=self.port, log_level=self.log_level) self.server = uvicorn.Server(config) if self.non_blocking: self.start_server() else: self.server.run() def start_server(self): def _run(): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) # serve() is a coroutine; it exits when server.should_exit becomes True self._loop.run_until_complete(self.server.serve()) self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False) self._thread.start() def kill_server(self): if not self._thread: raise ValueError("The server cannot be killed as it was not launched in a separate thread.") if not self._thread.is_alive(): raise ValueError("The server is already killed.") self.server.should_exit = True if self._thread and self._thread.is_alive(): self._thread.join(timeout=2) def _validate_request( self, request: dict, schema: TypedDict, validator: TypeAdapter, unused_fields: set, ): """ Validates the request against the schema, and checks for unexpected keys. Args: request (`dict`): The request to validate. schema (`TypedDict`): The schema of the request to validate. It is a `TypedDict` definition. validator (`TypeAdapter`): The validator to use to validate the request. Built from `schema`. unused_fields (`set`): Fields accepted by `schema`, but not used in `transformers serve`. Raises: HTTPException: If the request is invalid or contains unexpected or unused fields. """ logger.debug(f"Validating request: {request}") # Validate unexpected keys -- Pydantic doesn't validate extra keys in the request. input_keys = set(request.keys()) possible_keys = schema.__mutable_keys__ unexpected_keys = input_keys - possible_keys if unexpected_keys: logger.error(f"Unexpected keys in the request: {unexpected_keys}") raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected_keys}") if self.input_validation: # Validate expected keys try: validator.validate_python(request) except ValidationError as e: logger.error(f"Validation error: {e.errors()}") raise HTTPException(status_code=422, detail=e.errors()) # Validate unused fields unused_fields_in_request = input_keys & unused_fields if unused_fields_in_request: logger.error(f"Unused fields in the request: {unused_fields_in_request}") raise HTTPException( status_code=422, detail=f"Unused fields in the request: {unused_fields_in_request}" ) def validate_response_request(self, request: dict): self._validate_request( request=request, schema=TransformersResponseCreateParamsStreaming, validator=response_validator, unused_fields=UNUSED_RESPONSE_FIELDS, ) def validate_chat_completion_request(self, request: dict): self._validate_request( request=request, schema=TransformersCompletionCreateParamsStreaming, validator=completion_validator, unused_fields=UNUSED_CHAT_COMPLETION_FIELDS, ) def validate_transcription_request(self, request: dict): self._validate_request( request=request, schema=TransformersTranscriptionCreateParams, validator=transcription_validator, unused_fields=UNUSED_TRANSCRIPTION_FIELDS, ) def build_chat_completion_chunk( self, request_id: str = "", content: int | None = None, model: str | None = None, role: str | None = None, finish_reason: str | None = None, tool_calls: list[ChoiceDeltaToolCall] | None = None, decode_stream: DecodeStream | None = None, tokenizer: Optional["PreTrainedTokenizerFast"] = None, ) -> ChatCompletionChunk: """ Builds a chunk of a streaming OpenAI Chat Completion response. IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps, like Cursor, assume that when the field exists, it has data. Args: request_id (`str`): The request ID. content (`str`, *optional*): Content of the response from the model. model (`str`, *optional*): The model that generated the content. role (`str`, *optional*): The role of the next content, until a new role is defined. finish_reason (`str`, *optional*): The reason the generation by the model has finished. tool_calls (`list[ChoiceDeltaToolCall]`, *optional*): Data about the tool calls, when they are triggered. Returns: `str`: The built chunk, a string containing a JSON string with the payload. """ if decode_stream is not None and content is not None and tokenizer is not None: content = decode_stream.step(tokenizer._tokenizer, content) chunk = ChatCompletionChunk( id=request_id, created=int(time.time()), model=model, choices=[ ChoiceChunk( delta=ChoiceDelta( content=content, role=role, tool_calls=tool_calls, ), index=0, finish_reason=finish_reason, ) ], system_fingerprint="", object="chat.completion.chunk", ) return chunk @staticmethod def chunk_to_sse_element(chunk: ChatCompletionChunk | BaseModel) -> str: """ Builds an event of a streaming OpenAI Response model or a ChatCompletion chunk. IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps, like Cursor, assume that when the field exists, it has data. Args: chunk (`BaseModel` or `ChatCompletionChunk`): The response to build an event from. One of the multiple OpenAI Response output types Returns: `str`: The built chunk, a string containing a JSON string with the payload. """ return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" @staticmethod @lru_cache def get_gen_models(cache_dir: str | None = None) -> list[dict[str, any]]: """ List LLMs and VLMs in the cache. """ from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, ) generative_models = [] logger.warning("Scanning the cache directory for LLMs and VLMs.") for repo in tqdm(scan_cache_dir(cache_dir).repos): if repo.repo_type != "model": continue refs = repo.refs for ref, revision_info in refs.items(): files = revision_info.files config_path = next((f.file_path for f in files if f.file_name == "config.json"), None) if not config_path: continue config = json.loads(config_path.open().read()) if not (isinstance(config, dict) and "architectures" in config): continue architectures = config["architectures"] llms = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values() vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() if any(arch for arch in architectures if arch in [*llms, *vlms]): author = repo.repo_id.split("/") if "/" in repo.repo_id else "" repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "") generative_models.append( { "owned_by": author, "id": repo_handle, "object": "model", "created": repo.last_modified, } ) return generative_models def continuous_batching_chat_completion(self, req: dict, request_id: str) -> StreamingResponse | JSONResponse: """ Generates an OpenAI Chat Completion using continuous batching. Args: req (`dict`): The request to generate an OpenAI Chat Completion for. Returns: `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks. """ model_id_and_revision = self.process_model_name(req["model"]) must_discard_cache = model_id_and_revision != self.last_model self.last_model = model_id_and_revision # When switching models, terminate a continuous batching manager if it is running. if must_discard_cache: if self.running_continuous_batching_manager is not None: self.running_continuous_batching_manager.stop(block=True, timeout=2) self.running_continuous_batching_manager = None model, processor = self.load_model_and_processor(model_id_and_revision) tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor generation_config = create_generation_config_from_req( req, model_generation_config=model.generation_config, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, use_cache=False, do_sample=False, scheduler="fifo", ) if self.running_continuous_batching_manager is None: self.running_continuous_batching_manager = model.init_continuous_batching( generation_config=generation_config ) # TODO (Joao, Lysandre): the logits processors should be fixed in continuous batching and correctly applied in non-cb self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() self.running_continuous_batching_manager.start() # TODO (Joao, Lysandre): this should also work with tool support inputs = processor.apply_chat_template( req["messages"], return_tensors="pt", add_generation_prompt=True, return_dict=True ).to(model.device)["input_ids"][0] def stream_chat_completion(request_id, decode_stream): from ..generation.continuous_batching import RequestStatus try: # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit # they come from the assistant. yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision) n_tokens_generated = 0 for result in self.running_continuous_batching_manager.request_id_iter(request_id): n_tokens_generated += 1 # Always yield the token content (even for the final FINISHED token) if result.generated_tokens: token_id = result.generated_tokens[-1] yield self.build_chat_completion_chunk( request_id=request_id, content=token_id, model=model_id_and_revision, decode_stream=decode_stream, tokenizer=tokenizer, ) if result.status == RequestStatus.FINISHED: generated_all_tokens = n_tokens_generated >= generation_config.max_new_tokens # If the tokenizer has an eos_token, we can have a more robust check. if hasattr(tokenizer, "eos_token"): final_token_is_eos = result == tokenizer.eos_token generated_all_tokens = generated_all_tokens and not final_token_is_eos reason = "length" if generated_all_tokens else "stop" yield self.build_chat_completion_chunk( request_id, finish_reason=reason, model=model_id_and_revision, ) break except Exception as e: logger.error(str(e)) self.running_continuous_batching_manager.cancel_request(request_id) yield f'data: {{"error": "{str(e)}"}}' def buffer_chat_completion(_request_id): result = None while self.running_continuous_batching_manager.is_running() and result is None: result = self.running_continuous_batching_manager.get_result(request_id=_request_id, timeout=1) content = tokenizer.decode(result.generated_tokens) chat_completion_result = ChatCompletion( id=_request_id, created=int(time.time()), object="chat.completion", model=model_id_and_revision, choices=[ Choice( # TODO check the index index=0, message=ChatCompletionMessage(content=content, role="assistant"), finish_reason="stop", ) ], # TODO implement function calling # TODO implement usage ) return chat_completion_result async def cancellation_wrapper_stream(_request_id): # Enables cancellation in an async context try: decode_stream = DecodeStream(inputs.tolist(), False) for _chunk in stream_chat_completion(_request_id, decode_stream): yield self.chunk_to_sse_element(_chunk) await asyncio.sleep(0) except asyncio.CancelledError: self.running_continuous_batching_manager.cancel_request(_request_id) logger.warning(f"Request {_request_id} was cancelled.") def cancellation_wrapper_buffer(_request_id): # Enables cancellation in an async context try: return buffer_chat_completion(_request_id) except asyncio.CancelledError: self.running_continuous_batching_manager.cancel_request(_request_id) logger.warning(f"Request {_request_id} was cancelled.") request_id = self.running_continuous_batching_manager.add_request( inputs, request_id=request_id, max_new_tokens=generation_config.max_new_tokens, streaming=req.get("stream") ) if req.get("stream"): return StreamingResponse(cancellation_wrapper_stream(request_id), media_type="text/event-stream") else: chunk = cancellation_wrapper_buffer(request_id) json_chunk = chunk.model_dump_json(exclude_none=True) return JSONResponse(json_chunk, media_type="application/json") @staticmethod def get_model_modality(model: "PreTrainedModel", processor=None) -> Modality: if processor is not None: if isinstance(processor, PreTrainedTokenizerBase): return Modality.LLM from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, ) model_classname = model.__class__.__name__ if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values(): modality = Modality.VLM elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): modality = Modality.LLM else: raise ValueError(f"Unknown modality: {model_classname}") return modality @staticmethod def get_processor_inputs_from_inbound_messages(messages, modality: Modality): processor_inputs = [] for message in messages: parsed_message = {"role": message["role"], "content": []} if modality == Modality.LLM: # Input: `content` is a string or a list of dictionaries with a "text" key. # Output: `content` is a string. if isinstance(message["content"], str): parsed_content = message["content"] elif isinstance(message["content"], list): parsed_content = [] for content in message["content"]: if content["type"] == "text": parsed_content.append(content["text"]) parsed_content = " ".join(parsed_content) parsed_message["content"] = parsed_content elif modality == Modality.VLM: # Input: `content` is a string or a list of dictionaries with a "type" key (possible types: "text", # "image_url"). # Output: `content` is a list of dictionaries with a "type" key if isinstance(message["content"], str): parsed_message["content"].append({"type": "text", "text": message["content"]}) else: for content in message["content"]: if content["type"] == "text": parsed_message["content"].append(content) elif content["type"] == "image_url": if "base64" in content["image_url"]["url"]: image_data = re.sub("^data:image/.+;base64,", "", content["image_url"]["url"]) image = Image.open(BytesIO(base64.b64decode(image_data))) file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) url = file.name image.save(file.name) else: url = content["image_url"]["url"] parsed_message["content"].append({"type": "image", "url": url}) processor_inputs.append(parsed_message) return processor_inputs def generate_chat_completion(self, req: dict) -> StreamingResponse | JSONResponse: """ Generates an OpenAI Chat Completion using `generate`. Args: req (`dict`): The request to generate an OpenAI Chat Completion for. Returns: `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks. """ # TODO: This should throw an error in case the specified model in the request is different to the forced model. if self.force_model is not None: req["model"] = self.force_model messages: Iterable[ChatCompletionMessageParam] = req["messages"] # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a # request whose last message is from the assistant. if messages[-1]["role"] == "assistant": return model_id_and_revision = self.process_model_name(req["model"]) must_discard_cache = model_id_and_revision != self.last_model self.last_model = model_id_and_revision model, processor = self.load_model_and_processor(model_id_and_revision) modality = self.get_model_modality(model, processor=processor) processor_inputs = self.get_processor_inputs_from_inbound_messages(messages, modality) # ====== TOOL PREPROCESSING LOGIC ====== tool_model_family = None for supported_model_families in _MODELS_WITH_TOOL_SUPPORT: if supported_model_families in model.config.architectures[0].lower(): tool_model_family = supported_model_families break # TODO: trigger 2 constrained generations after the tool call start token is emitted: # 1. force generation to pick from the tool names # 2. force generation to pick from that tool's arguments # ====== END OF TOOL PREPROCESSING LOGIC ====== inputs = processor.apply_chat_template( processor_inputs, add_generation_prompt=True, tools=req.get("tools"), return_tensors="pt", return_dict=True, tokenize=True, ) inputs = inputs.to(model.device) request_id = req.get("request_id", "req_0") # Temporary hack for GPTOSS 1: don't filter special tokens skip_special_tokens = True if "gptoss" in model.config.architectures[0].lower(): skip_special_tokens = False generation_streamer = TextIteratorStreamer( processor, skip_special_tokens=skip_special_tokens, skip_prompt=True, ) generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) last_kv_cache = None if self.is_continuation(req) and not must_discard_cache: seq_len = self.last_kv_cache.get_seq_length() if inputs["input_ids"].shape[-1] > seq_len: last_kv_cache = self.last_kv_cache generation_kwargs = { **inputs, "streamer": generation_streamer, "generation_config": generation_config, "return_dict_in_generate": True, "past_key_values": last_kv_cache, } def stream_chat_completion(streamer, _request_id): # Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output # classes and piping the reasoning trace into a new field filter_cot = False cot_trace_end = None if "gptoss" in model.config.architectures[0].lower(): filter_cot = True cot_trace_end = "<|channel|>final<|message|>" # Thin wrapper to save the KV cache after generation def generate_with_cache(**kwargs): generate_output = model.generate(**kwargs) self.last_kv_cache = generate_output.past_key_values thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) results = "" try: thread.start() tool_state = ToolState() # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit # they come from the assistant. yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision) result = "" n_tokens_generated = 0 for result in streamer: n_tokens_generated += 1 # Temporary hack for GPT-OSS 3: don't emit the final "<|return|>" if "gptoss" in model.config.architectures[0].lower(): result = result.removesuffix("<|return|>") results += result # (related to temporary hack 2) if filter_cot: if cot_trace_end in results: # end of reasoning trace observed -> stop filtering filter_cot = False continue else: continue # ====== TOOL CALL LOGIC ====== if tool_model_family is not None: # Start of a tool call: reset state variables, set `inside_tool_call` if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]: tool_state.inside_tool_call = True continue # End of tool call: reset `inside_tool_call`, emit a `finish_reason` if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]: tool_state.reset() yield self.build_chat_completion_chunk( request_id=_request_id, role=None, finish_reason="tool_calls", model=model_id_and_revision, ) continue # Inside a tool call if tool_state.inside_tool_call: tool_state.buffer += result # First step: extract the tool name (may need several tokens, and we can't emit a delta # until we have the full name) if not tool_state.has_tool_name_defined: tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer) if tool_name is None: continue else: tool_name = tool_name.group(1) tool_state.has_tool_name_defined = True tool = ChoiceDeltaToolCall( function=ChoiceDeltaToolCallFunction(name=tool_name), index=0, type="function", id=_request_id + "_tool_call", # Only the first tool call delta has an id ) # Second step: extract tool arguments. The tool arguments can be seen as a json string # within the tool json string. We emit a delta for the arguments. else: # Empty text: skip if result == "": continue # Until we see the `"arguments": {` in the buffer, we skip # TODO: other models will likely need more elaborate processing here if '"arguments": {' not in tool_state.buffer: continue # Handle nesting. We want to exclude the last } from the emitted arguments (it's # closing the outermost nesting level, outside the arguments block) tool_state.arg_nesting_level += result.count("{") tool_state.arg_nesting_level -= result.count("}") if tool_state.arg_nesting_level < 0: result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}" tool = ChoiceDeltaToolCall( function=ChoiceDeltaToolCallFunction(arguments=result), index=0, type="function", ) yield self.build_chat_completion_chunk( request_id=_request_id, role=None, tool_calls=[tool], model=model_id_and_revision, ) continue # ====== END OF TOOL CALL LOGIC ====== # All non-tool related tokens are emitted as assistant messages. Empty text is skipped. if result != "": yield self.build_chat_completion_chunk( _request_id, content=result, model=model_id_and_revision ) generated_all_tokens = n_tokens_generated >= generation_config.max_new_tokens # If the tokenizer has an eos_token, we can have a more robust check. if hasattr(streamer.tokenizer, "eos_token"): final_token_is_eos = result == streamer.tokenizer.eos_token generated_all_tokens = generated_all_tokens and not final_token_is_eos reason = "length" if generated_all_tokens else "stop" yield self.build_chat_completion_chunk(_request_id, finish_reason=reason, model=model_id_and_revision) thread.join() except Exception as e: logger.error(str(e)) yield f'data: {{"error": "{str(e)}"}}' finally: thread.join() if req.get("stream"): return StreamingResponse( map(self.chunk_to_sse_element, stream_chat_completion(generation_streamer, request_id)), media_type="text/event-stream", ) else: content = [] finish_reason = "stop" generator = stream_chat_completion(generation_streamer, request_id) usage = None for chunk in generator: choice = chunk.choices[0] if getattr(choice.delta, "content", None): content.append(choice.delta.content) if choice.finish_reason: finish_reason = choice.finish_reason if getattr(chunk, "usage", None): usage = chunk.usage chat_completion_result = ChatCompletion( id=request_id, created=int(time.time()), object="chat.completion", model=model_id_and_revision, choices=[ Choice( # TODO check the index index=0, message=ChatCompletionMessage(content="".join(content), role="assistant"), finish_reason=finish_reason, ) ], # TODO implement function calling usage=usage, ) result = chat_completion_result.model_dump(exclude_none=True) return JSONResponse(result, media_type="application/json") def generate_response(self, req: dict) -> Generator[str, None, None]: """ Generates an OpenAI Response using `generate`. Args: req (`dict`): The request to generate an OpenAI Response for. Returns: `Generator[str, None, None]`: A generator that yields the OpenAI Response events. """ # TODO -- Implement non-streaming mode model_id_and_revision = self.process_model_name(req["model"]) must_discard_cache = model_id_and_revision != self.last_model self.last_model = model_id_and_revision model, processor = self.load_model_and_processor(model_id_and_revision) if isinstance(req["input"], str): inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] inputs.append({"role": "user", "content": req["input"]}) elif isinstance(req["input"], list): if "instructions" in req: if req["input"][0]["role"] != "system": inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]] else: inputs = req["input"] inputs[0]["content"] = req["instructions"] else: inputs = req["input"] elif isinstance(req["input"], dict): inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] inputs.append(req["input"]) else: raise TypeError("inputs should be a list, dict, or str") inputs = processor.apply_chat_template( inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True )["input_ids"] inputs = inputs.to(model.device) request_id = req.get("previous_response_id", "req_0") # Temporary hack for GPT-OSS 1: don't filter special tokens skip_special_tokens = True if "gptoss" in model.config.architectures[0].lower(): skip_special_tokens = False generation_streamer = TextIteratorStreamer( processor, skip_special_tokens=skip_special_tokens, skip_prompt=True, ) generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) last_kv_cache = None if self.is_continuation(req) and not must_discard_cache: seq_len = self.last_kv_cache.get_seq_length() if inputs["input_ids"].shape[-1] > seq_len: last_kv_cache = self.last_kv_cache generation_kwargs = { "inputs": inputs, "attention_mask": torch_ones_like(inputs), "streamer": generation_streamer, "generation_config": generation_config, "return_dict_in_generate": True, "past_key_values": last_kv_cache, } def stream_response(streamer, _request_id): # Temporary hack for GPT-OSS 2: filter out the CoT tokens. Full solution here implies defining new output # classes and piping the reasoning trace into a new field filter_cot = False cot_trace_end = None if "gptoss" in model.config.architectures[0].lower(): filter_cot = True cot_trace_end = "<|channel|>final<|message|>" # Thin wrapper to save the KV cache after generation def generate_with_cache(**kwargs): generate_output = model.generate(**kwargs) self.last_kv_cache = generate_output.past_key_values thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) sequence_number = 0 output_index = 0 content_index = 0 try: thread.start() created_at = time.time() # the spec expects a unix timestamp in seconds # We start by acknowledging the request (the request has `status="queued"`), and then by moving it to # in progress (`status="in_progress"`) response_created = ResponseCreatedEvent( type="response.created", sequence_number=sequence_number, response=Response( id=f"resp_{request_id}", created_at=created_at, status="queued", model=model_id_and_revision, instructions=req.get("instructions"), text={"format": {"type": "text"}}, object="response", tools=[], output=[], parallel_tool_calls=req.get("parallel_tool_calls", False), tool_choice="auto", metadata=req.get("metadata"), ), ) sequence_number += 1 yield self.chunk_to_sse_element(response_created) response_in_progress = ResponseInProgressEvent( type="response.in_progress", sequence_number=sequence_number, response=Response( id=f"resp_{request_id}", created_at=created_at, status="in_progress", model=model_id_and_revision, instructions=req.get("instructions"), text={"format": {"type": "text"}}, object="response", tools=[], output=[], parallel_tool_calls=req.get("parallel_tool_calls", False), tool_choice="auto", metadata=req.get("metadata"), ), ) sequence_number += 1 yield self.chunk_to_sse_element(response_in_progress) # Start the output item. Emit the assistant role to start the stream. Other chunks won't have a role, # as it is implicit response_output_item_added = ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=sequence_number, output_index=output_index, item=ResponseOutputMessage( id=f"msg_{request_id}", type="message", status="in_progress", role="assistant", content=[] ), ) sequence_number += 1 yield self.chunk_to_sse_element(response_output_item_added) # Start the content part of the event response_content_part_added = ResponseContentPartAddedEvent( type="response.content_part.added", item_id=f"msg_{request_id}", sequence_number=sequence_number, output_index=output_index, content_index=content_index, part=ResponseOutputText(type="output_text", text="", annotations=[]), ) sequence_number += 1 yield self.chunk_to_sse_element(response_content_part_added) # Stream the actual generated text results = "" for result in streamer: # Temporary hack for GPTOS 3: don't emit the final "<|return|>" if "gptoss" in model.config.architectures[0].lower(): result = result.removesuffix("<|return|>") results += result # (related to temporary hack 2) if filter_cot: if cot_trace_end in results: # end of reasoning trace observed -> stop filtering filter_cot = False results = "" # reset the results -> results will now track the final response continue else: response_output_text_delta = ResponseTextDeltaEvent( type="response.output_text.delta", item_id=f"msg_{request_id}", sequence_number=sequence_number, output_index=output_index, content_index=content_index, delta=result, logprobs=[], ) sequence_number += 1 yield self.chunk_to_sse_element(response_output_text_delta) else: # Normal path: emit token deltas when not filtering CoT if result: response_output_text_delta = ResponseTextDeltaEvent( type="response.output_text.delta", item_id=f"msg_{request_id}", sequence_number=sequence_number, output_index=output_index, content_index=content_index, delta=result, logprobs=[], ) sequence_number += 1 yield self.chunk_to_sse_element(response_output_text_delta) # Signal the end of the text generation response_output_text_done = ResponseTextDoneEvent( type="response.output_text.done", item_id=f"msg_{request_id}", sequence_number=sequence_number, output_index=output_index, content_index=0, text=results, logprobs=[], ) sequence_number += 1 yield self.chunk_to_sse_element(response_output_text_done) # Complete the content part response_content_part_done = ResponseContentPartDoneEvent( type="response.content_part.done", item_id=f"msg_{request_id}", sequence_number=sequence_number, output_index=output_index, content_index=content_index, part=ResponseOutputText(type="output_text", text=response_output_text_done.text, annotations=[]), ) sequence_number += 1 content_index += 1 yield self.chunk_to_sse_element(response_content_part_done) # Complete the output item response_output_item_done = ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=sequence_number, output_index=output_index, item=ResponseOutputMessage( id=f"msg_{request_id}", type="message", status="completed", role="assistant", content=[response_content_part_done.part], annotations=[], ), ) sequence_number += 1 output_index += 1 yield self.chunk_to_sse_element(response_output_item_done) # Finally, Complete the event response_completed = ResponseCompletedEvent( type="response.completed", sequence_number=sequence_number, response=Response( id=f"resp_{request_id}", created_at=created_at, status="completed", model=model_id_and_revision, instructions=req.get("instructions"), text={"format": {"type": "text"}}, output=[response_output_item_done.item], object="response", tools=[], parallel_tool_calls=req.get("parallel_tool_calls", False), tool_choice="auto", metadata=req.get("metadata"), ), ) sequence_number += 1 yield self.chunk_to_sse_element(response_completed) thread.join() except Exception as e: logger.error(f"Exception in response generation: {str(e)}") error_event = ResponseErrorEvent( type="error", sequence_number=sequence_number, message=str(e), ) sequence_number += 1 yield self.chunk_to_sse_element(error_event) response_failed = ResponseFailedEvent( type="response.failed", sequence_number=sequence_number, response=Response( id=f"resp_{request_id}", created_at=created_at, status="failed", model=model_id_and_revision, instructions=req.get("instructions"), text={"format": {"type": "text"}}, output=[], object="response", tools=[], parallel_tool_calls=False, tool_choice="auto", metadata=req.get("metadata"), error=ResponseError( code="server_error", message=str(e), ), ), ) sequence_number += 1 yield self.chunk_to_sse_element(response_failed) finally: thread.join() return stream_response(generation_streamer, request_id) def generate_response_non_streaming(self, req: dict) -> dict: """ Generates an OpenAI Response in non-streaming mode (single JSON payload). Args: req (`dict`): The request to generate an OpenAI Response for. Returns: `dict`: The OpenAI `Response` serialized as a dict. """ model_id_and_revision = self.process_model_name(req["model"]) must_discard_cache = model_id_and_revision != self.last_model self.last_model = model_id_and_revision model, processor = self.load_model_and_processor(model_id_and_revision) if isinstance(req["input"], str): inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] inputs.append({"role": "user", "content": req["input"]}) elif isinstance(req["input"], list): if "instructions" in req: if req["input"][0]["role"] != "system": inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]] else: inputs = req["input"] inputs[0]["content"] = req["instructions"] else: inputs = req["input"] elif isinstance(req["input"], dict): inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] inputs.append(req["input"]) else: raise ValueError("inputs should be a list, dict, or str") inputs = processor.apply_chat_template( inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True )["input_ids"] inputs = inputs.to(model.device) request_id = req.get("previous_response_id", "req_0") # Temporary hack for GPTOSS 1: don't filter special tokens skip_special_tokens = True if "gptoss" in model.config.architectures[0].lower(): skip_special_tokens = False generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) last_kv_cache = None if self.is_continuation(req) and not must_discard_cache: seq_len = self.last_kv_cache.get_seq_length() if inputs.shape[-1] > seq_len: last_kv_cache = self.last_kv_cache generate_output = model.generate( inputs=inputs, attention_mask=torch_ones_like(inputs), generation_config=generation_config, return_dict_in_generate=True, past_key_values=last_kv_cache, ) # save KV cache self.last_kv_cache = generate_output.past_key_values # Decode full text full_text = processor.batch_decode(generate_output.sequences, skip_special_tokens=skip_special_tokens)[0] created_at = time.time() response_output_item = ResponseOutputMessage( id=f"msg_{request_id}", type="message", status="completed", role="assistant", content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], annotations=[], ) response_completed = Response( id=f"resp_{request_id}", created_at=created_at, status="completed", model=model_id_and_revision, instructions=req.get("instructions"), text={"format": {"type": "text"}}, output=[response_output_item], object="response", tools=[], parallel_tool_calls=req.get("parallel_tool_calls", False), tool_choice="auto", metadata=req.get("metadata"), ) return response_completed.model_dump(exclude_none=True) def generate_transcription(self, req: dict) -> Generator[str, None, None]: """ Generates an OpenAI Transcription using the audio file. Args: req (`dict`): The request containing the audio file and model information. Returns: `Generator[str, None, None]`: A generator that yields the transcription result. """ # TODO: implement streaming transcription (currently, it's not streaming) if not is_librosa_available(): raise ImportError( "Missing librosa dependency for audio transcription. Please install with `pip install librosa`" ) model_id_and_revision = self.process_model_name(req["model"]) audio_model, audio_processor = self.load_audio_model_and_processor(model_id_and_revision) generation_streamer = TextIteratorStreamer( audio_processor.tokenizer, skip_special_tokens=True, skip_prompt=True ) generation_config = create_generation_config_from_req( req, model_generation_config=audio_model.generation_config ) # Read the binary audio file using librosa model_sampling_rate = audio_processor.feature_extractor.sampling_rate audio_bytes = io.BytesIO(req["file"]) audio_array, _ = librosa.load(audio_bytes, sr=model_sampling_rate, mono=True) audio_inputs = audio_processor(audio_array, sampling_rate=model_sampling_rate, return_tensors="pt").to( audio_model.device ) audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype) generation_kwargs = { "streamer": generation_streamer, "generation_config": generation_config, "return_dict_in_generate": True, } def _generate_transcription(): generated_ids = audio_model.generate(**audio_inputs, **generation_kwargs) transcription_text = audio_processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)[0] transcription = Transcription(text=transcription_text) yield f"{transcription.model_dump_json(exclude_none=True)}" return _generate_transcription() def is_continuation(self, req: dict) -> bool: """ Determines whether the current request is a continuation of the last request. In other words, if it is the same chat session. Args: req (`dict`): The request to check. Returns: `True` if the request is a continuation of the last request, `False` otherwise. """ messages = req.get("messages") or req.get("input") # ChatCompletion and Response have different fields req_continues_last_messages = True # No cached messages: this is a new request if self.last_messages is None: req_continues_last_messages = False # The new request has no new rounds of conversation: this is a new request elif len(self.last_messages) >= len(messages): req_continues_last_messages = False # Otherwise, check that the last messages are a subset of the new request else: for i in range(len(self.last_messages)): if self.last_messages[i] != messages[i]: req_continues_last_messages = False break self.last_messages = messages return req_continues_last_messages def get_quantization_config(self) -> BitsAndBytesConfig | None: """ Returns the quantization config for the given CLI arguments. Returns: `Optional[BitsAndBytesConfig]`: The quantization config. """ if self.quantization == "bnb-4bit": quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) elif self.quantization == "bnb-8bit": quantization_config = BitsAndBytesConfig(load_in_8bit=True) else: quantization_config = None if quantization_config is not None: logger.info(f"Quantization applied with the following config: {quantization_config}") return quantization_config def process_model_name(self, model_id: str) -> str: """ Applies the `force_model` CLI argument and canonicalizes the model name to the format "model_id@revision". If the model_id DOESN'T contain an @, it defaults to "model_id@main". Args: model_id (`str`): The model ID. Returns: `str`: The canonicalized model name to be used """ if self.force_model is not None: model_id = self.force_model if "@" in model_id: return model_id return f"{model_id}@main" def _load_model_and_data_processor(self, model_id_and_revision: str): """ Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI arguments. Args: model_id_and_revision (`str`): The model ID and revision to load. model_cls (`type[PreTrainedModel]`): The model class to load. Returns: `tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]`: The loaded model and data processor (tokenizer, audio processor, etc.). """ import torch from transformers import AutoConfig, AutoProcessor logger.info(f"Loading {model_id_and_revision}") if "@" in model_id_and_revision: model_id, revision = model_id_and_revision.split("@", 1) else: model_id, revision = model_id_and_revision, "main" try: data_processor = AutoProcessor.from_pretrained( model_id, revision=revision, trust_remote_code=self.trust_remote_code, ) except OSError: try: data_processor = AutoTokenizer.from_pretrained( model_id, revision=revision, trust_remote_code=self.trust_remote_code, ) except OSError: raise OSError("Failed to load processor with `AutoProcessor` and `AutoTokenizer`.") dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype) quantization_config = self.get_quantization_config() model_kwargs = { "revision": revision, "attn_implementation": self.attn_implementation, "dtype": dtype, "device_map": self.device, "trust_remote_code": self.trust_remote_code, "quantization_config": quantization_config, } config = AutoConfig.from_pretrained(model_id, **model_kwargs) architecture = getattr(transformers, config.architectures[0]) model = architecture.from_pretrained(model_id, **model_kwargs) has_default_max_length = ( model.generation_config.max_new_tokens is None and model.generation_config.max_length == 20 ) has_short_max_new_tokens = ( model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 1024 ) if has_default_max_length or has_short_max_new_tokens: model.generation_config.max_new_tokens = 1024 logger.info(f"Loaded model {model_id_and_revision}") return model, data_processor def load_model_and_processor( self, model_id_and_revision: str ) -> tuple["PreTrainedModel", "PreTrainedTokenizerFast"]: """ Loads the text model and processor from the given model ID and revision into the ServeCommand instance. Args: model_id_and_revision (`str`): The model ID and revision to load. Returns: `tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and processor. """ if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted(): model, processor = self._load_model_and_data_processor(model_id_and_revision) self.loaded_models[model_id_and_revision] = TimedModel( model, timeout_seconds=self.model_timeout, processor=processor, ) else: self.loaded_models[model_id_and_revision].reset_timer() model = self.loaded_models[model_id_and_revision].model processor = self.loaded_models[model_id_and_revision].processor return model, processor def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple["PreTrainedModel", "ProcessorMixin"]: """ Loads the audio model and processor from the given model ID and revision into the ServeCommand instance. Args: model_id_and_revision (`str`): The model ID and revision to load. Returns: `tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor. """ if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted(): audio_model, audio_processor = self._load_model_and_data_processor(model_id_and_revision) self.loaded_models[model_id_and_revision] = TimedModel( audio_model, timeout_seconds=self.model_timeout, processor=audio_processor, ) else: self.loaded_models[model_id_and_revision].reset_timer() audio_model = self.loaded_models[model_id_and_revision].model audio_processor = self.loaded_models[model_id_and_revision].processor return audio_model, audio_processor # set docstring separately to make it look nice (Typer doesn't play well with the class command) Serve.__doc__ = """ Run a FastAPI server to serve models on-demand with an OpenAI compatible API. Models will be loaded and unloaded automatically based on usage and a timeout. \b The server will expose the following endpoints: - POST /v1/chat/completions: Generates chat completions. - POST /v1/responses: Generates responses. - POST /v1/audio/transcriptions: Generates transcriptions from audio. - GET /v1/models: Lists available models for 3rd party tools. Requires FastAPI and Uvicorn to be installed. """ if __name__ == "__main__": serve = Serve()