You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
332 lines
17 KiB
332 lines
17 KiB
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_gemma3.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Any, Literal
|
|
|
|
from ...configuration_utils import PreTrainedConfig, layer_type_validation
|
|
from ...modeling_rope_utils import RopeParameters
|
|
from ...utils import logging
|
|
from ..siglip import SiglipVisionConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class Gemma3TextConfig(PreTrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
|
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
|
defaults will yield a similar configuration to that of the Gemma3Text-7B.
|
|
e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b)
|
|
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PreTrainedConfig`] for more information.
|
|
|
|
Args:
|
|
vocab_size (`int`, *optional*, defaults to 262208):
|
|
Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`Gemma3TextModel`]
|
|
hidden_size (`int`, *optional*, defaults to 2304):
|
|
Dimension of the hidden representations.
|
|
intermediate_size (`int`, *optional*, defaults to 9216):
|
|
Dimension of the MLP representations.
|
|
num_hidden_layers (`int`, *optional*, defaults to 26):
|
|
Number of hidden layers in the Transformer decoder.
|
|
num_attention_heads (`int`, *optional*, defaults to 8):
|
|
Number of attention heads for each attention layer in the Transformer decoder.
|
|
num_key_value_heads (`int`, *optional*, defaults to 4):
|
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
by meanpooling all the original heads within that group. For more details, check out [this
|
|
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
|
`num_attention_heads`.
|
|
head_dim (`int`, *optional*, defaults to 256):
|
|
The attention head dimension.
|
|
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
|
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
|
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
|
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
|
The maximum sequence length that this model might ever be used with.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
The epsilon used by the rms normalization layers.
|
|
use_cache (`bool`, *optional*, defaults to `True`):
|
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
relevant if `config.is_decoder=True`.
|
|
pad_token_id (`int`, *optional*, defaults to 0):
|
|
Padding token id.
|
|
eos_token_id (`int`, *optional*, defaults to 1):
|
|
End of stream token id.
|
|
bos_token_id (`int`, *optional*, defaults to 2):
|
|
Beginning of stream token id.
|
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the attention probabilities.
|
|
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
|
Scaling factor used on the attention scores
|
|
sliding_window (`int`, *optional*, defaults to 4096):
|
|
In Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window.
|
|
layer_types (`list`, *optional*):
|
|
Attention pattern for each layer.
|
|
final_logit_softcapping (`float`, *optional*):
|
|
Scaling factor when applying tanh softcapping on the logits.
|
|
attn_logit_softcapping (`float`, *optional*):
|
|
Scaling factor when applying tanh softcapping on the attention scores.
|
|
rope_parameters (`dict`, *optional*):
|
|
Dictionary mapping attention patterns (`"full_attention"`, `"sliding_attention"`) to `RopeParameters`.
|
|
Each value should be a dictionary containing `rope_type` and optional scaling parameters.
|
|
use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
|
|
If True, the model will attend to all text tokens instead of using a causal mask. This does not change
|
|
behavior for vision tokens.
|
|
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
|
Whether to tie weight embeddings
|
|
|
|
```python
|
|
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
|
|
>>> # Initializing a Gemma3Text gemma3_text-7b style configuration
|
|
>>> configuration = Gemma3TextConfig()
|
|
>>> # Initializing a model from the gemma3_text-7b style configuration
|
|
>>> model = Gemma3TextModel(configuration)
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```
|
|
"""
|
|
|
|
model_type = "gemma3_text"
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
base_model_tp_plan = {
|
|
"layers.*.self_attn.q_proj": "colwise",
|
|
"layers.*.self_attn.k_proj": "colwise",
|
|
"layers.*.self_attn.v_proj": "colwise",
|
|
"layers.*.self_attn.o_proj": "rowwise",
|
|
"layers.*.mlp.gate_proj": "colwise",
|
|
"layers.*.mlp.up_proj": "colwise",
|
|
"layers.*.mlp.down_proj": "rowwise",
|
|
}
|
|
base_model_pp_plan = {
|
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
|
"norm": (["hidden_states"], ["hidden_states"]),
|
|
}
|
|
default_theta = {"global": 1_000_000.0, "local": 10_000.0}
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size: int | None = 262_208,
|
|
hidden_size: int | None = 2304,
|
|
intermediate_size: int | None = 9216,
|
|
num_hidden_layers: int | None = 26,
|
|
num_attention_heads: int | None = 8,
|
|
num_key_value_heads: int | None = 4,
|
|
head_dim: int | None = 256,
|
|
hidden_activation: str | None = "gelu_pytorch_tanh",
|
|
max_position_embeddings: int | None = 131_072,
|
|
initializer_range: float | None = 0.02,
|
|
rms_norm_eps: int | None = 1e-6,
|
|
use_cache: bool | None = True,
|
|
pad_token_id: int | None = 0,
|
|
eos_token_id: int | None = 1,
|
|
bos_token_id: int | None = 2,
|
|
attention_bias: bool | None = False,
|
|
attention_dropout: float | None = 0.0,
|
|
query_pre_attn_scalar: int | None = 256,
|
|
sliding_window: int | None = 4096,
|
|
layer_types: list[str] | None = None,
|
|
final_logit_softcapping: float | None = None,
|
|
attn_logit_softcapping: float | None = None,
|
|
rope_parameters: dict[Literal["full_attention", "sliding_attention"], RopeParameters] | None = None,
|
|
use_bidirectional_attention: bool | None = False,
|
|
tie_word_embeddings: bool | None = True,
|
|
**kwargs,
|
|
):
|
|
self.pad_token_id = pad_token_id
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
self.tie_word_embeddings = tie_word_embeddings
|
|
self.vocab_size = vocab_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.head_dim = head_dim
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.initializer_range = initializer_range
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.use_cache = use_cache
|
|
self.attention_bias = attention_bias
|
|
self.attention_dropout = attention_dropout
|
|
self.hidden_activation = hidden_activation
|
|
self.query_pre_attn_scalar = query_pre_attn_scalar
|
|
self.sliding_window = sliding_window
|
|
self.final_logit_softcapping = final_logit_softcapping
|
|
self.attn_logit_softcapping = attn_logit_softcapping
|
|
self.layer_types = layer_types
|
|
|
|
self.use_bidirectional_attention = use_bidirectional_attention
|
|
if use_bidirectional_attention:
|
|
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
|
|
|
|
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
|
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
|
|
|
|
if self.layer_types is None:
|
|
self.layer_types = [
|
|
"sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
|
|
for i in range(self.num_hidden_layers)
|
|
]
|
|
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
|
|
|
self.rope_parameters = rope_parameters
|
|
super().__init__(**kwargs)
|
|
|
|
def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation=None, **kwargs):
|
|
rope_scaling = kwargs.pop("rope_scaling", None)
|
|
|
|
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`. If we find `rope_parameters`
|
|
# as arg in the inputs, we can safely assume that it is in the new format. New naming used -> new format
|
|
default_rope_params = {
|
|
"sliding_attention": {"rope_type": "default"},
|
|
"full_attention": {"rope_type": "default"},
|
|
}
|
|
self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else default_rope_params
|
|
if rope_scaling is not None:
|
|
self.rope_parameters["full_attention"].update(rope_scaling)
|
|
|
|
# Set default values if not present
|
|
if self.rope_parameters.get("full_attention") is None:
|
|
self.rope_parameters["full_attention"] = {"rope_type": "default"}
|
|
self.rope_parameters["full_attention"].setdefault(
|
|
"rope_theta", kwargs.pop("rope_theta", self.default_theta["global"])
|
|
)
|
|
if self.rope_parameters.get("sliding_attention") is None:
|
|
self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
|
|
self.rope_parameters["sliding_attention"].setdefault(
|
|
"rope_theta", kwargs.pop("rope_local_base_freq", self.default_theta["local"])
|
|
)
|
|
|
|
# Standardize and validate the correctness of rotary position embeddings parameters
|
|
self.standardize_rope_params()
|
|
self.validate_rope(ignore_keys=ignore_keys_at_rope_validation)
|
|
return kwargs
|
|
|
|
|
|
class Gemma3Config(PreTrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
|
|
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
|
|
|
|
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
|
|
|
|
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PreTrainedConfig`] for more information.
|
|
|
|
Args:
|
|
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
|
|
The config object of the text backbone.
|
|
vision_config (`Union[AutoConfig, dict]`, *optional*):
|
|
Custom vision config or dict.
|
|
mm_tokens_per_image (`int`, *optional*, defaults to 256):
|
|
The number of tokens per image embedding.
|
|
boi_token_index (`int`, *optional*, defaults to 255999):
|
|
The begin-of-image token index to wrap the image prompt.
|
|
eoi_token_index (`int`, *optional*, defaults to 256000):
|
|
The end-of-image token index to wrap the image prompt.
|
|
image_token_index (`int`, *optional*, defaults to 262144):
|
|
The image token index to encode the image prompt.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
|
Whether to tie weight embeddings
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
|
|
|
|
>>> # Initializing a Siglip-like vision config
|
|
>>> vision_config = SiglipVisionConfig()
|
|
|
|
>>> # Initializing a Gemma3 Text config
|
|
>>> text_config = Gemma3TextConfig()
|
|
|
|
>>> # Initializing a Gemma3 gemma-3-4b style configuration
|
|
>>> configuration = Gemma3Config(vision_config, text_config)
|
|
|
|
>>> # Initializing a model from the gemma-3-4b style configuration
|
|
>>> model = Gemma3TextConfig(configuration)
|
|
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
|
|
model_type = "gemma3"
|
|
attribute_map = {
|
|
"image_token_id": "image_token_index",
|
|
"boi_token_id": "boi_token_index",
|
|
"eoi_token_id": "eoi_token_index",
|
|
}
|
|
sub_configs = {
|
|
"text_config": Gemma3TextConfig,
|
|
"vision_config": SiglipVisionConfig,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
text_config: Gemma3TextConfig | dict[str, Any] | None = None,
|
|
vision_config: SiglipVisionConfig | dict[str, Any] | None = None,
|
|
mm_tokens_per_image: int | None = 256,
|
|
boi_token_index: int | None = 255_999,
|
|
eoi_token_index: int | None = 256_000,
|
|
image_token_index: int | None = 262_144,
|
|
initializer_range: float | None = 0.02,
|
|
tie_word_embeddings: bool | None = True,
|
|
**kwargs,
|
|
):
|
|
if text_config is None:
|
|
text_config = Gemma3TextConfig()
|
|
logger.info("text_config is None, using default Gemma3TextConfig text config.")
|
|
elif isinstance(text_config, dict):
|
|
text_config = Gemma3TextConfig(**text_config)
|
|
|
|
if isinstance(vision_config, dict):
|
|
vision_config = SiglipVisionConfig(**vision_config)
|
|
elif vision_config is None:
|
|
vision_config = SiglipVisionConfig()
|
|
logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
|
|
|
|
self.text_config = text_config
|
|
self.vision_config = vision_config
|
|
self.mm_tokens_per_image = mm_tokens_per_image
|
|
self.boi_token_index = boi_token_index
|
|
self.eoi_token_index = eoi_token_index
|
|
self.image_token_index = image_token_index
|
|
self.initializer_range = initializer_range
|
|
self.tie_word_embeddings = tie_word_embeddings
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
__all__ = ["Gemma3Config", "Gemma3TextConfig"]
|