# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/afmoe/modular_afmoe.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_afmoe.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable from typing import Optional import torch from torch import nn from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_afmoe import AfmoeConfig class AfmoeRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: AfmoeConfig, device=None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_type = self.config.rope_parameters["rope_type"] rope_init_fn: Callable = self.compute_default_rope_parameters if self.rope_type != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) @staticmethod def compute_default_rope_parameters( config: AfmoeConfig | None = None, device: Optional["torch.device"] = None, seq_len: int | None = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation Args: config ([`~transformers.PreTrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ base = config.rope_parameters["rope_theta"] dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads attention_factor = 1.0 # Unused in this type of RoPE # Compute the inverse frequencies inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) ) return inv_freq, attention_factor @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @use_kernel_forward_from_hub("RMSNorm") class AfmoeRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ AfmoeRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return (self.weight * hidden_states).to(input_dtype) # main diff with Llama def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class AfmoeMLP(nn.Module): def __init__(self, config, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class AfmoeTokenChoiceRouter(nn.Module): """ Token-choice top-K router for MoE routing. This router assigns each token to the top-K experts based on sigmoid scores, matching the released checkpoints. """ def __init__(self, config): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.num_experts = config.num_experts self.route_scale = config.route_scale self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor): _, _, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32)) _, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1) top_scores = scores.gather(dim=1, index=selected_experts) denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 top_scores = top_scores / denominator top_scores = top_scores * self.route_scale return top_scores, selected_experts class AfmoeExperts(nn.ModuleList): """ Container holding the routed experts. This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion. """ def __init__(self, config: AfmoeConfig): super().__init__() self.top_k = config.num_experts_per_tok self.num_experts = config.num_experts for _ in range(self.num_experts): self.append(AfmoeMLP(config, intermediate_size=config.moe_intermediate_size)) def forward( self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor ) -> torch.Tensor: """ Args: hidden_states: (batch, seq, hidden) selected_experts: (batch, seq, top_k) routing_weights: (batch, seq, top_k) """ batch_size, seq_len, hidden_dim = hidden_states.shape if seq_len == 0: return hidden_states.new_zeros(batch_size, 0, hidden_dim) hidden_states_flat = hidden_states.view(-1, hidden_dim) top_k = selected_experts.shape[-1] # Map every token routing decision to a unique position so we can process expert by expert. token_indices = torch.arange( hidden_states_flat.shape[0], device=hidden_states.device, dtype=torch.long ).repeat_interleave(top_k) expert_indices = selected_experts.reshape(-1) routing_weights = routing_weights.reshape(-1) sorting = torch.argsort(expert_indices, stable=True) token_indices = token_indices[sorting] expert_indices = expert_indices[sorting] routing_weights = routing_weights[sorting] dispatched_tokens = hidden_states_flat.index_select(0, token_indices) expert_outputs = torch.zeros_like(dispatched_tokens) unique_experts, counts = torch.unique_consecutive(expert_indices, return_counts=True) start = 0 for expert_id, count in zip(unique_experts.tolist(), counts.tolist()): if count == 0: continue end = start + count expert_input = dispatched_tokens[start:end] expert_output = self[expert_id](expert_input) expert_outputs[start:end] = expert_output start = end weighted_outputs = (expert_outputs.to(torch.float32) * routing_weights.unsqueeze(-1)).to(hidden_states.dtype) aggregated = torch.zeros_like(hidden_states_flat) scatter_indices = token_indices.unsqueeze(-1).expand_as(weighted_outputs) aggregated.scatter_add_(0, scatter_indices, weighted_outputs) return aggregated.view(batch_size, seq_len, hidden_dim) class AfmoeMoE(nn.Module): """ Mixture of Experts (MoE) module for AFMoE. This module implements a sparse MoE layer with both shared experts (always active) and routed experts (activated based on token-choice routing). """ def __init__(self, config): super().__init__() self.config = config self.router = AfmoeTokenChoiceRouter(config) self.shared_experts = AfmoeMLP(config, config.moe_intermediate_size * config.num_shared_experts) self.experts = AfmoeExperts(config) self.expert_bias = nn.Parameter(torch.zeros(config.num_experts), requires_grad=False) def forward(self, hidden_states): batch_size, seq_len, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) # Get routing decisions top_scores, selected_experts = self.router(hidden_states, self.expert_bias) top_scores = top_scores.view(batch_size, seq_len, self.config.num_experts_per_tok) selected_experts = selected_experts.view(batch_size, seq_len, self.config.num_experts_per_tok) # Process through shared experts shared_output = self.shared_experts(hidden_states_flat).view(batch_size, seq_len, hidden_dim) routed_output = self.experts(hidden_states, selected_experts, top_scores) return shared_output + routed_output def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @use_kernelized_func(apply_rotary_pos_emb) class AfmoeAttention(nn.Module): """ Multi-headed attention module with optional sliding window and gating. This attention mechanism supports both full attention and sliding window attention, and includes Q/K normalization and gating of the output. It inherits from [`LlamaAttention`] to minimize the amount of custom logic we need to maintain. """ def __init__(self, config: AfmoeConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) # Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim # We only add AFMoE-specific attributes self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention" self.sliding_window = config.sliding_window if self.is_local_attention else None self.q_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.gate_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, past_key_value: Cache | None = None, cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape) key_states = self.k_proj(hidden_states).view(hidden_shape) value_states = self.v_proj(hidden_states).view(hidden_shape) gate_states = self.gate_proj(hidden_states) query_states = self.q_norm(query_states).transpose(1, 2) key_states = self.k_norm(key_states).transpose(1, 2) value_states = value_states.transpose(1, 2) if self.is_local_attention: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask=attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) output = output.view(*input_shape, -1).contiguous() output = output * torch.sigmoid(gate_states) attn_output = self.o_proj(output) return attn_output, attn_weights class AfmoeDecoderLayer(GradientCheckpointingLayer): """ AFMoE decoder layer with dual normalization. This layer applies self-attention followed by either a dense MLP or MoE block, with dual normalization (pre and post) around each component. """ def __init__(self, config: AfmoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.layer_idx = layer_idx self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx) self.attention_type = config.layer_types[layer_idx] # Dual normalization for attention self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Dual normalization for FFN self.pre_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # MoE or dense FFN self.moe_enabled = layer_idx >= config.num_dense_layers if self.moe_enabled: self.mlp = AfmoeMoE(config) else: self.mlp = AfmoeMLP(config) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_value: Cache | None = None, use_cache: bool | None = None, cache_position: torch.LongTensor | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: residual = hidden_states # Self Attention with dual normalization hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states # FFN with dual normalization residual = hidden_states hidden_states = self.pre_mlp_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.post_mlp_layernorm(hidden_states) hidden_states = residual + hidden_states return hidden_states class AfmoePreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config: AfmoeConfig base_model_prefix = "model" _no_split_modules = ["AfmoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _can_record_outputs = { "hidden_states": AfmoeDecoderLayer, "attentions": AfmoeAttention, } _keep_in_fp32_modules = [ "input_layernorm", "post_attention_layernorm", "pre_mlp_layernorm", "post_mlp_layernorm", "q_norm", "k_norm", "norm", "expert_bias", ] _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" super()._init_weights(module) if isinstance(module, AfmoeTokenChoiceRouter): init.zeros_(module.gate.weight) elif isinstance(module, AfmoeMoE): init.zeros_(module.expert_bias) @auto_docstring class AfmoeModel(AfmoePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AfmoeDecoderLayer`] Args: config: AfmoeConfig """ def __init__(self, config: AfmoeConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [AfmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = AfmoeRotaryEmbedding(config=config) self.gradient_checkpointing = False self.post_init() @auto_docstring @check_model_inputs def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, inputs_embeds: torch.FloatTensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, cache_position: torch.LongTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # It may already have been prepared by e.g. `generate` if not isinstance(causal_mask_mapping := attention_mask, dict): mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, } causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } hidden_states = inputs_embeds # Apply muP input scaling if enabled if self.config.mup_enabled: hidden_states = hidden_states * (self.config.hidden_size**0.5) position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) @auto_docstring class AfmoeForCausalLM(AfmoePreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_gather_output"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = AfmoeModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, cache_position: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" Example: ```python >>> from transformers import AutoTokenizer, AfmoeForCausalLM >>> model = AfmoeForCausalLM.from_pretrained("meta-afmoe/Afmoe-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-afmoe/Afmoe-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = ["AfmoeForCausalLM", "AfmoeModel", "AfmoePreTrainedModel"]