You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
470 lines
18 KiB
470 lines
18 KiB
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch AFMoE model."""
|
|
|
|
from collections.abc import Callable
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ... import initialization as init
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...generation import GenerationMixin
|
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import MoeModelOutputWithPast
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import TransformersKwargs, auto_docstring, logging
|
|
from ...utils.generic import check_model_inputs
|
|
from ..gpt_oss.modeling_gpt_oss import GptOssRMSNorm
|
|
from ..llama.modeling_llama import (
|
|
LlamaAttention,
|
|
LlamaForCausalLM,
|
|
LlamaRotaryEmbedding,
|
|
apply_rotary_pos_emb,
|
|
eager_attention_forward,
|
|
)
|
|
from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeMLP
|
|
from .configuration_afmoe import AfmoeConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class AfmoeRotaryEmbedding(LlamaRotaryEmbedding):
|
|
pass
|
|
|
|
|
|
class AfmoeRMSNorm(GptOssRMSNorm):
|
|
pass
|
|
|
|
|
|
class AfmoeMLP(Qwen2MoeMLP):
|
|
pass
|
|
|
|
|
|
class AfmoeTokenChoiceRouter(nn.Module):
|
|
"""
|
|
Token-choice top-K router for MoE routing.
|
|
|
|
This router assigns each token to the top-K experts based on sigmoid scores, matching the released checkpoints.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.top_k = config.num_experts_per_tok
|
|
self.num_experts = config.num_experts
|
|
self.route_scale = config.route_scale
|
|
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor):
|
|
_, _, hidden_dim = hidden_states.shape
|
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
|
scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32))
|
|
|
|
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
|
|
top_scores = scores.gather(dim=1, index=selected_experts)
|
|
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
|
|
top_scores = top_scores / denominator
|
|
top_scores = top_scores * self.route_scale
|
|
return top_scores, selected_experts
|
|
|
|
|
|
class AfmoeExperts(nn.ModuleList):
|
|
"""
|
|
Container holding the routed experts.
|
|
|
|
This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion.
|
|
"""
|
|
|
|
def __init__(self, config: AfmoeConfig):
|
|
super().__init__()
|
|
self.top_k = config.num_experts_per_tok
|
|
self.num_experts = config.num_experts
|
|
for _ in range(self.num_experts):
|
|
self.append(AfmoeMLP(config, intermediate_size=config.moe_intermediate_size))
|
|
|
|
def forward(
|
|
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states: (batch, seq, hidden)
|
|
selected_experts: (batch, seq, top_k)
|
|
routing_weights: (batch, seq, top_k)
|
|
"""
|
|
batch_size, seq_len, hidden_dim = hidden_states.shape
|
|
if seq_len == 0:
|
|
return hidden_states.new_zeros(batch_size, 0, hidden_dim)
|
|
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
|
top_k = selected_experts.shape[-1]
|
|
|
|
# Map every token routing decision to a unique position so we can process expert by expert.
|
|
token_indices = torch.arange(
|
|
hidden_states_flat.shape[0], device=hidden_states.device, dtype=torch.long
|
|
).repeat_interleave(top_k)
|
|
expert_indices = selected_experts.reshape(-1)
|
|
routing_weights = routing_weights.reshape(-1)
|
|
|
|
sorting = torch.argsort(expert_indices, stable=True)
|
|
token_indices = token_indices[sorting]
|
|
expert_indices = expert_indices[sorting]
|
|
routing_weights = routing_weights[sorting]
|
|
|
|
dispatched_tokens = hidden_states_flat.index_select(0, token_indices)
|
|
expert_outputs = torch.zeros_like(dispatched_tokens)
|
|
|
|
unique_experts, counts = torch.unique_consecutive(expert_indices, return_counts=True)
|
|
start = 0
|
|
for expert_id, count in zip(unique_experts.tolist(), counts.tolist()):
|
|
if count == 0:
|
|
continue
|
|
end = start + count
|
|
expert_input = dispatched_tokens[start:end]
|
|
expert_output = self[expert_id](expert_input)
|
|
expert_outputs[start:end] = expert_output
|
|
start = end
|
|
|
|
weighted_outputs = (expert_outputs.to(torch.float32) * routing_weights.unsqueeze(-1)).to(hidden_states.dtype)
|
|
aggregated = torch.zeros_like(hidden_states_flat)
|
|
scatter_indices = token_indices.unsqueeze(-1).expand_as(weighted_outputs)
|
|
aggregated.scatter_add_(0, scatter_indices, weighted_outputs)
|
|
return aggregated.view(batch_size, seq_len, hidden_dim)
|
|
|
|
|
|
class AfmoeMoE(nn.Module):
|
|
"""
|
|
Mixture of Experts (MoE) module for AFMoE.
|
|
|
|
This module implements a sparse MoE layer with both shared experts (always active) and
|
|
routed experts (activated based on token-choice routing).
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.router = AfmoeTokenChoiceRouter(config)
|
|
self.shared_experts = AfmoeMLP(config, config.moe_intermediate_size * config.num_shared_experts)
|
|
self.experts = AfmoeExperts(config)
|
|
self.expert_bias = nn.Parameter(torch.zeros(config.num_experts), requires_grad=False)
|
|
|
|
def forward(self, hidden_states):
|
|
batch_size, seq_len, hidden_dim = hidden_states.shape
|
|
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
|
|
|
# Get routing decisions
|
|
top_scores, selected_experts = self.router(hidden_states, self.expert_bias)
|
|
top_scores = top_scores.view(batch_size, seq_len, self.config.num_experts_per_tok)
|
|
selected_experts = selected_experts.view(batch_size, seq_len, self.config.num_experts_per_tok)
|
|
|
|
# Process through shared experts
|
|
shared_output = self.shared_experts(hidden_states_flat).view(batch_size, seq_len, hidden_dim)
|
|
routed_output = self.experts(hidden_states, selected_experts, top_scores)
|
|
return shared_output + routed_output
|
|
|
|
|
|
class AfmoeAttention(LlamaAttention):
|
|
"""
|
|
Multi-headed attention module with optional sliding window and gating.
|
|
|
|
This attention mechanism supports both full attention and sliding window attention,
|
|
and includes Q/K normalization and gating of the output. It inherits from [`LlamaAttention`] to minimize the amount
|
|
of custom logic we need to maintain.
|
|
"""
|
|
|
|
def __init__(self, config: AfmoeConfig, layer_idx: int):
|
|
super().__init__(config, layer_idx)
|
|
# Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
|
|
# We only add AFMoE-specific attributes
|
|
self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
|
|
self.sliding_window = config.sliding_window if self.is_local_attention else None
|
|
|
|
self.q_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
self.k_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
self.gate_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: torch.Tensor | None,
|
|
past_key_value: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
|
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape)
|
|
gate_states = self.gate_proj(hidden_states)
|
|
|
|
query_states = self.q_norm(query_states).transpose(1, 2)
|
|
key_states = self.k_norm(key_states).transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
if self.is_local_attention:
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
cache_kwargs = {"cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask=attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
sliding_window=self.sliding_window,
|
|
**kwargs,
|
|
)
|
|
|
|
output = output.view(*input_shape, -1).contiguous()
|
|
output = output * torch.sigmoid(gate_states)
|
|
attn_output = self.o_proj(output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class AfmoeDecoderLayer(GradientCheckpointingLayer):
|
|
"""
|
|
AFMoE decoder layer with dual normalization.
|
|
|
|
This layer applies self-attention followed by either a dense MLP or MoE block,
|
|
with dual normalization (pre and post) around each component.
|
|
"""
|
|
|
|
def __init__(self, config: AfmoeConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.layer_idx = layer_idx
|
|
|
|
self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
|
|
self.attention_type = config.layer_types[layer_idx]
|
|
|
|
# Dual normalization for attention
|
|
self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
# Dual normalization for FFN
|
|
self.pre_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
# MoE or dense FFN
|
|
self.moe_enabled = layer_idx >= config.num_dense_layers
|
|
if self.moe_enabled:
|
|
self.mlp = AfmoeMoE(config)
|
|
else:
|
|
self.mlp = AfmoeMLP(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_value: Cache | None = None,
|
|
use_cache: bool | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.FloatTensor:
|
|
residual = hidden_states
|
|
|
|
# Self Attention with dual normalization
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states, _ = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# FFN with dual normalization
|
|
residual = hidden_states
|
|
hidden_states = self.pre_mlp_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.post_mlp_layernorm(hidden_states)
|
|
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class AfmoePreTrainedModel(PreTrainedModel):
|
|
"""
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
models.
|
|
"""
|
|
|
|
config: AfmoeConfig
|
|
base_model_prefix = "model"
|
|
_no_split_modules = ["AfmoeDecoderLayer"]
|
|
_skip_keys_device_placement = ["past_key_values"]
|
|
_can_record_outputs = {
|
|
"hidden_states": AfmoeDecoderLayer,
|
|
"attentions": AfmoeAttention,
|
|
}
|
|
_keep_in_fp32_modules = [
|
|
"input_layernorm",
|
|
"post_attention_layernorm",
|
|
"pre_mlp_layernorm",
|
|
"post_mlp_layernorm",
|
|
"q_norm",
|
|
"k_norm",
|
|
"norm",
|
|
"expert_bias",
|
|
]
|
|
_supports_sdpa = True
|
|
_supports_flash_attn = True
|
|
_supports_flex_attn = True
|
|
_supports_attention_backend = True
|
|
supports_gradient_checkpointing = True
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
super()._init_weights(module)
|
|
if isinstance(module, AfmoeTokenChoiceRouter):
|
|
init.zeros_(module.gate.weight)
|
|
elif isinstance(module, AfmoeMoE):
|
|
init.zeros_(module.expert_bias)
|
|
|
|
|
|
@auto_docstring
|
|
class AfmoeModel(AfmoePreTrainedModel):
|
|
"""
|
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AfmoeDecoderLayer`]
|
|
|
|
Args:
|
|
config: AfmoeConfig
|
|
"""
|
|
|
|
def __init__(self, config: AfmoeConfig):
|
|
super().__init__(config)
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
self.layers = nn.ModuleList(
|
|
[AfmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
)
|
|
self.norm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.rotary_emb = AfmoeRotaryEmbedding(config=config)
|
|
self.gradient_checkpointing = False
|
|
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
@check_model_inputs
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
use_cache: bool | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | MoeModelOutputWithPast:
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = DynamicCache(config=self.config)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens,
|
|
past_seen_tokens + inputs_embeds.shape[1],
|
|
device=inputs_embeds.device,
|
|
)
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
# It may already have been prepared by e.g. `generate`
|
|
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
|
mask_kwargs = {
|
|
"config": self.config,
|
|
"input_embeds": inputs_embeds,
|
|
"attention_mask": attention_mask,
|
|
"cache_position": cache_position,
|
|
"past_key_values": past_key_values,
|
|
}
|
|
causal_mask_mapping = {
|
|
"full_attention": create_causal_mask(**mask_kwargs),
|
|
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
|
}
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
# Apply muP input scaling if enabled
|
|
if self.config.mup_enabled:
|
|
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
for decoder_layer in self.layers:
|
|
hidden_states = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
return MoeModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values if use_cache else None,
|
|
)
|
|
|
|
|
|
class AfmoeForCausalLM(LlamaForCausalLM, AfmoePreTrainedModel, GenerationMixin):
|
|
def __init__(self, config):
|
|
AfmoePreTrainedModel.__init__(self, config)
|
|
self.model = AfmoeModel(config)
|
|
self.vocab_size = config.vocab_size
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
self.post_init()
|
|
|
|
|
|
__all__ = [
|
|
"AfmoeForCausalLM",
|
|
"AfmoeModel",
|
|
"AfmoePreTrainedModel",
|
|
]
|