You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
180 lines
9.3 KiB
180 lines
9.3 KiB
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/olmo3/modular_olmo3.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_olmo3.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# Copyright 2025 the HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from ...configuration_utils import PreTrainedConfig, layer_type_validation
|
|
from ...modeling_rope_utils import RopeParameters
|
|
|
|
|
|
class Olmo3Config(PreTrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`Olmo3Model`]. It is used to instantiate an OLMo3
|
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
|
defaults will yield a similar configuration to that of the [allenai/OLMo-3-0725-1B](https://huggingface.co/allenai/OLMo-3-0725-1B).
|
|
|
|
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PreTrainedConfig`] for more information.
|
|
|
|
|
|
Args:
|
|
vocab_size (`int`, *optional*, defaults to 50304):
|
|
Vocabulary size of the Olmo3 model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`Olmo3Model`]
|
|
hidden_size (`int`, *optional*, defaults to 4096):
|
|
Dimension of the hidden representations.
|
|
intermediate_size (`int`, *optional*, defaults to 11008):
|
|
Dimension of the MLP representations.
|
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
|
Number of hidden layers in the Transformer decoder.
|
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
|
Number of attention heads for each attention layer in the Transformer decoder.
|
|
num_key_value_heads (`int`, *optional*):
|
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
by meanpooling all the original heads within that group. For more details, check out [this
|
|
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
|
`num_attention_heads`.
|
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
|
The non-linear activation function (function or string) in the decoder.
|
|
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
|
The maximum sequence length that this model might ever be used with.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
use_cache (`bool`, *optional*, defaults to `True`):
|
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
relevant if `config.is_decoder=True`.
|
|
pad_token_id (`int`, *optional*, defaults to 1):
|
|
Padding token id.
|
|
bos_token_id (`int`, *optional*):
|
|
Beginning of stream token id.
|
|
eos_token_id (`int`, *optional*, defaults to 50279):
|
|
End of stream token id.
|
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
Whether to tie weight embeddings
|
|
rope_parameters (`RopeParameters`, *optional*):
|
|
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
|
|
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
|
|
with longer `max_position_embeddings`.
|
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the attention probabilities.
|
|
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
The epsilon used by the rms normalization layers.
|
|
sliding_window (`int`, *optional*, defaults to 4096):
|
|
Size of the sliding window for sliding window attention.
|
|
layer_types (`list`, *optional*):
|
|
Attention pattern for each layer. Defaults to sliding window attention
|
|
for 3 out of 4 layers, and full attention for every 4th layer.
|
|
|
|
```python
|
|
>>> from transformers import Olmo3Model, Olmo3Config
|
|
|
|
>>> # Initializing a Olmo3 7B style configuration
|
|
>>> configuration = Olmo3Config()
|
|
|
|
>>> # Initializing a model from the Olmo3 7B style configuration
|
|
>>> model = Olmo3Model(configuration)
|
|
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```
|
|
"""
|
|
|
|
model_type = "olmo3"
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
base_model_tp_plan = {
|
|
"layers.*.self_attn.q_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
|
|
"layers.*.self_attn.k_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
|
|
"layers.*.self_attn.v_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k
|
|
"layers.*.self_attn.o_proj": "rowwise_split_input", # input is replicated due to the added norm on q and k
|
|
"layers.*.mlp.gate_proj": "colwise",
|
|
"layers.*.mlp.up_proj": "colwise",
|
|
"layers.*.mlp.down_proj": "rowwise",
|
|
}
|
|
base_model_pp_plan = {
|
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
|
"norm": (["hidden_states"], ["hidden_states"]),
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size: int | None = 50304,
|
|
hidden_size: int | None = 4096,
|
|
intermediate_size: int | None = 11008,
|
|
num_hidden_layers: int | None = 32,
|
|
num_attention_heads: int | None = 32,
|
|
num_key_value_heads: int | None = None,
|
|
hidden_act: str | None = "silu",
|
|
max_position_embeddings: int | None = 2048,
|
|
initializer_range: float | None = 0.02,
|
|
use_cache: bool | None = True,
|
|
pad_token_id: int | None = 1,
|
|
bos_token_id: int | None = None,
|
|
eos_token_id: int | None = 50279,
|
|
tie_word_embeddings: bool | None = False,
|
|
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
|
|
attention_bias: bool | None = False,
|
|
attention_dropout: float | None = 0.0,
|
|
rms_norm_eps: float | None = 1e-5,
|
|
sliding_window: int | None = 4096,
|
|
layer_types: list[str] | None = None,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
|
|
# for backward compatibility
|
|
if num_key_value_heads is None:
|
|
num_key_value_heads = num_attention_heads
|
|
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.hidden_act = hidden_act
|
|
self.initializer_range = initializer_range
|
|
self.use_cache = use_cache
|
|
self.attention_bias = attention_bias
|
|
self.attention_dropout = attention_dropout
|
|
self.tie_word_embeddings = tie_word_embeddings
|
|
self.pad_token_id = pad_token_id
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.sliding_window = sliding_window
|
|
self.layer_types = layer_types
|
|
if self.layer_types is None:
|
|
self.layer_types = [
|
|
"sliding_attention" if (i + 1) % 4 != 0 else "full_attention" for i in range(self.num_hidden_layers)
|
|
]
|
|
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
|
|
|
self.rope_parameters = rope_parameters
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
__all__ = ["Olmo3Config"]
|