# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable import torch import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig, layer_type_validation from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ( ROPE_INIT_FUNCTIONS, RopeParameters, dynamic_rope_update, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging from ...utils.generic import maybe_autocast from ..gemma.modeling_gemma import ( GemmaAttention, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, GemmaMLP, GemmaModel, GemmaPreTrainedModel, GemmaRMSNorm, GemmaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv, ) logger = logging.get_logger(__name__) class Gemma2Config(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma2-7B. e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b) Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 256000): Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Gemma2Model`] hidden_size (`int`, *optional*, defaults to 2304): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 9216): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 26): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 8): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*, defaults to 4): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details, check out [this paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `num_attention_heads`. head_dim (`int`, *optional*, defaults to 256): The attention head dimension. hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*, defaults to 0): Padding token id. eos_token_id (`int`, *optional*, defaults to 1): End of stream token id. bos_token_id (`int`, *optional*, defaults to 2): Beginning of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether to tie weight embeddings rope_parameters (`RopeParameters`, *optional*): Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE with longer `max_position_embeddings`. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. layer_types (`list`, *optional*): Attention pattern for each layer. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. use_bidirectional_attention (`bool`, *optional*): If True, the model will attend to all text tokens instead of using a causal mask. ```python >>> from transformers import Gemma2Model, Gemma2Config >>> # Initializing a Gemma2 gemma2-7b style configuration >>> configuration = Gemma2Config() >>> # Initializing a model from the gemma2-7b style configuration >>> model = Gemma2Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "gemma2" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, vocab_size: int | None = 256000, hidden_size: int | None = 2304, intermediate_size: int | None = 9216, num_hidden_layers: int | None = 26, num_attention_heads: int | None = 8, num_key_value_heads: int | None = 4, head_dim: int | None = 256, hidden_activation: str | None = "gelu_pytorch_tanh", max_position_embeddings: int | None = 8192, initializer_range: float | None = 0.02, rms_norm_eps: int | None = 1e-6, use_cache: bool | None = True, pad_token_id: int | None = 0, eos_token_id: int | None = 1, bos_token_id: int | None = 2, tie_word_embeddings: bool | None = True, rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, attention_bias: bool | None = False, attention_dropout: float | None = 0.0, query_pre_attn_scalar: int | None = 256, sliding_window: int | None = 4096, layer_types: list[str] | None = None, final_logit_softcapping: float | None = 30.0, attn_logit_softcapping: float | None = 50.0, use_bidirectional_attention: bool | None = None, **kwargs, ): self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.tie_word_embeddings = tie_word_embeddings self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.head_dim = head_dim self.num_key_value_heads = num_key_value_heads self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.hidden_activation = hidden_activation self.query_pre_attn_scalar = query_pre_attn_scalar self.sliding_window = sliding_window self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping self.layer_types = layer_types self.use_bidirectional_attention = use_bidirectional_attention if self.layer_types is None: self.layer_types = [ "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) ] layer_type_validation(self.layer_types, self.num_hidden_layers) self.rope_parameters = rope_parameters super().__init__(**kwargs) class Gemma2RMSNorm(GemmaRMSNorm): pass class Gemma2MLP(GemmaMLP): def __init__(self, config): super().__init__(config) self.act_fn = ACT2FN[config.hidden_activation] class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): def __init__(self, config: Gemma2Config, device=None): nn.Module.__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_type = self.config.rope_parameters["rope_type"] rope_init_fn: Callable = self.compute_default_rope_parameters if self.rope_type != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, dropout: float = 0.0, scaling: float | None = None, softcap: float | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if scaling is None: scaling = module.head_dim**-0.5 key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if softcap is not None: attn_weights = attn_weights / softcap attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * softcap if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Gemma2Attention(GemmaAttention): def __init__(self, config: Gemma2Config, layer_idx: int): self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None super().__init__(config, layer_idx) self.attn_logit_softcapping = self.config.attn_logit_softcapping self.attention_dropout = self.config.attention_dropout self.is_causal = not getattr(config, "use_bidirectional_attention", False) self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, cache_position: torch.LongTensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=self.attention_dropout if self.training else 0.0, scaling=self.scaling, sliding_window=self.sliding_window, softcap=self.attn_logit_softcapping, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Gemma2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.config = config self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, cache_position: torch.LongTensor | None = None, **kwargs, ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.pre_feedforward_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states return hidden_states class Gemma2PreTrainedModel(GemmaPreTrainedModel): pass class Gemma2Model(GemmaModel): def __init__(self, config: Gemma2Config): super().__init__(config) self.layers = nn.ModuleList( [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.rotary_emb = Gemma2RotaryEmbedding(config) def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # It may already have been prepared by e.g. `generate` if not isinstance(causal_mask_mapping := attention_mask, dict): # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, "position_ids": position_ids, } # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } # embed positions hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) # normalized # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer for decoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) class Gemma2ForCausalLM(GemmaForCausalLM): def __init__(self, config): super().__init__(config) self.model = Gemma2Model(config) self.post_init() def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, cache_position: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" Example: ```python >>> from transformers import AutoTokenizer, Gemma2ForCausalLM >>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b") >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") >>> prompt = "What is your favorite condiment?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) logits = logits * self.config.final_logit_softcapping loss = None if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class Gemma2ForSequenceClassification(GemmaForSequenceClassification): pass class Gemma2ForTokenClassification(GemmaForTokenClassification): pass __all__ = [ "Gemma2Config", "Gemma2ForCausalLM", "Gemma2Model", "Gemma2PreTrainedModel", "Gemma2ForSequenceClassification", "Gemma2ForTokenClassification", ]