You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

80 lines
3.5 KiB

import torch
from ..generation.continuous_batching import PagedAttentionCache
from ..modeling_flash_attention_utils import lazy_import_paged_flash_attention
def paged_attention_forward(
module: torch.nn.Module,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor | None = None,
cache: PagedAttentionCache = None,
cu_seq_lens_q=None,
cu_seq_lens_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
**kwargs,
) -> torch.Tensor:
r"""Perform the forward pass of attention with paged key-value cache.
This function handles the cache updates and performs the attention computation
using the flash_attn_varlen_func for efficient processing.
Args:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full k
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full v
cu_seq_lens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seq_lens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
"""
flash_attn_varlen_func = lazy_import_paged_flash_attention(module.config._attn_implementation)
sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window - 1, 0)
layer_type = "full_attention" if sliding_window == (-1, -1) else "sliding_attention"
# .update changes the shape of k and v from [1, num_kv_heads, seqlen_kv, head_dim] to [-1, num_kv_heads, head_dim]
if cache is not None:
k, v = cache.update(
key_states=k,
value_states=v,
layer_idx=module.layer_idx,
read_index=kwargs["read_index"],
write_index=kwargs["write_index"],
)
# Retrieve the cumulative sequence lengths for the current layer
if isinstance(cu_seq_lens_k, dict):
cu_seq_lens_k = cu_seq_lens_k[layer_type]
max_seqlen_k = max_seqlen_k[layer_type]
custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}
attn_output = flash_attn_varlen_func(
q.transpose(1, 2).squeeze(0).contiguous(),
k.contiguous(),
v.contiguous(),
cu_seq_lens_q.to(torch.int32),
cu_seq_lens_k.to(torch.int32).clone(),
max_seqlen_q,
max_seqlen_k,
softmax_scale=module.scaling,
causal=True, # kind of a must, it automatically aligns the mask for q < k
window_size=sliding_window, # -1 means infinite context window
**custom_kwargs,
)
if isinstance(attn_output, tuple):
attn_output = attn_output[0]
return attn_output, None