You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

412 lines
15 KiB

# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Phimoe model."""
from collections.abc import Callable
import torch
from torch import nn
from ...modeling_layers import (
GenericForSequenceClassification,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...utils.generic import OutputRecorder, maybe_autocast
from ..llama.modeling_llama import LlamaAttention
from ..mixtral.modeling_mixtral import (
MixtralDecoderLayer,
MixtralExperts,
MixtralForCausalLM,
MixtralModel,
MixtralPreTrainedModel,
MixtralRotaryEmbedding,
)
from .configuration_phimoe import PhimoeConfig
class PhimoeRotaryEmbedding(MixtralRotaryEmbedding):
def __init__(self, config: PhimoeConfig, device=None):
nn.Module.__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
self.rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
def forward(self, x, position_ids=None, layer_type=None):
if layer_type is not None:
raise ValueError(
f"{self.__class__.__name__} does not support layer types, but got `layer_type={layer_type}`"
)
mscale = None
seq_len = torch.max(position_ids) + 1
if self.config.rope_parameters["rope_type"] != "default" and seq_len:
mscale = (
self.config.rope_parameters["long_mscale"]
if seq_len > self.config.rope_parameters["original_max_position_embeddings"]
else self.config.rope_parameters["short_mscale"]
)
inv_freq, attention_scaling = self.rope_init_fn(self.config, x.device, seq_len)
mscale = attention_scaling if mscale is None else mscale
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * mscale
sin = emb.sin() * mscale
return cos.to(x.dtype), sin.to(x.dtype)
class PhimoeAttention(LlamaAttention):
pass
class PhimoeMultiplier(torch.autograd.Function):
@staticmethod
def forward(
ctx,
scores: torch.Tensor,
multiplier: torch.Tensor,
selected_experts: torch.Tensor,
masked_gates: torch.Tensor,
mask_for_one: torch.Tensor,
):
"""
Forward pass for the custom autograd function.
Args:
ctx: Context object to save information for backward computation.
scores (torch.Tensor): Input scores tensor.
multiplier (torch.Tensor): Multiplier tensor.
selected_experts (torch.Tensor): Tensor of selected experts.
masked_gates (torch.Tensor): Masked gates tensor.
mask_for_one (torch.Tensor): Mask for one tensor.
Returns:
torch.Tensor: Result of the forward pass.
"""
ctx.save_for_backward(multiplier, selected_experts, masked_gates)
return multiplier * mask_for_one
@staticmethod
def backward(
ctx,
grad_at_output: torch.Tensor,
):
"""
Backward pass for the custom autograd function.
Args:
ctx: Context object with saved tensors from the forward pass.
grad_at_output (torch.Tensor): Gradient at the output.
Returns:
tuple[torch.Tensor, None, None, None, None]: Gradients for the inputs.
"""
multiplier, selected_experts, masked_gates = ctx.saved_tensors
grad_at_output = grad_at_output * multiplier
grad_at_scores_expanded = masked_gates * grad_at_output.mul(-1)
grad_at_scores_expanded.scatter_add_(
dim=-1,
index=selected_experts,
src=grad_at_output,
)
return (
grad_at_scores_expanded,
None,
None,
None,
None,
)
def sparsemixer(scores, jitter_eps, training, top_k=2):
"""
Sparse mixer function to select top-k experts and compute multipliers.
Based on the paper: https://huggingface.co/papers/2409.12136
We first replace the TopK(·) function as random sampling of discrete variables
in model training. Then, following Liu et al. (2023a) and Liu et al. (2023b), we apply Heun's
third order method to approximate the expert routing gradient and construct a modified
back-propagation to give a mathematically sound gradient estimation for expert routing.
Args:
scores (torch.Tensor): Input scores tensor.
jitter_eps (float): Jitter epsilon for numerical stability.
training (bool): Flag indicating if the model is in training mode.
top_k (int): Number of top experts to select.
Returns:
tuple[torch.Tensor, torch.Tensor]: Multiplier and selected experts tensors.
"""
with torch.no_grad():
# Compute mask for sparsity
mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold)
mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
# Apply mask
masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf"))
if training:
selected_experts = (
(
masked_gates
- torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log()
)
.max(dim=-1)[1]
.unsqueeze(-1)
) # Gumbel sampling, more robust than the multinomial method
else:
selected_experts = max_ind
# Compute scores for gradients
masked_gates = torch.softmax(masked_gates, dim=-1)
multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
if training:
# Compute midpoint mask
max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
mask_for_one = torch.logical_or(
selected_experts == max_ind,
torch.rand_like(max_scores) > 0.75, # Heun's third-order method
)
# 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
multiplier = PhimoeMultiplier.apply(
scores,
multiplier_o,
selected_experts,
masked_gates,
mask_for_one,
)
else:
multiplier = multiplier_o
# Masked out first expert
masked_scores = torch.scatter(
scores,
-1,
selected_experts,
float("-inf"),
)
with torch.no_grad():
# Compute mask for sparsity
mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold)
mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
# Apply mask
masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
if training:
selected_experts_top2 = (
(
masked_gates_top2
- torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
.max(dim=-1)[1]
.unsqueeze(-1)
) # Gumbel sampling, more robust than the multinomial method
else:
selected_experts_top2 = max_ind
# Compute scores for gradients
masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)
if training:
# Compute midpoint mask
max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True)
mask_for_one_top2 = torch.logical_or(
selected_experts_top2 == max_ind,
torch.rand_like(max_scores).uniform_() > 0.75, # Heun's third-order method
)
# 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2)
multiplier_top2 = PhimoeMultiplier.apply(
scores,
multiplier_top2_o,
selected_experts_top2,
masked_gates_top2,
mask_for_one_top2,
)
else:
multiplier_top2 = multiplier_top2_o
multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
return (
multiplier,
selected_experts,
)
class PhimoeExperts(MixtralExperts):
pass
class PhimoeTopKRouter(nn.Linear):
def __init__(self, config: PhimoeConfig):
super().__init__(config.hidden_size, config.num_local_experts, bias=False)
self.router_jitter_noise = config.router_jitter_noise
self.input_jitter_noise = config.input_jitter_noise
self.top_k = config.num_experts_per_tok
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.training and self.input_jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
)
router_logits = super().forward(hidden_states)
routing_weights, selected_experts = sparsemixer(
router_logits, jitter_eps=self.router_jitter_noise, training=self.training, top_k=self.top_k
)
return routing_weights, selected_experts
class PhimoeSparseMoeBlock(nn.Module):
"""
This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accommodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
self.router = PhimoeTopKRouter(config)
self.experts = PhimoeExperts(config)
self.input_jitter_noise = config.input_jitter_noise
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.input_jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
)
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim)
routing_weights, selected_experts = self.router(hidden_states)
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
class PhimoeDecoderLayer(MixtralDecoderLayer):
def __init__(self, config: PhimoeConfig, layer_idx: int):
super().__init__(config, layer_idx)
# Phimoe uses nn.LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
)
class PhimoePreTrainedModel(MixtralPreTrainedModel):
_can_record_outputs = {
"router_logits": OutputRecorder(PhimoeTopKRouter, layer_name="mlp.router", index=0),
"hidden_states": PhimoeDecoderLayer,
"attentions": PhimoeAttention,
}
class PhimoeModel(MixtralModel):
def __init__(self, config: PhimoeConfig):
super().__init__(config)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
class PhimoeForCausalLM(MixtralForCausalLM):
def __init__(self, config):
super().__init__(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
# Copied from transformers.models.phi3.modeling_phi3.Phi3ForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
# process
# When the first time input length reached long and short factor switching point, enforce re-compute cache
# It will cause downside of slower at this single token position, however, better than current failure.
if (
past_key_values
and hasattr(self.config, "original_max_position_embeddings")
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
):
past_length = cache_position[0]
if past_length <= self.config.original_max_position_embeddings:
past_key_values = None
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
return model_inputs
class PhimoeForSequenceClassification(GenericForSequenceClassification, PhimoePreTrainedModel): ...
__all__ = [
"PhimoePreTrainedModel",
"PhimoeModel",
"PhimoeForCausalLM",
"PhimoeForSequenceClassification",
]