You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
4.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
from ..utils import is_torch_npu_available, is_torch_xpu_available, logging
from ..utils.import_utils import is_torch_greater_or_equal
logger = logging.get_logger(__name__)
_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
_is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
_is_torch_xpu_available = is_torch_xpu_available()
_is_torch_npu_available = is_torch_npu_available()
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def use_gqa_in_sdpa(attention_mask: torch.Tensor | None, key: torch.Tensor) -> bool:
# GQA can only be used under the following conditions
# 1.cuda or Ascend NPU
# - torch version >= 2.5
# - attention_mask is None (otherwise it will fall back to the math kernel)
# 2.xpu
# - torch version >= 2.8
if _is_torch_xpu_available:
return _is_torch_greater_or_equal_than_2_8
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None
def sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
dropout: float = 0.0,
scaling: float | None = None,
is_causal: bool | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
if kwargs.get("output_attentions", False):
logger.warning_once(
"`sdpa` attention does not support `output_attentions=True`."
" Please set your attention to `eager` if you want any of these features."
)
sdpa_kwargs = {}
if hasattr(module, "num_key_value_groups"):
if not use_gqa_in_sdpa(attention_mask, key):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
else:
sdpa_kwargs = {"enable_gqa": True}
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
# SDPA's Flash Attention (and cuDNN) kernels rely on the `is_causal` flag. However, there are certain conditions:
# - Not in decoding phase (otherwise we want full attention on the single query token)
# - Attention mask is not to be provided (even if it is a causal pattern)
# - Internally, we marked this as compatible with causal, i.e. it is a decoder attention type
#
# Quirks on the conditionals:
# - We avoid inline passing this to the SDPA function directly to support both torch.compile's dynamic shapes and
# full graph options. Otherwise, dynamic shapes are prevented from compiling.
# - It is important to check first for the shape, otherwise compile will fail with
# `argument 'is_causal' must be bool, not SymBool`.
is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
# When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator
# and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
# This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
if _is_torch_npu_available:
if attention_mask is not None and attention_mask.dtype != torch.bool:
# Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
**sdpa_kwargs,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None