You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

342 lines
14 KiB

# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import torch
from torch import nn
from ... import initialization as init
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...masking_utils import create_causal_mask
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import RopeParameters
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, logging
from ..llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaMLP,
LlamaModel,
LlamaPreTrainedModel,
LlamaRotaryEmbedding,
)
if TYPE_CHECKING:
pass
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
SPIECE_UNDERLINE = ""
logger = logging.get_logger(__name__)
class GemmaConfig(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The legacy activation function. It is overwritten by the `hidden_activation`.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_parameters (`RopeParameters`, *optional*):
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
with longer `max_position_embeddings`.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
use_bidirectional_attention (`bool`, *optional*):
If True, the model will attend to all text tokens instead of using a causal mask.
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size: int | None = 256000,
hidden_size: int | None = 3072,
intermediate_size: int | None = 24576,
num_hidden_layers: int | None = 28,
num_attention_heads: int | None = 16,
num_key_value_heads: int | None = 16,
head_dim: int | None = 256,
hidden_act: str | None = "gelu_pytorch_tanh",
max_position_embeddings: int | None = 8192,
initializer_range: float | None = 0.02,
rms_norm_eps: int | None = 1e-6,
use_cache: bool | None = True,
pad_token_id: int | None = 0,
eos_token_id: int | None = 1,
bos_token_id: int | None = 2,
tie_word_embeddings: bool | None = True,
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
attention_bias: bool | None = False,
attention_dropout: float | None = 0.0,
use_bidirectional_attention: bool | None = None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_bidirectional_attention = use_bidirectional_attention
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.tie_word_embeddings = tie_word_embeddings
self.rope_parameters = rope_parameters
super().__init__(**kwargs)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class GemmaMLP(LlamaMLP):
def __init__(self, config):
super().__init__(config)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
class GemmaRotaryEmbedding(LlamaRotaryEmbedding):
pass
class GemmaAttention(LlamaAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.is_causal = not getattr(config, "use_bidirectional_attention", False)
class GemmaPreTrainedModel(LlamaPreTrainedModel):
@torch.no_grad()
def _init_weights(self, module):
PreTrainedModel._init_weights(self, module)
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
init.zeros_(module.weight)
class GemmaModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
class GemmaForCausalLM(LlamaForCausalLM):
def forward(**super_kwargs):
r"""
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
return super().forward(**super_kwargs)
class GemmaForSequenceClassification(LlamaForSequenceClassification):
pass
class GemmaForTokenClassification(LlamaForTokenClassification):
pass
__all__ = [
"GemmaConfig",
"GemmaModel",
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
"GemmaPreTrainedModel",
]