You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
80 lines
3.5 KiB
80 lines
3.5 KiB
import torch
|
|
|
|
from ..generation.continuous_batching import PagedAttentionCache
|
|
from ..modeling_flash_attention_utils import lazy_import_paged_flash_attention
|
|
|
|
|
|
def paged_attention_forward(
|
|
module: torch.nn.Module,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
cache: PagedAttentionCache = None,
|
|
cu_seq_lens_q=None,
|
|
cu_seq_lens_k=None,
|
|
max_seqlen_q=None,
|
|
max_seqlen_k=None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
r"""Perform the forward pass of attention with paged key-value cache.
|
|
|
|
This function handles the cache updates and performs the attention computation
|
|
using the flash_attn_varlen_func for efficient processing.
|
|
|
|
Args:
|
|
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
|
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full k
|
|
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full v
|
|
cu_seq_lens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into q.
|
|
cu_seq_lens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into kv.
|
|
max_seqlen_q: int. Maximum query sequence length in the batch.
|
|
max_seqlen_k: int. Maximum key sequence length in the batch.
|
|
dropout_p: float. Dropout probability.
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
Default to 1 / sqrt(headdim).
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
softcap: float. Anything > 0 activates softcapping attention.
|
|
"""
|
|
flash_attn_varlen_func = lazy_import_paged_flash_attention(module.config._attn_implementation)
|
|
|
|
sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window - 1, 0)
|
|
layer_type = "full_attention" if sliding_window == (-1, -1) else "sliding_attention"
|
|
|
|
# .update changes the shape of k and v from [1, num_kv_heads, seqlen_kv, head_dim] to [-1, num_kv_heads, head_dim]
|
|
if cache is not None:
|
|
k, v = cache.update(
|
|
key_states=k,
|
|
value_states=v,
|
|
layer_idx=module.layer_idx,
|
|
read_index=kwargs["read_index"],
|
|
write_index=kwargs["write_index"],
|
|
)
|
|
|
|
# Retrieve the cumulative sequence lengths for the current layer
|
|
if isinstance(cu_seq_lens_k, dict):
|
|
cu_seq_lens_k = cu_seq_lens_k[layer_type]
|
|
max_seqlen_k = max_seqlen_k[layer_type]
|
|
|
|
custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}
|
|
|
|
attn_output = flash_attn_varlen_func(
|
|
q.transpose(1, 2).squeeze(0).contiguous(),
|
|
k.contiguous(),
|
|
v.contiguous(),
|
|
cu_seq_lens_q.to(torch.int32),
|
|
cu_seq_lens_k.to(torch.int32).clone(),
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
softmax_scale=module.scaling,
|
|
causal=True, # kind of a must, it automatically aligns the mask for q < k
|
|
window_size=sliding_window, # -1 means infinite context window
|
|
**custom_kwargs,
|
|
)
|
|
if isinstance(attn_output, tuple):
|
|
attn_output = attn_output[0]
|
|
return attn_output, None
|