# Copyright 2025 the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch.nn as nn from ...modeling_rope_utils import RopeParameters from ...utils import auto_docstring, can_return_tuple from ..llama.configuration_llama import LlamaConfig from ..llama.modeling_llama import ( LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel, ) from ..nemotron.modeling_nemotron import NemotronMLP class Jais2Config(LlamaConfig): r""" This is the configuration class to store the configuration of a [`Jais2Model`]. It is used to instantiate a Jais2 model according to the specified arguments, defining the model architecture. [inceptionai/Jais-2-8B-Chat](https://huggingface.co/inceptionai/Jais-2-8B-Chat). Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 150272): Vocabulary size of the Jais2 model. hidden_size (`int`, *optional*, defaults to 3328): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 26624): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 26): Number of attention heads for each attention layer. num_key_value_heads (`int`, *optional*): Number of key_value heads for Grouped Query Attention. hidden_act (`str`, *optional*, defaults to `"relu2"`): The non-linear activation function in the decoder. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer. layer_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether to return last key/values attentions. pad_token_id (`int`, *optional*): Padding token id. bos_token_id (`int`, *optional*, defaults to 0): Beginning of stream token id. eos_token_id (`int`, *optional*, defaults to 150024): End of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings. attention_bias (`bool`, *optional*, defaults to `True`): Whether to use a bias in the query, key, value and output projection layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. mlp_bias (`bool`, *optional*, defaults to `True`): Whether to use a bias in up_proj, down_proj and gate_proj layers. head_dim (`int`, *optional*): The attention head dimension. rope_parameters (`dict`, *optional*): The RoPE parameters. """ model_type = "jais2" base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } def __init__( self, vocab_size: int | None = 150272, hidden_size: int | None = 3328, intermediate_size: int | None = 26624, num_hidden_layers: int | None = 32, num_attention_heads: int | None = 26, num_key_value_heads: int | None = None, hidden_act: str | None = "relu2", max_position_embeddings: int | None = 8192, initializer_range: float | None = 0.02, layer_norm_eps: float | None = 1e-5, use_cache: bool | None = True, pad_token_id: int | None = None, bos_token_id: int | None = 0, eos_token_id: int | None = 150024, tie_word_embeddings: bool | None = False, attention_bias: bool | None = True, attention_dropout: float | None = 0.0, mlp_bias: bool | None = True, head_dim: int | None = None, rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, **kwargs, ): super().__init__( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, initializer_range=initializer_range, use_cache=use_cache, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, attention_bias=attention_bias, attention_dropout=attention_dropout, mlp_bias=mlp_bias, head_dim=head_dim, rope_parameters=rope_parameters, **kwargs, ) self.layer_norm_eps = layer_norm_eps del self.rms_norm_eps del self.pretraining_tp class Jais2MLP(NemotronMLP): pass class Jais2DecoderLayer(LlamaDecoderLayer): def __init__(self, config: Jais2Config, layer_idx: int): super().__init__(config, layer_idx) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) class Jais2PreTrainedModel(LlamaPreTrainedModel): pass class Jais2Model(LlamaModel): def __init__(self, config: Jais2Config): super().__init__(config) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) class Jais2ForCausalLM(LlamaForCausalLM): @can_return_tuple @auto_docstring def forward(self, **super_kwargs): r""" Example: ```python >>> from transformers import AutoTokenizer, Jais2ForCausalLM >>> model = Jais2ForCausalLM.from_pretrained("inceptionai/Jais-2-8B-Chat") >>> tokenizer = AutoTokenizer.from_pretrained("inceptionai/Jais-2-8B-Chat") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" return super().forward(**super_kwargs) __all__ = [ "Jais2Config", "Jais2Model", "Jais2ForCausalLM", "Jais2PreTrainedModel", ]