You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OLMoE model."""
from collections.abc import Callable
import torch
from torch import nn
from ... import initialization as init
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_outputs import MoeModelOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, is_grouped_mm_available, logging
from ...utils.generic import OutputRecorder
from ..gemma.modeling_gemma import GemmaMLP
from ..llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM, MixtralModel
from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeTopKRouter
from .configuration_olmoe import OlmoeConfig
logger = logging.get_logger(__name__)
class OlmoeRMSNorm(LlamaRMSNorm):
def __init__(self, hidden_size, eps=1e-5):
super().__init__(hidden_size, eps)
class OlmoeRotaryEmbedding(LlamaRotaryEmbedding):
pass
class OlmoeMLP(GemmaMLP):
pass
class OlmoeAttention(LlamaAttention):
def __init__(self, config: OlmoeConfig, layer_idx: int | None = None):
super().__init__(config, layer_idx)
self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.k_norm = OlmoeRMSNorm(
(config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None,
past_key_values: Cache | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states))
key_states = self.k_norm(self.k_proj(hidden_states))
value_states = self.v_proj(hidden_states)
if self.config.clip_qkv is not None: # Diff with llama
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
query_states = query_states.view(*hidden_shape).transpose(1, 2)
key_states = key_states.view(*hidden_shape).transpose(1, 2)
value_states = value_states.view(*hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class OlmoeExperts(MixtralExperts):
pass
class OlmoeTopKRouter(Qwen2MoeTopKRouter):
pass
class OlmoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = OlmoeTopKRouter(config)
self.experts = OlmoeExperts(config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
_, top_k_weights, top_k_index = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states
class OlmoeDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: OlmoeConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.hidden_size = config.hidden_size
self.self_attn = OlmoeAttention(config=config, layer_idx=layer_idx)
self.mlp = OlmoeSparseMoeBlock(config)
self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@auto_docstring
class OlmoePreTrainedModel(PreTrainedModel):
config: OlmoeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OlmoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_can_record_outputs = {
"router_logits": OutputRecorder(OlmoeTopKRouter, index=0),
"hidden_states": OlmoeDecoderLayer,
"attentions": OlmoeAttention,
}
_can_compile_fullgraph = (
is_grouped_mm_available()
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
_supports_attention_backend = True
@torch.no_grad()
def _init_weights(self, module):
PreTrainedModel._init_weights(self, module)
if isinstance(module, OlmoeExperts):
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
elif isinstance(module, OlmoeTopKRouter):
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
@auto_docstring
class OlmoeModel(MixtralModel):
def __init__(self, config: OlmoeConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = OlmoeRotaryEmbedding(config=config)
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> MoeModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask( # diff with mixtral: no sliding
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
class OlmoeForCausalLM(MixtralForCausalLM, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
def __init__(self, config):
super().__init__(config)
self.model = OlmoeModel(config)
self.num_experts = config.num_experts
def forward(self, **super_kwargs):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, OlmoeForCausalLM
>>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
>>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'Hey, are you conscious? Can you talk to me?\nIm not sure if youre conscious of this, but Im'
```
"""
return super().forward(**super_kwargs)
__all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]