You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.8 KiB

# mypy: allow-untyped-defs
"""Registry for flash attention implementations.
This module contains the registration system for flash attention implementations.
It has no torch dependencies to avoid circular imports during initialization.
"""
from collections.abc import Callable
from typing import Literal, Protocol
class FlashAttentionHandle(Protocol):
def remove(self) -> None: ...
_RegisterFn = Callable[..., FlashAttentionHandle | None]
_FlashAttentionImpl = Literal["FA4"]
_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
_FLASH_ATTENTION_ACTIVE: str | None = None
_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {}
def register_flash_attention_impl(
impl: str | _FlashAttentionImpl,
*,
register_fn: _RegisterFn,
) -> None:
"""
Register the callable that activates a flash attention impl.
.. note::
This function is intended for SDPA backend providers to register their
implementations. End users should use :func:`activate_flash_attention_impl`
to activate a registered implementation.
Args:
impl: Implementation identifier (e.g., ``"FA4"``).
register_fn: Callable that performs the actual dispatcher registration.
This function will be invoked by :func:`activate_flash_attention_impl`
and should register custom kernels with the PyTorch dispatcher.
It may optionally return a handle implementing
:class:`FlashAttentionHandle` to keep any necessary state alive.
Example:
>>> def my_impl_register(module_path: str = "my_flash_impl"):
... # Register custom kernels with torch dispatcher
... pass # doctest: +SKIP
>>> register_flash_attention_impl(
... "MyImpl", register_fn=my_impl_register
... ) # doctest: +SKIP
"""
_FLASH_ATTENTION_IMPLS[impl] = register_fn
def activate_flash_attention_impl(
impl: str | _FlashAttentionImpl,
) -> None:
"""
Activate into the dispatcher a previously registered flash attention impl.
.. note::
Backend providers should NOT automatically activate their implementation
on import. Users should explicitly opt-in by calling this function or via
environment variables to ensure multiple provider libraries can coexist.
Args:
impl: Implementation identifier to activate. See
:func:`~torch.nn.attention.list_flash_attention_impls` for available
implementations.
If the backend's :func:`register_flash_attention_impl` callable
returns a :class:`FlashAttentionHandle`, the registry keeps that
handle alive for the lifetime of the process (until explicit
uninstall support exists).
Example:
>>> activate_flash_attention_impl("FA4") # doctest: +SKIP
"""
global _FLASH_ATTENTION_ACTIVE
register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
if register_fn is None:
raise ValueError(
f"Unknown flash attention impl '{impl}'. "
f"Available implementations: {list_flash_attention_impls()}"
)
# TODO: The only way to actually register a new impl is to unregister the current impl
# reinstall the default impl and then register the new impl
if _FLASH_ATTENTION_ACTIVE == impl:
return
handle = register_fn()
if handle is not None:
_FLASH_ATTENTION_HANDLES[impl] = handle
_FLASH_ATTENTION_ACTIVE = impl
def list_flash_attention_impls() -> list[str]:
"""Return the names of all available flash attention implementations."""
return sorted(_FLASH_ATTENTION_IMPLS.keys())
def current_flash_attention_impl() -> str | None:
"""
Return the currently activated flash attention impl name, if any.
``None`` indicates that no custom impl has been activated.
"""
return _FLASH_ATTENTION_ACTIVE