You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
776 lines
32 KiB
776 lines
32 KiB
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/dbrx/modular_dbrx.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_dbrx.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections.abc import Callable
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ... import initialization as init
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...generation import GenerationMixin
|
|
from ...integrations import use_kernel_func_from_hub
|
|
from ...masking_utils import create_causal_mask
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
from .configuration_dbrx import DbrxConfig
|
|
|
|
|
|
class DbrxRotaryEmbedding(nn.Module):
|
|
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
|
|
|
def __init__(self, config: DbrxConfig, device=None):
|
|
super().__init__()
|
|
self.max_seq_len_cached = config.max_position_embeddings
|
|
self.original_max_seq_len = config.max_position_embeddings
|
|
|
|
self.config = config
|
|
|
|
self.rope_type = self.config.rope_parameters["rope_type"]
|
|
rope_init_fn: Callable = self.compute_default_rope_parameters
|
|
if self.rope_type != "default":
|
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
|
|
@staticmethod
|
|
def compute_default_rope_parameters(
|
|
config: DbrxConfig | None = None,
|
|
device: Optional["torch.device"] = None,
|
|
seq_len: int | None = None,
|
|
) -> tuple["torch.Tensor", float]:
|
|
"""
|
|
Computes the inverse frequencies according to the original RoPE implementation
|
|
Args:
|
|
config ([`~transformers.PreTrainedConfig`]):
|
|
The model configuration.
|
|
device (`torch.device`):
|
|
The device to use for initialization of the inverse frequencies.
|
|
seq_len (`int`, *optional*):
|
|
The current sequence length. Unused for this type of RoPE.
|
|
Returns:
|
|
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
|
"""
|
|
base = config.rope_parameters["rope_theta"]
|
|
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
|
|
|
attention_factor = 1.0 # Unused in this type of RoPE
|
|
|
|
# Compute the inverse frequencies
|
|
inv_freq = 1.0 / (
|
|
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
|
)
|
|
return inv_freq, attention_factor
|
|
|
|
@torch.no_grad()
|
|
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
|
def forward(self, x, position_ids):
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
cos = emb.cos() * self.attention_scaling
|
|
sin = emb.sin() * self.attention_scaling
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
|
|
|
|
def rotate_half(x):
|
|
"""Rotates half the hidden dims of the input."""
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
@use_kernel_func_from_hub("rotary_pos_emb")
|
|
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
Args:
|
|
q (`torch.Tensor`): The query tensor.
|
|
k (`torch.Tensor`): The key tensor.
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
Returns:
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
"""
|
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
"""
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
"""
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: torch.Tensor | None,
|
|
scaling: float,
|
|
dropout: float = 0.0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
key_states = repeat_kv(key, module.num_key_value_groups)
|
|
value_states = repeat_kv(value, module.num_key_value_groups)
|
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
|
if attention_mask is not None:
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class DbrxAttention(nn.Module):
|
|
"""Modular DBRX attention component that can be reused across different model architectures."""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
layer_idx: int | None = None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.d_model
|
|
self.num_heads = config.n_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.max_position_embeddings = config.max_seq_len
|
|
self.layer_idx = layer_idx
|
|
|
|
attn_config = config.attn_config
|
|
self.attention_dropout = attn_config.attn_pdrop
|
|
self.clip_qkv = attn_config.clip_qkv
|
|
self.num_key_value_heads = attn_config.kv_n_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
self.rope_theta = attn_config.rope_theta
|
|
self.is_causal = True
|
|
|
|
self.Wqkv = nn.Linear(
|
|
self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False
|
|
)
|
|
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_embeddings: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
|
|
qkv_states = self.Wqkv(hidden_states)
|
|
min_val = -self.clip_qkv if self.clip_qkv is not None else None
|
|
qkv_states = qkv_states.clamp(min=min_val, max=self.clip_qkv)
|
|
|
|
query_states, key_states, value_states = qkv_states.split(
|
|
[
|
|
self.hidden_size,
|
|
self.num_key_value_heads * self.head_dim,
|
|
self.num_key_value_heads * self.head_dim,
|
|
],
|
|
dim=2,
|
|
)
|
|
|
|
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
|
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
|
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
|
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_values is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
attn_output = self.out_proj(attn_output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class DbrxExpertGLU(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.ffn_hidden_size = config.ffn_hidden_size
|
|
self.moe_num_experts = config.moe_num_experts
|
|
|
|
self.w1 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size))
|
|
self.v1 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size))
|
|
self.w2 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size))
|
|
|
|
act_fn_name = config.ffn_act_fn.get("name", "silu")
|
|
self.activation_fn = ACT2FN[act_fn_name]
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor
|
|
) -> torch.Tensor:
|
|
gate_proj = x.matmul(expert_w1)
|
|
up_proj = x.matmul(expert_v1)
|
|
gate_proj = self.activation_fn(gate_proj)
|
|
intermediate_states = gate_proj * up_proj
|
|
down_proj = intermediate_states.matmul(expert_w2.t())
|
|
return down_proj
|
|
|
|
|
|
class DbrxExperts(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.mlp = DbrxExpertGLU(config)
|
|
self.hidden_size = config.hidden_size
|
|
self.ffn_hidden_size = config.ffn_hidden_size
|
|
self.num_experts = config.moe_num_experts
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
top_k_index: torch.Tensor,
|
|
top_k_weights: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
batch_size = hidden_states.shape[0]
|
|
hidden_states = hidden_states.reshape(-1, self.ffn_hidden_size)
|
|
|
|
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
|
|
with torch.no_grad():
|
|
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
|
|
expert_mask = expert_mask.permute(2, 1, 0)
|
|
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
|
|
|
split_expert_shape = (-1, self.ffn_hidden_size, self.hidden_size)
|
|
for expert_idx in expert_hit:
|
|
expert_idx = expert_idx[0]
|
|
with torch.no_grad():
|
|
idx, token_idx = torch.where(expert_mask[expert_idx])
|
|
v1 = self.mlp.v1.view(split_expert_shape)[expert_idx]
|
|
w1 = self.mlp.w1.view(split_expert_shape)[expert_idx]
|
|
w2 = self.mlp.w2.view(split_expert_shape)[expert_idx]
|
|
states = self.mlp(hidden_states[token_idx], w1, v1, w2)
|
|
states = states.view(-1, self.ffn_hidden_size) * top_k_weights[token_idx, idx, None]
|
|
next_states.index_add_(0, token_idx, states)
|
|
|
|
next_states = next_states.view(batch_size, -1, self.ffn_hidden_size)
|
|
return next_states
|
|
|
|
|
|
class DbrxRouter(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.hidden_size = config.ffn_hidden_size
|
|
self.moe_jitter_eps = config.moe_jitter_eps
|
|
self.layer = nn.Linear(self.hidden_size, config.moe_num_experts, bias=False)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
|
if self.training and self.moe_jitter_eps is not None:
|
|
hidden_states *= torch.empty_like(hidden_states).uniform_(
|
|
1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps
|
|
)
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
router_logits = self.layer(hidden_states)
|
|
return router_logits
|
|
|
|
|
|
class DbrxFFN(nn.Module):
|
|
"""Modular DBRX MLP/FFN component with MoE support."""
|
|
|
|
def __init__(self, config, **kwargs):
|
|
super().__init__()
|
|
self.router = DbrxRouter(config.ffn_config)
|
|
self.experts = DbrxExperts(config.ffn_config)
|
|
|
|
self.moe_normalize_expert_weights = config.ffn_config.moe_normalize_expert_weights
|
|
self.top_k = config.ffn_config.moe_top_k
|
|
|
|
def route_tokens_to_experts(self, router_logits):
|
|
router_logits = torch.nn.functional.softmax(router_logits, dim=1, dtype=router_logits.dtype)
|
|
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
|
|
if self.moe_normalize_expert_weights is not None:
|
|
router_top_value = router_top_value / torch.norm(
|
|
router_top_value, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
|
)
|
|
return router_top_value, router_indices
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
router_logits = self.router(hidden_states)
|
|
top_k_weights, top_k_index = self.route_tokens_to_experts(router_logits)
|
|
output = self.experts(hidden_states, top_k_index, top_k_weights)
|
|
return output
|
|
|
|
|
|
class DbrxNormAttentionNorm(nn.Module):
|
|
def __init__(self, config: DbrxConfig, layer_idx: int | None = None):
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.resid_pdrop = config.resid_pdrop
|
|
self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
|
|
self.attn = DbrxAttention(
|
|
config=config,
|
|
layer_idx=layer_idx,
|
|
)
|
|
self.norm_2 = nn.LayerNorm(config.d_model, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: torch.LongTensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
**kwargs: Any,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
residual_states = hidden_states
|
|
hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
|
|
|
|
hidden_states, _ = self.attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_embeddings=position_embeddings,
|
|
past_key_values=past_key_values,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
|
|
hidden_states = hidden_states + residual_states
|
|
|
|
residual_states = hidden_states
|
|
hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
|
|
|
|
return residual_states, hidden_states
|
|
|
|
|
|
class DbrxBlock(GradientCheckpointingLayer):
|
|
def __init__(self, config: DbrxConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.hidden_size = config.d_model
|
|
self.resid_pdrop = config.resid_pdrop
|
|
self.layer_idx = layer_idx
|
|
self.norm_attn_norm = DbrxNormAttentionNorm(
|
|
config=config,
|
|
layer_idx=layer_idx,
|
|
)
|
|
self.ffn = DbrxFFN(config=config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_embeddings: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
**kwargs: Any,
|
|
):
|
|
resid_states, hidden_states = self.norm_attn_norm(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_embeddings=position_embeddings,
|
|
past_key_values=past_key_values,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = self.ffn(hidden_states)
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
|
|
hidden_states = resid_states + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class DbrxPreTrainedModel(PreTrainedModel):
|
|
config: DbrxConfig
|
|
base_model_prefix = "transformer"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["DbrxBlock"]
|
|
_skip_keys_device_placement = ["past_key_values"]
|
|
_supports_flex_attn = True
|
|
_supports_attention_backend = True
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
|
_can_record_outputs = {
|
|
"hidden_states": DbrxBlock,
|
|
"attentions": DbrxAttention,
|
|
}
|
|
|
|
@torch.no_grad()
|
|
def _init_weights(self, module: nn.Module):
|
|
super()._init_weights(module)
|
|
std = self.config.initializer_range
|
|
if isinstance(module, DbrxExpertGLU):
|
|
init.normal_(module.w1, mean=0.0, std=std)
|
|
init.normal_(module.v1, mean=0.0, std=std)
|
|
init.normal_(module.w2, mean=0.0, std=std)
|
|
|
|
|
|
@auto_docstring
|
|
class DbrxModel(DbrxPreTrainedModel):
|
|
"""Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer.
|
|
|
|
Args:
|
|
config ([`DbrxConfig`]): Model configuration class with all parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the
|
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
def __init__(self, config: DbrxConfig):
|
|
super().__init__(config)
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
self.emb_pdrop = config.emb_pdrop
|
|
self.rotary_emb = DbrxRotaryEmbedding(config)
|
|
self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
|
self.blocks = nn.ModuleList([DbrxBlock(config, layer_idx) for layer_idx in range(config.n_layers)])
|
|
self.norm_f = nn.LayerNorm(config.d_model, bias=False)
|
|
self.gradient_checkpointing = False
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
return self.wte
|
|
|
|
def set_input_embeddings(self, value: nn.Embedding):
|
|
self.wte = value
|
|
|
|
@check_model_inputs
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
use_cache: bool | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> MoeModelOutputWithPast:
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = DynamicCache(config=self.config)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.wte(input_ids)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
causal_mask = create_causal_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
for decoder_layer in self.blocks[: self.config.num_hidden_layers]:
|
|
hidden_states = decoder_layer(
|
|
hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = self.norm_f(hidden_states)
|
|
|
|
return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values,
|
|
)
|
|
|
|
|
|
def load_balancing_loss_func(
|
|
gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
|
|
num_experts: int | None = None,
|
|
top_k=2,
|
|
attention_mask: torch.Tensor | None = None,
|
|
) -> torch.Tensor | int:
|
|
r"""
|
|
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
|
|
|
See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
|
|
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
|
experts is too unbalanced.
|
|
|
|
Args:
|
|
gate_logits:
|
|
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
|
shape [batch_size X sequence_length, num_experts].
|
|
num_experts:
|
|
Number of experts
|
|
top_k:
|
|
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
|
parameter.
|
|
attention_mask (`torch.Tensor`, *optional*):
|
|
The attention_mask used in forward function
|
|
shape [batch_size X sequence_length] if not None.
|
|
|
|
Returns:
|
|
The auxiliary loss.
|
|
"""
|
|
if gate_logits is None or not isinstance(gate_logits, tuple):
|
|
return 0
|
|
|
|
if isinstance(gate_logits, tuple):
|
|
compute_device = gate_logits[0].device
|
|
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
|
|
|
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
|
|
|
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
|
|
|
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
|
|
|
if attention_mask is None:
|
|
# Compute the percentage of tokens routed to each experts
|
|
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
|
|
|
# Compute the average probability of routing to these experts
|
|
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
|
else:
|
|
batch_size, sequence_length = attention_mask.shape
|
|
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
|
|
|
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
|
expert_attention_mask = (
|
|
attention_mask[None, :, :, None, None]
|
|
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
|
.reshape(-1, top_k, num_experts)
|
|
.to(compute_device)
|
|
)
|
|
|
|
# Compute the percentage of tokens routed to each experts
|
|
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
|
expert_attention_mask, dim=0
|
|
)
|
|
|
|
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
|
router_per_expert_attention_mask = (
|
|
attention_mask[None, :, :, None]
|
|
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
|
.reshape(-1, num_experts)
|
|
.to(compute_device)
|
|
)
|
|
|
|
# Compute the average probability of routing to these experts
|
|
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
|
router_per_expert_attention_mask, dim=0
|
|
)
|
|
|
|
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
|
return overall_loss * num_experts
|
|
|
|
|
|
class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin):
|
|
_tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
|
|
_tp_plan = {"lm_head": "colwise_gather_output"}
|
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
|
|
|
def __init__(self, config: DbrxConfig):
|
|
super().__init__(config)
|
|
self.transformer = DbrxModel(config)
|
|
self.vocab_size = config.vocab_size
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
self.router_aux_loss_coef = config.ffn_config.moe_loss_weight
|
|
self.num_experts = config.ffn_config.moe_num_experts
|
|
self.num_experts_per_tok = config.ffn_config.moe_top_k
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
return self.transformer.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value: nn.Embedding):
|
|
self.transformer.set_input_embeddings(value)
|
|
|
|
def get_output_embeddings(self) -> nn.Linear:
|
|
return self.lm_head
|
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Linear):
|
|
self.lm_head = new_embeddings
|
|
|
|
def set_decoder(self, decoder: DbrxModel):
|
|
self.transformer = decoder
|
|
|
|
def get_decoder(self) -> DbrxModel:
|
|
return self.transformer
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
labels: torch.LongTensor | None = None,
|
|
use_cache: bool | None = None,
|
|
output_router_logits: bool | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
logits_to_keep: int | torch.Tensor = 0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> MoeCausalLMOutputWithPast:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>> from transformers import AutoTokenizer, DbrxForCausalLM
|
|
|
|
>> model = DbrxForCausalLM.from_pretrained("transformers-community/dbrx-instruct")
|
|
>> tokenizer = AutoTokenizer.from_pretrained("transformers-community/dbrx-instruct")
|
|
|
|
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
>> # Generate
|
|
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
```
|
|
"""
|
|
output_router_logits = (
|
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
)
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
outputs: MoeModelOutputWithPast = self.transformer(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_router_logits=output_router_logits,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs.last_hidden_state
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
|
|
|
aux_loss = None
|
|
if output_router_logits:
|
|
aux_loss = load_balancing_loss_func(
|
|
outputs.router_logits,
|
|
self.num_experts,
|
|
self.num_experts_per_tok,
|
|
attention_mask,
|
|
)
|
|
if labels is not None:
|
|
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
|
|
return MoeCausalLMOutputWithPast(
|
|
loss=loss,
|
|
aux_loss=aux_loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
router_logits=outputs.router_logits,
|
|
)
|
|
|
|
|
|
__all__ = ["DbrxForCausalLM", "DbrxModel", "DbrxPreTrainedModel"]
|