You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
3.1 KiB
87 lines
3.1 KiB
import torch
|
|
|
|
from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
|
|
from ..utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
_use_top_left_mask = flash_attn_supports_top_left_mask()
|
|
|
|
|
|
def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtype:
|
|
"""If the query is in float32, return a target dtype compatible with flash attention. Return None otherwise."""
|
|
if query.dtype == torch.float32:
|
|
if torch.is_autocast_enabled():
|
|
return torch.get_autocast_dtype("cuda")
|
|
# Handle the case where the model is quantized
|
|
elif hasattr(module.config, "_is_quantized"):
|
|
return module.config.dtype
|
|
else:
|
|
return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
|
|
return None
|
|
|
|
|
|
def flash_attention_forward(
|
|
module: torch.nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: torch.Tensor | None,
|
|
dropout: float = 0.0,
|
|
scaling: float | None = None,
|
|
sliding_window: int | None = None,
|
|
softcap: float | None = None,
|
|
is_causal: bool | None = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, None]:
|
|
if kwargs.get("output_attentions", False):
|
|
logger.warning_once(
|
|
"Flash Attention does not support `output_attentions=True`."
|
|
" Please set your attention to `eager` if you want any of these features."
|
|
)
|
|
|
|
# This is before the transpose
|
|
seq_len = query.shape[2]
|
|
|
|
if any(dim == 0 for dim in query.shape):
|
|
raise ValueError(
|
|
"Tensor query has shape with a zero dimension.\n"
|
|
"FlashAttention does not support inputs with dim=0.\n"
|
|
"Please check your input shapes or use SDPA instead."
|
|
)
|
|
# FA2 uses non-transposed inputs
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
|
# in fp32. (usually our RMSNorm modules handle it correctly)
|
|
target_dtype = get_target_dtype(query, module)
|
|
|
|
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
|
|
is_causal = is_causal if is_causal is not None else module.is_causal
|
|
|
|
attn_output = _flash_attention_forward(
|
|
query,
|
|
key,
|
|
value,
|
|
attention_mask,
|
|
query_length=seq_len,
|
|
is_causal=is_causal,
|
|
dropout=dropout,
|
|
softmax_scale=scaling,
|
|
sliding_window=sliding_window,
|
|
softcap=softcap,
|
|
use_top_left_mask=_use_top_left_mask,
|
|
target_dtype=target_dtype,
|
|
attn_implementation=module.config._attn_implementation,
|
|
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
|
|
**kwargs,
|
|
)
|
|
|
|
return attn_output, None
|