You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

457 lines
12 KiB

"""UBER PROTOTYPE!!!"""
# mypy: allow-untyped-defs
from __future__ import annotations
import importlib
from dataclasses import dataclass
from functools import cache
from typing import Any, TYPE_CHECKING
from typing_extensions import TypeVarTuple, Unpack
from . import _registry
if TYPE_CHECKING:
from types import ModuleType
import torch
from torch.library import Library
__all__ = [
"register_flash_attention_fa4",
]
_FA4_MODULE_PATH: str | None = None
@dataclass
class _FA4Handle:
library: Library | None
def remove(self) -> None:
self.library = None
@cache
def _get_device_major(device: torch.device) -> int:
major, _ = torch.cuda.get_device_capability(device)
return major
def register_flash_attention_fa4(
module_path: str = "flash_attn.cute.interface",
) -> _FA4Handle:
"""
Register FA4 flash attention kernels with the PyTorch dispatcher.
Args:
module_path: Python module path to the FA4 implementation.
"""
global _FA4_MODULE_PATH
_ = _fa4_import_module(module_path)
_FA4_MODULE_PATH = module_path
return _FA4Handle(_fa4_register_kernels())
@cache
def _fa4_import_module(module_path: str) -> ModuleType:
module = importlib.import_module(module_path)
if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"):
raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels")
return module
def _fa4_register_kernels() -> Library:
lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA")
lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA")
lib.impl(
"_scaled_dot_product_flash_attention",
_fa4_scaled_dot_product_flash_attention_forward_impl,
"CUDA",
)
lib.impl(
"_scaled_dot_product_flash_attention_backward",
_fa4_scaled_dot_product_flash_attention_backward_impl,
"CUDA",
)
return lib
def _fa4_common_support_error(
query: torch.Tensor,
tensors: tuple[torch.Tensor, ...],
cum_seq_q: torch.Tensor | None,
require_fp32: tuple[tuple[str, torch.Tensor], ...] = (),
) -> str | None:
if not all(t.is_cuda for t in tensors):
return "inputs must be CUDA tensors"
if len({t.device for t in tensors}) != 1:
return "inputs must share device"
if query.dtype not in (torch.float16, torch.bfloat16):
return "query dtype must be float16 or bfloat16"
for name, tensor in require_fp32:
if tensor.dtype != torch.float32:
return f"{name} dtype must be float32"
if cum_seq_q is None and query.dim() != 4:
return "dense query must be 4D"
if cum_seq_q is not None and query.dim() != 3:
return "ragged query must be 3D"
if not torch.cuda.is_available():
return "CUDA not available"
if _get_device_major(query.device) not in (9, 10):
return "FA4 requires compute capability 9.0 or 10.0"
return None
def _fa4_forward_support_error(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float,
return_debug_mask: bool,
alibi_slopes: torch.Tensor | None,
seqused_k: torch.Tensor | None,
cum_seq_q: torch.Tensor | None,
) -> str | None:
if dropout_p != 0.0:
return "dropout_p must be 0"
if return_debug_mask:
return "return_debug_mask must be False"
if alibi_slopes is not None:
return "alibi_slopes not supported"
if seqused_k is not None:
if seqused_k.dtype != torch.int32:
return "seqused_k must be int32"
if not seqused_k.is_cuda:
return "seqused_k must be CUDA"
error = _fa4_common_support_error(
query,
(query, key, value),
cum_seq_q,
)
if error is not None:
if error == "inputs must share device":
return "query, key, value must be on same device"
return error
return None
def _fa4_backward_support_error(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
dropout_p: float,
cum_seq_q: torch.Tensor | None,
window_size_left: int | None,
window_size_right: int | None,
) -> str | None:
if dropout_p != 0.0:
return "dropout_p must be 0"
if window_size_left is not None or window_size_right is not None:
return "windowed attention not supported"
error = _fa4_common_support_error(
query,
(grad_out, query, key, value, out, logsumexp),
cum_seq_q,
require_fp32=(("logsumexp", logsumexp),),
)
if error is not None:
return error
return None
Ts = TypeVarTuple("Ts")
def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
def _fa4_run_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor | None,
cu_seq_k: torch.Tensor | None,
scale: float | None,
is_causal: bool,
window_size_left: int | None,
window_size_right: int | None,
seqused_k: torch.Tensor | None,
out: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if _FA4_MODULE_PATH is None:
raise RuntimeError("FA4 not registered")
module = _fa4_import_module(_FA4_MODULE_PATH)
kwargs: dict[str, Any] = {
"softmax_scale": scale,
"causal": is_causal,
"window_size_left": window_size_left,
"window_size_right": window_size_right,
"return_lse": True,
"cu_seqlens_q": cu_seq_q,
"cu_seqlens_k": cu_seq_k,
"seqused_k": seqused_k.contiguous() if seqused_k is not None else None,
}
if out is not None:
kwargs["out"] = out
out, lse = module._flash_attn_fwd(query, key, value, **kwargs)
return out, lse.contiguous()
def _fa4_run_backward(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cu_seq_q: torch.Tensor | None,
cu_seq_k: torch.Tensor | None,
scale: float | None,
is_causal: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if _FA4_MODULE_PATH is None:
raise RuntimeError("FA4 not registered")
module = _fa4_import_module(_FA4_MODULE_PATH)
dq, dk, dv = module._flash_attn_bwd(
query,
key,
value,
out,
grad_out,
logsumexp.contiguous(),
softmax_scale=scale,
causal=is_causal,
cu_seqlens_q=cu_seq_q,
cu_seqlens_k=cu_seq_k,
)
return dq, dk, dv
def _fa4_flash_attention_forward_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
return_debug_mask: bool,
*,
scale: float | None = None,
window_size_left: int | None = None,
window_size_right: int | None = None,
seqused_k: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
out: torch.Tensor | None = None,
):
error = _fa4_forward_support_error(
query,
key,
value,
dropout_p,
return_debug_mask,
alibi_slopes,
seqused_k,
cum_seq_q,
)
if error is not None:
raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}")
out, lse = _fa4_run_forward(
query,
key,
value,
cum_seq_q,
cum_seq_k,
scale,
is_causal,
window_size_left,
window_size_right,
seqused_k,
out,
)
rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
return out, lse, rng_state, philox_offset, debug_mask
def _fa4_flash_attention_backward_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
rng_state: torch.Tensor,
unused: torch.Tensor,
*,
scale: float | None = None,
window_size_left: int | None = None,
window_size_right: int | None = None,
):
error = _fa4_backward_support_error(
grad_out,
query,
key,
value,
out,
logsumexp,
dropout_p,
cum_seq_q,
window_size_left,
window_size_right,
)
if error is not None:
raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}")
dq, dk, dv = _fa4_run_backward(
grad_out,
query,
key,
value,
out,
logsumexp,
cum_seq_q,
cum_seq_k,
scale,
is_causal,
)
return dq, dk, dv
def _fa4_scaled_dot_product_flash_attention_forward_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float | None = None,
):
error = _fa4_forward_support_error(
query,
key,
value,
dropout_p,
return_debug_mask,
None,
None,
None,
)
if error is not None:
raise RuntimeError(f"FA4 SDPA forward unsupported: {error}")
q, k, v = _transpose_dense(query, key, value)
# Pre-allocate output with query's strides (BHSD layout), then create
# a BSHD view for the kernel. This ensures the returned output has
# the same memory layout as the input query.
out_bhsd = torch.empty_like(query)
out_bshd = out_bhsd.transpose(1, 2)
max_q_flash = q.size(1)
max_k_flash = k.size(1)
_, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl(
q,
k,
v,
None,
None,
max_q_flash,
max_k_flash,
dropout_p,
is_causal,
return_debug_mask,
scale=scale,
out=out_bshd,
)
max_q = query.size(2)
max_k = key.size(2)
return (
out_bhsd,
lse,
None,
None,
max_q,
max_k,
rng_state,
philox_offset,
debug_mask,
)
def _fa4_scaled_dot_product_flash_attention_backward_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
philox_seed: torch.Tensor,
philox_offset: torch.Tensor,
*,
scale: float | None = None,
):
error = _fa4_backward_support_error(
grad_out,
query,
key,
value,
out,
logsumexp,
dropout_p,
None,
None,
None,
)
if error is not None:
raise RuntimeError(f"FA4 SDPA backward unsupported: {error}")
q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out)
max_q = query.size(2)
max_k = key.size(2)
dq, dk, dv = _fa4_flash_attention_backward_impl(
go,
q,
k,
v,
o,
logsumexp,
None,
None,
max_q,
max_k,
dropout_p,
is_causal,
philox_seed,
philox_offset,
scale=scale,
)
dq, dk, dv = _transpose_dense(dq, dk, dv)
return dq, dk, dv
_registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4)