|
|
import torch
|
|
|
|
|
|
from ..utils import is_torch_npu_available, is_torch_xpu_available, logging
|
|
|
from ..utils.import_utils import is_torch_greater_or_equal
|
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
|
|
|
_is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
|
|
|
_is_torch_xpu_available = is_torch_xpu_available()
|
|
|
_is_torch_npu_available = is_torch_npu_available()
|
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
"""
|
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
|
"""
|
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
|
if n_rep == 1:
|
|
|
return hidden_states
|
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
|
|
|
def use_gqa_in_sdpa(attention_mask: torch.Tensor | None, key: torch.Tensor) -> bool:
|
|
|
# GQA can only be used under the following conditions
|
|
|
# 1.cuda or Ascend NPU
|
|
|
# - torch version >= 2.5
|
|
|
# - attention_mask is None (otherwise it will fall back to the math kernel)
|
|
|
# 2.xpu
|
|
|
# - torch version >= 2.8
|
|
|
if _is_torch_xpu_available:
|
|
|
return _is_torch_greater_or_equal_than_2_8
|
|
|
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None
|
|
|
|
|
|
|
|
|
def sdpa_attention_forward(
|
|
|
module: torch.nn.Module,
|
|
|
query: torch.Tensor,
|
|
|
key: torch.Tensor,
|
|
|
value: torch.Tensor,
|
|
|
attention_mask: torch.Tensor | None,
|
|
|
dropout: float = 0.0,
|
|
|
scaling: float | None = None,
|
|
|
is_causal: bool | None = None,
|
|
|
**kwargs,
|
|
|
) -> tuple[torch.Tensor, None]:
|
|
|
if kwargs.get("output_attentions", False):
|
|
|
logger.warning_once(
|
|
|
"`sdpa` attention does not support `output_attentions=True`."
|
|
|
" Please set your attention to `eager` if you want any of these features."
|
|
|
)
|
|
|
sdpa_kwargs = {}
|
|
|
if hasattr(module, "num_key_value_groups"):
|
|
|
if not use_gqa_in_sdpa(attention_mask, key):
|
|
|
key = repeat_kv(key, module.num_key_value_groups)
|
|
|
value = repeat_kv(value, module.num_key_value_groups)
|
|
|
else:
|
|
|
sdpa_kwargs = {"enable_gqa": True}
|
|
|
|
|
|
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
|
|
|
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
|
|
|
|
|
|
# SDPA's Flash Attention (and cuDNN) kernels rely on the `is_causal` flag. However, there are certain conditions:
|
|
|
# - Not in decoding phase (otherwise we want full attention on the single query token)
|
|
|
# - Attention mask is not to be provided (even if it is a causal pattern)
|
|
|
# - Internally, we marked this as compatible with causal, i.e. it is a decoder attention type
|
|
|
#
|
|
|
# Quirks on the conditionals:
|
|
|
# - We avoid inline passing this to the SDPA function directly to support both torch.compile's dynamic shapes and
|
|
|
# full graph options. Otherwise, dynamic shapes are prevented from compiling.
|
|
|
# - It is important to check first for the shape, otherwise compile will fail with
|
|
|
# `argument 'is_causal' must be bool, not SymBool`.
|
|
|
is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
|
|
|
|
|
|
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
|
|
|
# We convert it to a bool for the SDPA kernel that only accepts bools.
|
|
|
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
|
|
is_causal = is_causal.item()
|
|
|
|
|
|
# When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator,
|
|
|
# and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
|
|
|
# This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
|
|
|
if _is_torch_npu_available:
|
|
|
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
|
|
# Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
|
|
|
attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
|
|
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
|
query,
|
|
|
key,
|
|
|
value,
|
|
|
attn_mask=attention_mask,
|
|
|
dropout_p=dropout,
|
|
|
scale=scaling,
|
|
|
is_causal=is_causal,
|
|
|
**sdpa_kwargs,
|
|
|
)
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
|
|
return attn_output, None
|