You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
124 lines
3.9 KiB
124 lines
3.9 KiB
from collections.abc import Callable
|
|
|
|
import torch
|
|
|
|
from ...cache_utils import Cache
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
from ...modeling_layers import (
|
|
GenericForQuestionAnswering,
|
|
GenericForSequenceClassification,
|
|
GenericForTokenClassification,
|
|
)
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
from ...processing_utils import Unpack
|
|
from ...utils import auto_docstring, logging
|
|
from ..mistral.modeling_mistral import (
|
|
MistralAttention,
|
|
MistralDecoderLayer,
|
|
MistralForCausalLM,
|
|
MistralModel,
|
|
MistralPreTrainedModel,
|
|
apply_rotary_pos_emb,
|
|
eager_attention_forward,
|
|
)
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
|
|
scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
|
|
return scaling.unsqueeze(-1)
|
|
|
|
|
|
class Ministral3Attention(MistralAttention):
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: torch.Tensor | None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
query_states = query_states * _get_llama_4_attn_scale(
|
|
cache_position,
|
|
self.config.rope_parameters.get("llama_4_scaling_beta"),
|
|
self.config.rope_parameters.get("original_max_position_embeddings"),
|
|
).to(query_states.dtype)
|
|
|
|
if past_key_values is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class Ministral3DecoderLayer(MistralDecoderLayer):
|
|
pass
|
|
|
|
|
|
@auto_docstring
|
|
class Ministral3PreTrainedModel(MistralPreTrainedModel):
|
|
pass
|
|
|
|
|
|
@auto_docstring
|
|
class Ministral3Model(MistralModel):
|
|
pass
|
|
|
|
|
|
@auto_docstring
|
|
class Ministral3ForCausalLM(MistralForCausalLM):
|
|
pass
|
|
|
|
|
|
class Ministral3ForTokenClassification(GenericForTokenClassification, Ministral3PreTrainedModel):
|
|
pass
|
|
|
|
|
|
class Ministral3ForSequenceClassification(GenericForSequenceClassification, Ministral3PreTrainedModel):
|
|
pass
|
|
|
|
|
|
class Ministral3ForQuestionAnswering(GenericForQuestionAnswering, Ministral3PreTrainedModel):
|
|
pass
|
|
|
|
|
|
__all__ = [
|
|
"Ministral3ForCausalLM",
|
|
"Ministral3ForQuestionAnswering",
|
|
"Ministral3Model",
|
|
"Ministral3PreTrainedModel",
|
|
"Ministral3ForSequenceClassification",
|
|
"Ministral3ForTokenClassification",
|
|
]
|