import torch from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask from ..utils import logging logger = logging.get_logger(__name__) _use_top_left_mask = flash_attn_supports_top_left_mask() def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtype: """If the query is in float32, return a target dtype compatible with flash attention. Return None otherwise.""" if query.dtype == torch.float32: if torch.is_autocast_enabled(): return torch.get_autocast_dtype("cuda") # Handle the case where the model is quantized elif hasattr(module.config, "_is_quantized"): return module.config.dtype else: return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype return None def flash_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, dropout: float = 0.0, scaling: float | None = None, sliding_window: int | None = None, softcap: float | None = None, is_causal: bool | None = None, **kwargs, ) -> tuple[torch.Tensor, None]: if kwargs.get("output_attentions", False): logger.warning_once( "Flash Attention does not support `output_attentions=True`." " Please set your attention to `eager` if you want any of these features." ) # This is before the transpose seq_len = query.shape[2] if any(dim == 0 for dim in query.shape): raise ValueError( "Tensor query has shape with a zero dimension.\n" "FlashAttention does not support inputs with dim=0.\n" "Please check your input shapes or use SDPA instead." ) # FA2 uses non-transposed inputs query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (usually our RMSNorm modules handle it correctly) target_dtype = get_target_dtype(query, module) # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented is_causal = is_causal if is_causal is not None else module.is_causal attn_output = _flash_attention_forward( query, key, value, attention_mask, query_length=seq_len, is_causal=is_causal, dropout=dropout, softmax_scale=scaling, sliding_window=sliding_window, softcap=softcap, use_top_left_mask=_use_top_left_mask, target_dtype=target_dtype, attn_implementation=module.config._attn_implementation, layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, **kwargs, ) return attn_output, None