You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

382 lines
19 KiB

# Copyright 2025 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from abc import ABC, abstractmethod
from collections import deque
from ...utils.metrics import attach_tracer, traced
from .cache import PagedAttentionCache
from .requests import RequestState, RequestStatus, logger
class Scheduler(ABC):
"""
Abstract base class for scheduling requests in the continuous batch processor. Schedulers manage the lifecycle of
requests from when they are added to the waiting queue to when they are scheduled for processing. Different
schedulers implement different strategies for prioritizing and batching requests.
"""
def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False):
self.active_requests: dict[str, RequestState] = {}
self.waiting_requests: dict[str, RequestState] = {}
self.waiting_requests_order: deque[str] = deque()
self.cache = cache
self.retain_cache_on_finish = retain_cache_on_finish
self._cancellation_lock = threading.Lock()
self._requests_to_cancel: set[str] = set()
self._requests_to_fork: list[RequestState] = []
# This state is used to avoid infinite loops when offloading requests
self.block_new_requests = False
# This is to compute the cache used by a new request being scheduled
self.cache_budget_module = None if cache.num_full_attention_groups else cache.config.sliding_window
@traced
def add_waiting_request(self, state: RequestState):
"""Adds a request to the waiting list."""
if self.retain_cache_on_finish and state.request_id in self.active_requests:
old_state = self.active_requests.pop(state.request_id)
state.tokens_to_process = state.tokens_to_process[
len(old_state.initial_tokens) :
] # XXX: check for indexing error?
state.allocated_blocks = old_state.allocated_blocks
state.position_offset = old_state.position_offset
self.waiting_requests[state.request_id] = state
self.waiting_requests_order.append(state.request_id)
@abstractmethod
def schedule_batch(self, token_budget: int, cache_budget: int) -> list[RequestState] | None:
"""Schedules requests for the next batch based on available token and cache budgets. This method selects which
requests should be processed in the current batch, considering the budgets and the scheduler's prioritization
rules. The token_budget is the maximum number of tokens that can be processed in a batch, and the cache_budget
is the maximum number of KV cache entries that can be read in a batch."""
@traced
def has_pending_requests(self) -> bool:
"""Checks if there are requests ready to be processed."""
return bool(len(self.active_requests) or len(self.waiting_requests))
@traced
def finish_request(self, request_id: str, evict_from_cache: bool = True) -> None:
"""Completes processing of a request and optionally frees its allocated cache blocks. This method is called
when a request has finished generation or encountered an error.
"""
if evict_from_cache:
self.cache.free_blocks(request_id)
self.active_requests.pop(request_id, None)
@traced
def get_active_request_static_outputs(self, request_id: str) -> list[int]:
"""Gets generated tokens for an active request."""
if request_id in self.active_requests:
return self.active_requests[request_id].generated_tokens
return []
@traced
def set_request_cancellation(self, request_id: str):
"""Marks a request for cancellation."""
with self._cancellation_lock:
self._requests_to_cancel.add(request_id)
@traced
def clear_cancelled_requests(self):
"""Remove all cancelled requests from active and waiting queues."""
with self._cancellation_lock:
for request_id in self._requests_to_cancel:
self.active_requests.pop(request_id, None)
self.waiting_requests.pop(request_id, None)
if request_id in self.waiting_requests_order:
self.waiting_requests_order.remove(request_id)
self.cache.free_blocks(request_id)
self._requests_to_cancel = set()
@traced
def request_is_cancelled(self, request_id: str) -> bool:
"""Checks if a request has been cancelled or removed."""
return request_id in self._requests_to_cancel or (
request_id not in self.active_requests and request_id not in self.waiting_requests
)
@traced
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
"""Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
accommodate the next tokens. It calculates how many blocks are needed based on the request's current
cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
objects. Returns a boolean indicating if the allocation was successful or not.
"""
# 1. we check that the occupancy is less than the requested length
# 2. we allocate enough blocks to cover the requested length
current_len = state.current_len()
occupancy = state.allocated_blocks * self.cache.block_size - current_len
if occupancy < len_next_tokens or state.allocated_blocks == 0:
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id, state.allocated_blocks)
if allocated is None:
return False
state.allocated_blocks += allocated
return True
def _infer_request_tokens(self, state: RequestState, request_ids_to_remove_from_waiting: set[str]) -> list[int]:
"""Prepares a request for processing in the current batch. If prefix sharing is enabled, and the request was
pending, this is where we look for a prefix match and split the request if found."""
# If prefix sharing is enabled, we look for a prefix match and split the request if found
if self.cache.use_prefix_sharing and state.status == RequestStatus.PENDING:
prefill_length = self.cache.search_prefix_match(state.request_id, state.tokens_to_process)
if prefill_length > 0:
self.active_requests[state.request_id] = state
request_ids_to_remove_from_waiting.add(state.request_id)
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
# We keep track of the number of allocated blocks to avoid double allocation
state.allocated_blocks += prefill_length // self.cache.block_size
# Even if we match the whole request, we keep at least 1 token to start decoding
prefill_length = min(prefill_length, len(state.tokens_to_process) - 1)
state.remaining_prefill_tokens = state.tokens_to_process[prefill_length:]
state.tokens_to_process = state.tokens_to_process[prefill_length:]
state.position_offset += prefill_length
# If the request has a split prefill, the tokens to process are the remaining prompt ids
if state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
request_tokens = state.remaining_prefill_tokens
# Otherwise, the tokens to process are the prompt ids, which are the full prompt or the last predicted tokens
else:
request_tokens = state.tokens_to_process
return request_tokens
def _schedule_request(
self,
state: RequestState,
request_tokens: list[int],
token_budget: int,
request_ids_to_remove_from_waiting: set[str],
) -> None:
"""Schedules a request for the current batch, updating the request's status according to the token budget left.
After a request is scheduled, it is part of the next batch unless there is an error.
If the request has children (for parallel decoding), it ensures at least one token remains before the request is
forked."""
# If the request has one or more children we make sure not to prefill it entirely
# This does not check the request state, but DECODING request already have children set to 0.
if state.num_children > 0 and token_budget >= len(request_tokens) - 1:
token_budget = len(request_tokens) - 1
self._requests_to_fork.append(state)
# Case: we can process the entire prompt/remainder
if len(request_tokens) < token_budget:
if state.status == RequestStatus.PENDING:
self.active_requests[state.request_id] = state
state.status = RequestStatus.PREFILLING
request_ids_to_remove_from_waiting.add(state.request_id)
elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
state.status = RequestStatus.PREFILLING
state.tokens_to_process = state.remaining_prefill_tokens
state.remaining_prefill_tokens = []
# Otherwise: we need to split the request
else:
if state.status == RequestStatus.PENDING:
self.active_requests[state.request_id] = state
state.status = RequestStatus.PREFILLING_SPLIT
request_ids_to_remove_from_waiting.add(state.request_id)
elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
state.status = RequestStatus.PREFILLING_SPLIT
state.remaining_prefill_tokens = request_tokens[token_budget:]
state.tokens_to_process = request_tokens[:token_budget]
def _process_candidates(
self,
candidates: list[RequestState],
token_budget: int,
cache_budget: int,
request_ids_to_remove_from_waiting: set[str],
safety_margin: float = 0.0,
) -> tuple[list[RequestState], bool]:
"""Schedules candidate requests for the current batch.
This method contains the common logic shared by all schedulers: it checks token and cache budgets, allocates
cache blocks if needed, updates request states, and tracks which waiting requests should be removed from the
waiting queue.
"""
scheduled_requests = []
one_allocation_failed = False
safety_margins = safety_margin * self.cache.num_blocks
for state in candidates:
num_free_blocks = self.cache.get_num_free_blocks()
# If we are out the safety margin, we only accept decoding requests or the first prefill request
outside_safety_margin = num_free_blocks < safety_margins
if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING:
logger.info(
f"Outside safety margin, breaking out of scheduling loop. {num_free_blocks = } {safety_margins = }"
)
break
# Check cache budget
cache_needed = state.current_len()
cache_needed = (
cache_needed if self.cache_budget_module is None else cache_needed % self.cache_budget_module
)
if cache_budget < cache_needed:
continue
# Infer the tokens that will be present in the batch if token budget is enough
request_tokens = self._infer_request_tokens(state, request_ids_to_remove_from_waiting)
# Account for token budget
request_len = min(len(request_tokens), token_budget)
# Check there will be enough cache for the new tokens
allocation_successful = self._allocate_blocks_if_needed(state, request_len)
# If the allocation would not be successful, we move on to the next request
if not allocation_successful:
one_allocation_failed = True
# If we reached a waiting request and the cache is full, all subsequent waiting requests will need
# allocation as well, so we can safely break out of the scheduling loop.
if num_free_blocks == 0 and state.request_id in self.waiting_requests:
logger.info(f"Breaking mid-loop for request {state.request_id} because the cache is full")
break
continue
# If this point is reached, it means we can safely schedule the request
self._schedule_request(state, request_tokens, token_budget, request_ids_to_remove_from_waiting)
request_len = len(state.tokens_to_process) # it may change after scheduling
scheduled_requests.append(state)
# Update the token and cache budgets
token_budget -= request_len
cache_budget -= cache_needed
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
if self.cache.allow_block_sharing:
tokens_in_current_block = state.current_len() % self.cache.block_size
tokens_after_forward = tokens_in_current_block + request_len
complete_blocks = tokens_after_forward // self.cache.block_size
self.cache.blocks_to_complete[state.request_id] = complete_blocks
# Remove the request from the waiting queue and mark it as removed
req_id = state.request_id
was_waiting = self.waiting_requests.pop(req_id, None) is not None
if was_waiting:
request_ids_to_remove_from_waiting.add(req_id)
# Early exit of the loop if we have no budget left
if token_budget == 0 or cache_budget == 0:
break
return scheduled_requests, one_allocation_failed
def _cleanup_waiting_queue(self, request_ids_to_remove_from_waiting: set[str]) -> None:
"""Removes processed requests from the waiting queue order."""
self.waiting_requests_order = deque(
[req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
)
# TODO: further common-ize the two classes
@attach_tracer()
class FIFOScheduler(Scheduler):
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
prefilling requests. Additionally, it includes a safety margin mechanism to prevent cache exhaustion. By default,
when 80% of the cache is full, new requests will not be scheduled to prioritize decoding active requests."""
def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.2):
"""Initializes the FIFO scheduler. The safety margin is the percentage of free blocks under which we stop
scheduling new prefill requests, so safety_margin = 0.1 means that when there is less than 10% of free blocks,
or equivalently when more than 90% of blocks are already allocated, we stop scheduling new prefill requests.
"""
super().__init__(cache, retain_cache_on_finish)
self.safety_margin = safety_margin
@traced
def schedule_batch(self, token_budget: int, cache_budget: int) -> list[RequestState] | None:
priority_states: list[RequestState] = []
second_priority_states: list[RequestState] = []
for state in self.active_requests.values():
if state.status == RequestStatus.DECODING:
priority_states.append(state)
if state.status in [RequestStatus.SPLIT_PENDING_REMAINDER, RequestStatus.PREFILLING_SPLIT]:
second_priority_states.append(state)
# Add waiting requests to second priority
if not self.block_new_requests:
for req_id in self.waiting_requests_order:
second_priority_states.append(self.waiting_requests[req_id])
candidates = priority_states + second_priority_states
request_ids_to_remove_from_waiting = set()
scheduled_requests, one_allocation_failed = self._process_candidates(
candidates,
token_budget,
cache_budget,
request_ids_to_remove_from_waiting,
safety_margin=self.safety_margin,
)
# We remove waiting requests before checking requests were scheduled, because there might have been prefill matches
self._cleanup_waiting_queue(request_ids_to_remove_from_waiting)
# If no requests were scheduled and the cache is full, we signal it by returning None
if not scheduled_requests and one_allocation_failed:
return None
return scheduled_requests
# FIXME: prioritize adding from waiting reqs before scheduling `RequestStatus.DECODING` when cache space allows it
# TODO: further consolidate the code by making more of it common. The reference Scheduler is FIFO, not this one.
@attach_tracer()
class PrefillFirstScheduler(Scheduler):
"""Scheduler that prioritizes split prefill requests over decoding requests. This scheduler ensures that split
prefill requests (which are continuations of partially processed prompts) are completed before processing new
decoding requests."""
@traced
def schedule_batch(self, token_budget: int, cache_budget: int) -> list[RequestState] | None:
priority_states: list[RequestState] = []
second_priority_states: list[RequestState] = []
for state in self.active_requests.values():
# XXX: when cache is full, state can stay on `PREFILLING_SPLIT` so we need to take those into account
if state.status in [RequestStatus.PREFILLING_SPLIT, RequestStatus.SPLIT_PENDING_REMAINDER]:
priority_states.append(state)
elif state.status == RequestStatus.DECODING:
second_priority_states.append(state)
# Add waiting requests to second priority
if not self.block_new_requests:
for req_id in self.waiting_requests_order:
second_priority_states.append(self.waiting_requests[req_id])
candidates = priority_states + second_priority_states
request_ids_to_remove_from_waiting = set()
scheduled_requests, one_allocation_failed = self._process_candidates(
candidates,
token_budget,
cache_budget,
request_ids_to_remove_from_waiting,
safety_margin=0.0,
)
# We remove waiting requests before checking requests were scheduled, because there might have been prefill matches
self._cleanup_waiting_queue(request_ids_to_remove_from_waiting)
# If no requests were scheduled and the cache is full, we signal it by returning None
if not scheduled_requests and one_allocation_failed:
return None
return scheduled_requests
SCHEDULER_MAPPING = {
"fifo": FIFOScheduler,
"prefill_first": PrefillFirstScheduler,
}