# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable import torch from torch import nn from ... import initialization as init from ...cache_utils import Cache from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import check_model_inputs from ..bamba.configuration_bamba import BambaConfig from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding from ..granitemoeshared.modeling_granitemoeshared import ( GraniteFlashAttentionKwargs, GraniteMoeSharedAttention, GraniteMoeSharedDecoderLayer, GraniteMoeSharedForCausalLM, GraniteMoeSharedMLP, GraniteMoeSharedModel, GraniteMoeSharedMoE, GraniteMoeSharedPreTrainedModel, apply_rotary_pos_emb, eager_attention_forward, ) from .configuration_granitemoehybrid import GraniteMoeHybridConfig logger = logging.get_logger(__name__) class GraniteMoeHybridAttention(GraniteMoeSharedAttention): def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): super().__init__(config, layer_idx) def forward( # FIME: @ARTHUR this forward is also classic: attention nope self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None, past_key_values: Cache | None = None, cache_position: torch.LongTensor | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # None or rope embeddings **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if position_embeddings is not None: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class GraniteMoeHybridMambaLayer(BambaMixer): def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): super().__init__(BambaConfig(config), layer_idx) class GraniteMoeHybridRMSNormGated(BambaRMSNormGated): def __init__(self, hidden_size, eps=1e-6): super().__init__(hidden_size, eps) class GraniteMoeHybridMLP(GraniteMoeSharedMLP): def __init__(self, config: GraniteMoeHybridConfig): super().__init__(config) class GraniteMoeHybridRotaryEmbedding(Gemma2RotaryEmbedding): pass class GraniteMoeHybridMoE(GraniteMoeSharedMoE): pass class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer): def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): super().__init__(config, layer_idx) self.shared_mlp = GraniteMoeHybridMLP(config) # Either attention or mamba will be initialized, depending on the layer type. self.self_attn = None self.mamba = None if config.layers_block_type[layer_idx] == "mamba": self.mamba = GraniteMoeHybridMambaLayer(config, layer_idx) else: self.self_attn = GraniteMoeHybridAttention(config, layer_idx) self.layer_type = config.layers_block_type[layer_idx] # Allow non-MoE (dense) self.block_sparse_moe = GraniteMoeHybridMoE(config) if config.num_local_experts > 0 else None # Accept 0 experts: skip MoE if num_local_experts == 0 self.has_experts = getattr(config, "num_local_experts", 0) > 0 @auto_docstring def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, use_cache: bool | None = False, cache_position: torch.LongTensor | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) if self.mamba is not None: hidden_states = self.mamba( hidden_states=hidden_states, cache_position=cache_position, cache_params=past_key_values, attention_mask=attention_mask, **kwargs, ) else: hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states * self.residual_multiplier residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) if self.has_experts: moe_hidden_states = self.block_sparse_moe(hidden_states) hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) else: hidden_states = self.shared_mlp(hidden_states) hidden_states = residual + hidden_states * self.residual_multiplier return hidden_states class GraniteMoeHybridPreTrainedModel(GraniteMoeSharedPreTrainedModel): config: GraniteMoeHybridConfig _no_split_modules = ["GraniteMoeHybridDecoderLayer"] _is_stateful = True @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GraniteMoeHybridMambaLayer): init.ones_(module.dt_bias) init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1))) init.ones_(module.D) elif isinstance(module, GraniteMoeHybridRMSNormGated): init.ones_(module.weight) class GraniteMoeHybridModel(GraniteMoeSharedModel): def __init__(self, config: GraniteMoeHybridConfig): super().__init__(config) self.layers = nn.ModuleList( [GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.embedding_multiplier = config.embedding_multiplier self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if config.position_embedding_type == "rope" else None @auto_docstring @check_model_inputs def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, cache_position: torch.LongTensor | None = None, **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> tuple | BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = inputs_embeds * self.embedding_multiplier if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( self.config, inputs_embeds, attention_mask, cache_position, past_key_values, ) mamba_mask = self._update_mamba_mask(attention_mask, cache_position) # embed positions hidden_states = inputs_embeds position_embeddings = None if self.rotary_emb is not None: position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask hidden_states = decoder_layer( hidden_states, attention_mask=layer_mask, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) if past_key_values and not past_key_values.has_previous_state: past_key_values.has_previous_state = True return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) def _update_mamba_mask(self, attention_mask, cache_position): """ No need for zeroing states when 1. Cached forward 2. Attending to all inputs """ mamba_mask = attention_mask if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): mamba_mask = None return mamba_mask class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: GraniteMoeHybridConfig): super().__init__(config) self.model = GraniteMoeHybridModel(config) # Initialize weights and apply final processing self.post_init() def forward(self, **super_kwargs): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoTokenizer, GraniteMoeHybridForCausalLM >>> model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-h-tiny") >>> tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-4.0-h-tiny") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" return super().forward(**super_kwargs) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, is_first_iteration=False, **kwargs, ): # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` if past_key_values is None and use_cache: past_key_values = HybridMambaAttentionDynamicCache( self.config, input_ids.shape[0], self.dtype, device=self.device ) model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, is_first_iteration=is_first_iteration, **kwargs, ) return model_inputs __all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"]