You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

296 lines
12 KiB

import torch
from torch import nn
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import RopeParameters
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring
from ...utils.generic import check_model_inputs
from ..mistral.configuration_mistral import MistralConfig
from ..qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2ForCausalLM,
Qwen2ForQuestionAnswering,
Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2MLP,
Qwen2Model,
Qwen2PreTrainedModel,
Qwen2RMSNorm,
Qwen2RotaryEmbedding,
)
class MinistralConfig(MistralConfig, PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MinistralModel`]. It is used to instantiate an
Ministral model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Ministral-8B-Instruct-2410.
[mistralai/Ministral-8B-Instruct-2410](https://huggingface.co/mistralai/Ministral-8B-Instruct-2410)
[mistralai/Ministral-8B-Instruct-2410](https://huggingface.co/mistralai/Ministral-8B-Instruct-2410)
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Ministral model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MinistralModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. Ministral's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_parameters (`RopeParameters`, *optional*):
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
with longer `max_position_embeddings`.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention window size. If not specified, will default to `4096`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
layer_types (`list`, *optional*):
Attention pattern for each layer.
```python
>>> from transformers import MinistralModel, MinistralConfig
>>> # Initializing a Ministral 8B style configuration
>>> configuration = MinistralConfig()
>>> # Initializing a model from the Ministral 8B style configuration
>>> model = MinistralModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "ministral"
def __init__(
self,
vocab_size: int | None = 32000,
hidden_size: int | None = 4096,
intermediate_size: int | None = 14336,
num_hidden_layers: int | None = 32,
num_attention_heads: int | None = 32,
num_key_value_heads: int | None = 8,
head_dim: int | None = None,
hidden_act: str | None = "silu",
max_position_embeddings: int | None = 4096 * 32,
initializer_range: float | None = 0.02,
rms_norm_eps: float | None = 1e-6,
use_cache: bool | None = True,
pad_token_id: int | None = None,
bos_token_id: int | None = 1,
eos_token_id: int | None = 2,
tie_word_embeddings: bool | None = False,
rope_parameters: RopeParameters | None = None,
sliding_window: int | None = 4096,
attention_dropout: float | None = 0.0,
layer_types: list[str] | None = None,
**kwargs,
):
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.tie_word_embeddings = tie_word_embeddings
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.head_dim = head_dim
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.attention_dropout = attention_dropout
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if self.sliding_window is not None else "full_attention"
] * num_hidden_layers
self.rope_parameters = rope_parameters
PreTrainedConfig.__init__(self, **kwargs)
class MinistralMLP(Qwen2MLP):
pass
class MinistralAttention(Qwen2Attention):
def __init__(self, config, layer_idx: int):
super().__init__(config, layer_idx)
# Match Mistral: q/k/v do not have bias
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
class MinistralRMSNorm(Qwen2RMSNorm):
pass
class MinistralDecoderLayer(Qwen2DecoderLayer):
pass
class MinistralPreTrainedModel(Qwen2PreTrainedModel):
pass
class MinistralRotaryEmbedding(Qwen2RotaryEmbedding):
pass
class MinistralModel(Qwen2Model):
def __init__(self, config: MinistralConfig):
super().__init__(config)
del self.has_sliding_layers
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
class MinistralForCausalLM(Qwen2ForCausalLM):
pass
class MinistralForSequenceClassification(Qwen2ForSequenceClassification):
pass
class MinistralForTokenClassification(Qwen2ForTokenClassification):
pass
class MinistralForQuestionAnswering(Qwen2ForQuestionAnswering):
pass
__all__ = [
"MinistralConfig",
"MinistralPreTrainedModel",
"MinistralModel",
"MinistralForCausalLM",
"MinistralForSequenceClassification",
"MinistralForTokenClassification",
"MinistralForQuestionAnswering",
]