You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
188 lines
9.4 KiB
188 lines
9.4 KiB
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/cwm/modular_cwm.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_cwm.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# Copyright 2025
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
from ...configuration_utils import PreTrainedConfig, layer_type_validation
|
|
|
|
|
|
class CwmConfig(PreTrainedConfig):
|
|
"""
|
|
Configuration for Code World Model (CWM).
|
|
This is an inherited Llama3-compatible configuration with layer-interleaved
|
|
sliding-window attention. Configures a `CwmModel`. Designed to yield a configuration mirroring the model in the
|
|
[facebook/cwm](https://huggingface.co/facebook/cwm) architecture by default. Other models include:
|
|
- [facebook/cwm-sft](https://huggingface.co/facebook/cwm-sft)
|
|
- [facebook/cwm-pretrain](https://huggingface.co/facebook/cwm-pretrain)
|
|
|
|
Args:
|
|
vocab_size (`int`, *optional*, defaults to 128256):
|
|
Vocabulary size of the CWM model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`CwmModel`]
|
|
hidden_size (`int`, *optional*, defaults to 6144):
|
|
Dimension of the hidden representations
|
|
intermediate_size (`int`, *optional*, defaults to 21504):
|
|
Dimension of the MLP representations
|
|
num_hidden_layers (`int`, *optional*, defaults to 64):
|
|
Number of hidden layers in the Transformer decoder
|
|
num_attention_heads (`int`, *optional*, defaults to 48):
|
|
Number of attention heads for each attention layer in the Transformer decoder
|
|
num_key_value_heads (`int`, *optional*, defaults to 8):
|
|
This is the number of key_value heads that should be used to implement Grouped Query Attention (GQA).
|
|
If it is not specified, will default to `num_attention_heads`.
|
|
head_dim (`int`, *optional*, defaults to 128):
|
|
The attention head dimension.
|
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
|
The non-linear activation function (function or string) in the decoder.
|
|
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
|
The maximum sequence length that this model might ever be used with. CWM's attention allows sequence
|
|
lengths up to 131072 tokens.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
The epsilon used by the rms normalization layers.
|
|
use_cache (`bool`, *optional*, defaults to `True`):
|
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
relevant if `config.is_decoder=True`.
|
|
pad_token_id (`int`, *optional*):
|
|
Padding token id.
|
|
eos_token_id (`int` or `list[int]`, *optional*, defaults to `[128001, 128008, 128009]`):
|
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
|
bos_token_id (`int`, *optional*, defaults to 128000):
|
|
The id of the *beginning-of-sequence* token.
|
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
Whether to tie weight embeddings
|
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the attention probabilities.
|
|
pretraining_tp (`int`, *optional*, defaults to 1):
|
|
Tensor parallelism degree used during pretraining. See [this
|
|
document](https://huggingface.co/docs/transformers/parallelism) and [this
|
|
issue](https://github.com/pytorch/pytorch/issues/76232).
|
|
mlp_bias (`bool`, *optional*, defaults to `False`):
|
|
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
|
rope_parameters (`RopeParameters`, *optional*):
|
|
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
|
|
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
|
|
with longer `max_position_embeddings`.
|
|
sliding_window (`int`, *optional*, defaults to 8192):
|
|
Sliding window attention window size.
|
|
layer_types (`List[str]`, *optional*):
|
|
List of layer types for each layer. Each element should be either "full_attention" or "sliding_attention".
|
|
If not specified, will default to alternating pattern based on the provided window pattern.
|
|
"""
|
|
|
|
model_type = "cwm"
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
# Default tensor parallel plan for base model `CwmModel`
|
|
base_model_tp_plan = {
|
|
"layers.*.self_attn.q_proj": "colwise",
|
|
"layers.*.self_attn.k_proj": "colwise",
|
|
"layers.*.self_attn.v_proj": "colwise",
|
|
"layers.*.self_attn.o_proj": "rowwise",
|
|
"layers.*.mlp.gate_proj": "colwise",
|
|
"layers.*.mlp.up_proj": "colwise",
|
|
"layers.*.mlp.down_proj": "rowwise",
|
|
}
|
|
base_model_pp_plan = {
|
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
|
"norm": (["hidden_states"], ["hidden_states"]),
|
|
}
|
|
default_theta = 1_000_000.0
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size: int = 128256,
|
|
hidden_size: int = 6144,
|
|
intermediate_size: int = 21504,
|
|
num_hidden_layers: int = 64,
|
|
num_attention_heads: int = 48,
|
|
num_key_value_heads: int = 8,
|
|
head_dim: int = 128,
|
|
hidden_act: str = "silu",
|
|
max_position_embeddings: int = 131072,
|
|
initializer_range: float = 0.02,
|
|
rms_norm_eps: float = 1e-5,
|
|
use_cache: bool = True,
|
|
pad_token_id: int | None = None,
|
|
eos_token_id=[128001, 128008, 128009],
|
|
bos_token_id: int = 128000,
|
|
tie_word_embeddings: bool = False,
|
|
attention_dropout: float = 0.0,
|
|
pretraining_tp: int = 1,
|
|
mlp_bias: bool = False,
|
|
rope_parameters: dict | None = None,
|
|
# CWM interleaved sliding window fields
|
|
sliding_window: int = 8192,
|
|
layer_types: list[str] | None = None, # ["full_attention"|"sliding_attention"] per layer
|
|
**kwargs,
|
|
):
|
|
if rope_parameters is None:
|
|
rope_parameters = {
|
|
"rope_theta": 1_000_000.0,
|
|
"factor": 16.0,
|
|
"high_freq_factor": 4.0,
|
|
"low_freq_factor": 1.0,
|
|
"original_max_position_embeddings": 8192,
|
|
"rope_type": "llama3",
|
|
}
|
|
|
|
if layer_types is None:
|
|
# Default pattern: every 4th layer uses full attention, others use sliding attention
|
|
window_pattern = 4
|
|
layer_types = [
|
|
("full_attention" if (i % window_pattern == 0) else "sliding_attention")
|
|
for i in range(num_hidden_layers)
|
|
]
|
|
else:
|
|
layer_type_validation(layer_types, num_hidden_layers)
|
|
|
|
self.sliding_window = int(sliding_window) if sliding_window else None
|
|
self.layer_types = list(layer_types)
|
|
self.vocab_size = vocab_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
|
|
# for backward compatibility
|
|
if num_key_value_heads is None:
|
|
num_key_value_heads = num_attention_heads
|
|
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.hidden_act = hidden_act
|
|
self.initializer_range = initializer_range
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.pretraining_tp = pretraining_tp
|
|
self.use_cache = use_cache
|
|
self.attention_dropout = attention_dropout
|
|
self.mlp_bias = mlp_bias
|
|
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
|
self.rope_parameters = rope_parameters
|
|
|
|
self.tie_word_embeddings = tie_word_embeddings
|
|
self.pad_token_id = pad_token_id
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
__all__ = ["CwmConfig"]
|