# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/dbrx/modular_dbrx.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_dbrx.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable from typing import Any, Optional import torch from torch import nn from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_dbrx import DbrxConfig class DbrxRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: DbrxConfig, device=None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_type = self.config.rope_parameters["rope_type"] rope_init_fn: Callable = self.compute_default_rope_parameters if self.rope_type != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) @staticmethod def compute_default_rope_parameters( config: DbrxConfig | None = None, device: Optional["torch.device"] = None, seq_len: int | None = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation Args: config ([`~transformers.PreTrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ base = config.rope_parameters["rope_theta"] dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads attention_factor = 1.0 # Unused in this type of RoPE # Compute the inverse frequencies inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) ) return inv_freq, attention_factor @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class DbrxAttention(nn.Module): """Modular DBRX attention component that can be reused across different model architectures.""" def __init__( self, config, layer_idx: int | None = None, **kwargs, ): super().__init__() self.config = config self.hidden_size = config.d_model self.num_heads = config.n_heads self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.max_seq_len self.layer_idx = layer_idx attn_config = config.attn_config self.attention_dropout = attn_config.attn_pdrop self.clip_qkv = attn_config.clip_qkv self.num_key_value_heads = attn_config.kv_n_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scaling = self.head_dim**-0.5 self.rope_theta = attn_config.rope_theta self.is_causal = True self.Wqkv = nn.Linear( self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False ) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_embeddings: torch.LongTensor | None = None, past_key_values: Cache | None = None, cache_position: torch.LongTensor | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) qkv_states = self.Wqkv(hidden_states) min_val = -self.clip_qkv if self.clip_qkv is not None else None qkv_states = qkv_states.clamp(min=min_val, max=self.clip_qkv) query_states, key_states, value_states = qkv_states.split( [ self.hidden_size, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim, ], dim=2, ) query_states = query_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.out_proj(attn_output) return attn_output, attn_weights class DbrxExpertGLU(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.ffn_hidden_size = config.ffn_hidden_size self.moe_num_experts = config.moe_num_experts self.w1 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size)) self.v1 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size)) self.w2 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size)) act_fn_name = config.ffn_act_fn.get("name", "silu") self.activation_fn = ACT2FN[act_fn_name] def forward( self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor ) -> torch.Tensor: gate_proj = x.matmul(expert_w1) up_proj = x.matmul(expert_v1) gate_proj = self.activation_fn(gate_proj) intermediate_states = gate_proj * up_proj down_proj = intermediate_states.matmul(expert_w2.t()) return down_proj class DbrxExperts(nn.Module): def __init__(self, config): super().__init__() self.mlp = DbrxExpertGLU(config) self.hidden_size = config.hidden_size self.ffn_hidden_size = config.ffn_hidden_size self.num_experts = config.moe_num_experts def forward( self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.ffn_hidden_size) next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() split_expert_shape = (-1, self.ffn_hidden_size, self.hidden_size) for expert_idx in expert_hit: expert_idx = expert_idx[0] with torch.no_grad(): idx, token_idx = torch.where(expert_mask[expert_idx]) v1 = self.mlp.v1.view(split_expert_shape)[expert_idx] w1 = self.mlp.w1.view(split_expert_shape)[expert_idx] w2 = self.mlp.w2.view(split_expert_shape)[expert_idx] states = self.mlp(hidden_states[token_idx], w1, v1, w2) states = states.view(-1, self.ffn_hidden_size) * top_k_weights[token_idx, idx, None] next_states.index_add_(0, token_idx, states) next_states = next_states.view(batch_size, -1, self.ffn_hidden_size) return next_states class DbrxRouter(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.ffn_hidden_size self.moe_jitter_eps = config.moe_jitter_eps self.layer = nn.Linear(self.hidden_size, config.moe_num_experts, bias=False) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: if self.training and self.moe_jitter_eps is not None: hidden_states *= torch.empty_like(hidden_states).uniform_( 1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps ) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) router_logits = self.layer(hidden_states) return router_logits class DbrxFFN(nn.Module): """Modular DBRX MLP/FFN component with MoE support.""" def __init__(self, config, **kwargs): super().__init__() self.router = DbrxRouter(config.ffn_config) self.experts = DbrxExperts(config.ffn_config) self.moe_normalize_expert_weights = config.ffn_config.moe_normalize_expert_weights self.top_k = config.ffn_config.moe_top_k def route_tokens_to_experts(self, router_logits): router_logits = torch.nn.functional.softmax(router_logits, dim=1, dtype=router_logits.dtype) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) if self.moe_normalize_expert_weights is not None: router_top_value = router_top_value / torch.norm( router_top_value, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True ) return router_top_value, router_indices def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: router_logits = self.router(hidden_states) top_k_weights, top_k_index = self.route_tokens_to_experts(router_logits) output = self.experts(hidden_states, top_k_index, top_k_weights) return output class DbrxNormAttentionNorm(nn.Module): def __init__(self, config: DbrxConfig, layer_idx: int | None = None): super().__init__() self.layer_idx = layer_idx self.resid_pdrop = config.resid_pdrop self.norm_1 = nn.LayerNorm(config.d_model, bias=False) self.attn = DbrxAttention( config=config, layer_idx=layer_idx, ) self.norm_2 = nn.LayerNorm(config.d_model, bias=False) def forward( self, hidden_states: torch.Tensor, position_embeddings: torch.LongTensor, attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, cache_position: torch.LongTensor | None = None, **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor]: residual_states = hidden_states hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype) hidden_states, _ = self.attn( hidden_states=hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training) hidden_states = hidden_states + residual_states residual_states = hidden_states hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype) return residual_states, hidden_states class DbrxBlock(GradientCheckpointingLayer): def __init__(self, config: DbrxConfig, layer_idx: int): super().__init__() self.hidden_size = config.d_model self.resid_pdrop = config.resid_pdrop self.layer_idx = layer_idx self.norm_attn_norm = DbrxNormAttentionNorm( config=config, layer_idx=layer_idx, ) self.ffn = DbrxFFN(config=config) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_embeddings: torch.LongTensor | None = None, past_key_values: Cache | None = None, cache_position: torch.LongTensor | None = None, **kwargs: Any, ): resid_states, hidden_states = self.norm_attn_norm( hidden_states=hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) hidden_states = self.ffn(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training) hidden_states = resid_states + hidden_states return hidden_states class DbrxPreTrainedModel(PreTrainedModel): config: DbrxConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["DbrxBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flex_attn = True _supports_attention_backend = True _supports_flash_attn = True _supports_sdpa = True _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _can_record_outputs = { "hidden_states": DbrxBlock, "attentions": DbrxAttention, } @torch.no_grad() def _init_weights(self, module: nn.Module): super()._init_weights(module) std = self.config.initializer_range if isinstance(module, DbrxExpertGLU): init.normal_(module.w1, mean=0.0, std=std) init.normal_(module.v1, mean=0.0, std=std) init.normal_(module.w2, mean=0.0, std=std) @auto_docstring class DbrxModel(DbrxPreTrainedModel): """Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer. Args: config ([`DbrxConfig`]): Model configuration class with all parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ def __init__(self, config: DbrxConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.emb_pdrop = config.emb_pdrop self.rotary_emb = DbrxRotaryEmbedding(config) self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.blocks = nn.ModuleList([DbrxBlock(config, layer_idx) for layer_idx in range(config.n_layers)]) self.norm_f = nn.LayerNorm(config.d_model, bias=False) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.wte def set_input_embeddings(self, value: nn.Embedding): self.wte = value @check_model_inputs @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.blocks[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = self.norm_f(hidden_states) return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE last_hidden_state=hidden_states, past_key_values=past_key_values, ) def load_balancing_loss_func( gate_logits: torch.Tensor | tuple[torch.Tensor] | None, num_experts: int | None = None, top_k=2, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor | int: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced. Args: gate_logits: Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of shape [batch_size X sequence_length, num_experts]. num_experts: Number of experts top_k: The number of experts to route per-token, can be also interpreted as the `top-k` routing parameter. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. Returns: The auxiliary loss. """ if gate_logits is None or not isinstance(gate_logits, tuple): return 0 if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) if attention_mask is None: # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( expert_attention_mask, dim=0 ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) .reshape(-1, num_experts) .to(compute_device) ) # Compute the average probability of routing to these experts router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} _tp_plan = {"lm_head": "colwise_gather_output"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: DbrxConfig): super().__init__(config) self.transformer = DbrxModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.ffn_config.moe_loss_weight self.num_experts = config.ffn_config.moe_num_experts self.num_experts_per_tok = config.ffn_config.moe_top_k self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.transformer.get_input_embeddings() def set_input_embeddings(self, value: nn.Embedding): self.transformer.set_input_embeddings(value) def get_output_embeddings(self) -> nn.Linear: return self.lm_head def set_output_embeddings(self, new_embeddings: nn.Linear): self.lm_head = new_embeddings def set_decoder(self, decoder: DbrxModel): self.transformer = decoder def get_decoder(self) -> DbrxModel: return self.transformer @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_router_logits: bool | None = None, cache_position: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >> from transformers import AutoTokenizer, DbrxForCausalLM >> model = DbrxForCausalLM.from_pretrained("transformers-community/dbrx-instruct") >> tokenizer = AutoTokenizer.from_pretrained("transformers-community/dbrx-instruct") >> prompt = "Hey, are you conscious? Can you talk to me?" >> inputs = tokenizer(prompt, return_tensors="pt") >> # Generate >> generate_ids = model.generate(inputs.input_ids, max_length=30) >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ``` """ output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: MoeModelOutputWithPast = self.transformer( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_router_logits=output_router_logits, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) __all__ = ["DbrxForCausalLM", "DbrxModel", "DbrxPreTrainedModel"]