You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
3.8 KiB
110 lines
3.8 KiB
# mypy: allow-untyped-defs
|
|
"""Registry for flash attention implementations.
|
|
|
|
This module contains the registration system for flash attention implementations.
|
|
It has no torch dependencies to avoid circular imports during initialization.
|
|
"""
|
|
|
|
from collections.abc import Callable
|
|
from typing import Literal, Protocol
|
|
|
|
|
|
class FlashAttentionHandle(Protocol):
|
|
def remove(self) -> None: ...
|
|
|
|
|
|
_RegisterFn = Callable[..., FlashAttentionHandle | None]
|
|
_FlashAttentionImpl = Literal["FA4"]
|
|
|
|
_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
|
|
|
|
_FLASH_ATTENTION_ACTIVE: str | None = None
|
|
_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {}
|
|
|
|
|
|
def register_flash_attention_impl(
|
|
impl: str | _FlashAttentionImpl,
|
|
*,
|
|
register_fn: _RegisterFn,
|
|
) -> None:
|
|
"""
|
|
Register the callable that activates a flash attention impl.
|
|
|
|
.. note::
|
|
This function is intended for SDPA backend providers to register their
|
|
implementations. End users should use :func:`activate_flash_attention_impl`
|
|
to activate a registered implementation.
|
|
|
|
Args:
|
|
impl: Implementation identifier (e.g., ``"FA4"``).
|
|
register_fn: Callable that performs the actual dispatcher registration.
|
|
This function will be invoked by :func:`activate_flash_attention_impl`
|
|
and should register custom kernels with the PyTorch dispatcher.
|
|
It may optionally return a handle implementing
|
|
:class:`FlashAttentionHandle` to keep any necessary state alive.
|
|
|
|
Example:
|
|
>>> def my_impl_register(module_path: str = "my_flash_impl"):
|
|
... # Register custom kernels with torch dispatcher
|
|
... pass # doctest: +SKIP
|
|
>>> register_flash_attention_impl(
|
|
... "MyImpl", register_fn=my_impl_register
|
|
... ) # doctest: +SKIP
|
|
"""
|
|
_FLASH_ATTENTION_IMPLS[impl] = register_fn
|
|
|
|
|
|
def activate_flash_attention_impl(
|
|
impl: str | _FlashAttentionImpl,
|
|
) -> None:
|
|
"""
|
|
Activate into the dispatcher a previously registered flash attention impl.
|
|
|
|
.. note::
|
|
Backend providers should NOT automatically activate their implementation
|
|
on import. Users should explicitly opt-in by calling this function or via
|
|
environment variables to ensure multiple provider libraries can coexist.
|
|
|
|
Args:
|
|
impl: Implementation identifier to activate. See
|
|
:func:`~torch.nn.attention.list_flash_attention_impls` for available
|
|
implementations.
|
|
If the backend's :func:`register_flash_attention_impl` callable
|
|
returns a :class:`FlashAttentionHandle`, the registry keeps that
|
|
handle alive for the lifetime of the process (until explicit
|
|
uninstall support exists).
|
|
|
|
Example:
|
|
>>> activate_flash_attention_impl("FA4") # doctest: +SKIP
|
|
"""
|
|
global _FLASH_ATTENTION_ACTIVE
|
|
register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
|
|
if register_fn is None:
|
|
raise ValueError(
|
|
f"Unknown flash attention impl '{impl}'. "
|
|
f"Available implementations: {list_flash_attention_impls()}"
|
|
)
|
|
# TODO: The only way to actually register a new impl is to unregister the current impl
|
|
# reinstall the default impl and then register the new impl
|
|
if _FLASH_ATTENTION_ACTIVE == impl:
|
|
return
|
|
|
|
handle = register_fn()
|
|
if handle is not None:
|
|
_FLASH_ATTENTION_HANDLES[impl] = handle
|
|
_FLASH_ATTENTION_ACTIVE = impl
|
|
|
|
|
|
def list_flash_attention_impls() -> list[str]:
|
|
"""Return the names of all available flash attention implementations."""
|
|
return sorted(_FLASH_ATTENTION_IMPLS.keys())
|
|
|
|
|
|
def current_flash_attention_impl() -> str | None:
|
|
"""
|
|
Return the currently activated flash attention impl name, if any.
|
|
|
|
``None`` indicates that no custom impl has been activated.
|
|
"""
|
|
return _FLASH_ATTENTION_ACTIVE
|