You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
786 lines
35 KiB
786 lines
35 KiB
# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ... import initialization as init
|
|
from ...activations import ACT2FN
|
|
from ...integrations import lazy_load_kernel
|
|
from ...masking_utils import create_causal_mask
|
|
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import TransformersKwargs, auto_docstring, logging
|
|
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, eager_attention_forward
|
|
from ..mistral.modeling_mistral import MistralMLP
|
|
from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM
|
|
from .configuration_jamba import JambaConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class JambaRMSNorm(LlamaRMSNorm):
|
|
pass
|
|
|
|
|
|
class HybridMambaAttentionDynamicCache:
|
|
"""
|
|
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
|
|
(which has a constant shape regardless of seq_len).
|
|
|
|
This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
|
|
and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
|
|
For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
|
|
while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
|
|
For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
|
|
while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
|
|
and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
|
|
"""
|
|
|
|
is_compileable = False
|
|
|
|
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
|
self.dtype = dtype
|
|
self.layers_block_type = config.layers_block_type
|
|
self.has_previous_state = False # only used by mamba
|
|
intermediate_size = config.mamba_expand * config.hidden_size
|
|
ssm_state_size = config.mamba_d_state
|
|
conv_kernel_size = config.mamba_d_conv
|
|
self.conv_states = []
|
|
self.ssm_states = []
|
|
self.transformer_layers = []
|
|
for i in range(config.num_hidden_layers):
|
|
if self.layers_block_type[i] == "mamba":
|
|
self.conv_states += [
|
|
torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
|
|
]
|
|
self.ssm_states += [
|
|
torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
|
|
]
|
|
else:
|
|
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
|
|
self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
|
|
self.transformer_layers.append(i)
|
|
|
|
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
|
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
|
|
|
def __len__(self):
|
|
return len(self.key_cache)
|
|
|
|
def __getitem__(self, layer_idx):
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
layer_idx: int,
|
|
cache_kwargs: dict[str, Any] | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Update the cache
|
|
if self.key_cache[layer_idx].shape[-1] == 0:
|
|
self.key_cache[layer_idx] = key_states
|
|
self.value_cache[layer_idx] = value_states
|
|
else:
|
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
|
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
|
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
if self.get_seq_length() > 0:
|
|
for layer_idx in range(len(self.key_cache)):
|
|
device = self.key_cache[layer_idx].device
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
|
device = self.value_cache[layer_idx].device
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
|
|
|
device = self.conv_states[layer_idx].device
|
|
self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
|
|
device = self.ssm_states[layer_idx].device
|
|
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
|
|
|
|
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
|
"""Return the length and offset of the cache, used to generate the mask"""
|
|
kv_offset = 0
|
|
query_length = cache_position.shape[0]
|
|
kv_length = self.get_seq_length(layer_idx) + query_length
|
|
return kv_length, kv_offset
|
|
|
|
def get_seq_length(self, layer_idx: int | None = 0) -> int:
|
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
# take any layer that contains cache and not empty tensor
|
|
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
|
|
if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0:
|
|
return 0
|
|
return self.key_cache[layer_idx].shape[-2]
|
|
|
|
|
|
class JambaAttention(LlamaAttention):
|
|
def __init__(self, config: JambaConfig, layer_idx: int):
|
|
super().__init__(config, layer_idx)
|
|
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
past_key_values: HybridMambaAttentionDynamicCache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
if past_key_values is not None:
|
|
key_states, value_states = past_key_values.update(
|
|
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
|
)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class JambaMambaMixer(nn.Module):
|
|
"""
|
|
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
|
|
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
|
|
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
|
|
and is why Mamba is called **selective** state spaces)
|
|
"""
|
|
|
|
def __init__(self, config: JambaConfig, layer_idx):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.hidden_size = config.hidden_size
|
|
self.ssm_state_size = config.mamba_d_state
|
|
self.conv_kernel_size = config.mamba_d_conv
|
|
self.intermediate_size = config.mamba_expand * config.hidden_size
|
|
self.time_step_rank = config.mamba_dt_rank
|
|
self.use_conv_bias = config.mamba_conv_bias
|
|
self.use_bias = config.mamba_proj_bias
|
|
self.conv1d = nn.Conv1d(
|
|
in_channels=self.intermediate_size,
|
|
out_channels=self.intermediate_size,
|
|
bias=self.use_conv_bias,
|
|
kernel_size=self.conv_kernel_size,
|
|
groups=self.intermediate_size,
|
|
padding=self.conv_kernel_size - 1,
|
|
)
|
|
|
|
self.activation = config.hidden_act
|
|
self.act = ACT2FN[config.hidden_act]
|
|
|
|
self.use_fast_kernels = config.use_mamba_kernels
|
|
|
|
# projection of the input hidden states
|
|
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
|
|
# selective projection used to make dt, B and C input dependent
|
|
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
|
|
# time step projection (discretization)
|
|
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
|
|
|
# S4D real initialization. These are not discretized!
|
|
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
|
|
A = torch.arange(1, self.ssm_state_size + 1)[None, :]
|
|
A = A.expand(self.intermediate_size, -1).contiguous()
|
|
|
|
self.A_log = nn.Parameter(torch.log(A))
|
|
self.D = nn.Parameter(torch.ones(self.intermediate_size))
|
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
|
|
|
|
self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
|
|
self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
|
|
self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
|
|
|
|
global causal_conv1d_update, causal_conv1d_fn
|
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
|
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
|
|
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
|
|
|
|
global selective_state_update, mamba_inner_fn, selective_scan_fn
|
|
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
|
selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
|
|
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)
|
|
selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)
|
|
|
|
global is_fast_path_available
|
|
is_fast_path_available = all(
|
|
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
|
)
|
|
|
|
if not is_fast_path_available:
|
|
logger.warning_once(
|
|
"The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
|
|
" is None. To install follow https://github.com/state-spaces/mamba/#installation and"
|
|
" https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
|
|
)
|
|
|
|
def cuda_kernels_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
cache_params: HybridMambaAttentionDynamicCache | None = None,
|
|
attention_mask: torch.LongTensor | None = None,
|
|
):
|
|
batch_size, seq_len, _ = hidden_states.shape
|
|
use_precomputed_states = (
|
|
cache_params is not None
|
|
and cache_params.has_previous_state
|
|
and seq_len == 1
|
|
and cache_params.conv_states[self.layer_idx].shape[0]
|
|
== cache_params.ssm_states[self.layer_idx].shape[0]
|
|
== batch_size
|
|
)
|
|
# 1. Gated MLP's linear projection
|
|
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
|
|
|
# We can't use `mamba_inner_fn` even if in training and without cache params because we have the
|
|
# inner layernorms which isn't supported by this fused kernel
|
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
|
|
|
if attention_mask is not None:
|
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
|
|
|
# 2. Convolution sequence transformation
|
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
|
if use_precomputed_states:
|
|
hidden_states = causal_conv1d_update(
|
|
hidden_states.squeeze(-1),
|
|
cache_params.conv_states[self.layer_idx],
|
|
conv_weights,
|
|
self.conv1d.bias,
|
|
self.activation,
|
|
)
|
|
hidden_states = hidden_states.unsqueeze(-1)
|
|
else:
|
|
if cache_params is not None:
|
|
conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
|
|
cache_params.conv_states[self.layer_idx].copy_(conv_states)
|
|
hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
|
|
|
|
if attention_mask is not None:
|
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
|
|
|
# 3. State Space Model sequence transformation
|
|
# 3.a. input varying initialization of time_step, B and C
|
|
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
|
time_step, B, C = torch.split(
|
|
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
|
|
)
|
|
|
|
time_step = self.dt_layernorm(time_step)
|
|
B = self.b_layernorm(B)
|
|
C = self.c_layernorm(C)
|
|
|
|
# Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
|
|
# This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
|
|
# in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
|
|
# linear layers, and requires to call the forward pass directly.
|
|
# Quantized model can't work with the original code:
|
|
# ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
|
|
time_proj_bias = self.dt_proj.bias.data
|
|
with torch.no_grad():
|
|
self.dt_proj.bias.data = torch.zeros_like(self.dt_proj.bias.data)
|
|
discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
|
|
with torch.no_grad():
|
|
self.dt_proj.bias.data = time_proj_bias
|
|
|
|
A = -torch.exp(self.A_log.float())
|
|
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
|
time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
|
|
if use_precomputed_states:
|
|
scan_outputs = selective_state_update(
|
|
cache_params.ssm_states[self.layer_idx],
|
|
hidden_states[..., 0],
|
|
discrete_time_step[..., 0],
|
|
A,
|
|
B[:, 0],
|
|
C[:, 0],
|
|
self.D,
|
|
gate[..., 0],
|
|
time_proj_bias,
|
|
dt_softplus=True,
|
|
).unsqueeze(-1)
|
|
else:
|
|
scan_outputs, ssm_state = selective_scan_fn(
|
|
hidden_states,
|
|
discrete_time_step,
|
|
A,
|
|
B.transpose(1, 2),
|
|
C.transpose(1, 2),
|
|
self.D.float(),
|
|
gate,
|
|
time_proj_bias,
|
|
delta_softplus=True,
|
|
return_last_state=True,
|
|
)
|
|
if ssm_state is not None and cache_params is not None:
|
|
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
|
|
|
# 4. Final linear projection
|
|
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
|
|
|
|
return contextualized_states
|
|
|
|
# fmt: off
|
|
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache | None = None, attention_mask: torch.LongTensor | None = None):
|
|
batch_size, seq_len, _ = input_states.shape
|
|
dtype = input_states.dtype
|
|
# 1. Gated MLP's linear projection
|
|
projected_states = self.in_proj(input_states).transpose(1, 2)
|
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
|
|
|
if attention_mask is not None:
|
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
|
|
|
use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
|
|
# 2. Convolution sequence transformation
|
|
if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
|
|
if self.training:
|
|
# In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass
|
|
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
|
|
else:
|
|
ssm_state = cache_params.ssm_states[self.layer_idx]
|
|
|
|
ssm_state = ssm_state.to(hidden_states.device)
|
|
|
|
if cache_params.has_previous_state and seq_len == 1 and \
|
|
cache_params.conv_states[self.layer_idx].shape[0] == batch_size:
|
|
conv_state = cache_params.conv_states[self.layer_idx]
|
|
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
|
|
conv_state[:, :, -1] = hidden_states[:, :, 0]
|
|
cache_params.conv_states[self.layer_idx] = conv_state
|
|
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
|
|
if self.use_conv_bias:
|
|
hidden_states += self.conv1d.bias
|
|
hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1)
|
|
else:
|
|
conv_state = nn.functional.pad(
|
|
hidden_states,
|
|
(self.conv_kernel_size - hidden_states.shape[-1], 0)
|
|
)
|
|
cache_params.conv_states[self.layer_idx] = conv_state
|
|
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
|
|
else:
|
|
ssm_state = torch.zeros(
|
|
(batch_size, self.intermediate_size, self.ssm_state_size),
|
|
device=hidden_states.device, dtype=dtype
|
|
)
|
|
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
|
|
|
|
if attention_mask is not None:
|
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
|
|
|
# 3. State Space Model sequence transformation
|
|
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
|
|
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
|
time_step, B, C = torch.split(
|
|
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
|
|
)
|
|
|
|
time_step = self.dt_layernorm(time_step)
|
|
B = self.b_layernorm(B)
|
|
C = self.c_layernorm(C)
|
|
|
|
discrete_time_step = self.dt_proj(time_step)
|
|
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2)
|
|
|
|
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
|
|
A = -torch.exp(self.A_log.float())
|
|
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None])
|
|
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
|
|
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
|
|
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
|
scan_outputs = []
|
|
for i in range(seq_len):
|
|
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
|
|
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))
|
|
scan_outputs.append(scan_output[:, :, 0])
|
|
scan_output = torch.stack(scan_outputs, dim=-1)
|
|
scan_output = scan_output + (hidden_states * self.D[None, :, None])
|
|
scan_output = (scan_output * self.act(gate))
|
|
|
|
if use_cache:
|
|
cache_params.ssm_states[self.layer_idx] = ssm_state
|
|
|
|
# 4. Final linear projection
|
|
contextualized_states = self.out_proj(scan_output.transpose(1, 2))
|
|
return contextualized_states
|
|
# fmt: on
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
cache_params: HybridMambaAttentionDynamicCache | None = None,
|
|
attention_mask: torch.LongTensor | None = None,
|
|
):
|
|
if self.use_fast_kernels:
|
|
if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
|
|
raise ValueError(
|
|
"Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
|
|
)
|
|
return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
|
|
return self.slow_forward(hidden_states, cache_params, attention_mask)
|
|
|
|
|
|
class JambaMLP(MistralMLP):
|
|
pass
|
|
|
|
|
|
class JambaExperts(MixtralExperts):
|
|
pass
|
|
|
|
|
|
class JambaSparseMoeBlock(nn.Module):
|
|
"""
|
|
This implementation is
|
|
strictly equivalent to standard MoE with full capacity (no
|
|
dropped tokens). It's faster since it formulates MoE operations
|
|
in terms of block-sparse operations to accommodate imbalanced
|
|
assignments of tokens to experts, whereas standard MoE either
|
|
(1) drop tokens at the cost of reduced performance or (2) set
|
|
capacity factor to number of experts and thus waste computation
|
|
and memory on padding.
|
|
"""
|
|
|
|
def __init__(self, config: JambaConfig):
|
|
super().__init__()
|
|
self.hidden_dim = config.hidden_size
|
|
self.ffn_dim = config.intermediate_size
|
|
self.num_experts = config.num_experts
|
|
self.top_k = config.num_experts_per_tok
|
|
|
|
self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
self.experts = JambaExperts(config)
|
|
|
|
def route_tokens_to_experts(self, hidden_states, router_logits):
|
|
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
|
|
top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
|
|
return top_k_index, top_k_weights.to(hidden_states.dtype)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
router_logits = self.router(hidden_states)
|
|
top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits)
|
|
hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
|
|
hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
|
return hidden_states
|
|
|
|
|
|
class JambaAttentionDecoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: JambaConfig, layer_idx: int):
|
|
super().__init__()
|
|
num_experts = config.layers_num_experts[layer_idx] if config.layers_num_experts else 1
|
|
self.self_attn = JambaAttention(config, layer_idx)
|
|
|
|
ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
|
|
self.feed_forward = ffn_layer_class(config)
|
|
self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: HybridMambaAttentionDynamicCache | None = None,
|
|
use_cache: bool | None = False,
|
|
cache_position: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.FloatTensor:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states, _ = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
residual = hidden_states
|
|
hidden_states = self.pre_ff_layernorm(hidden_states)
|
|
hidden_states = self.feed_forward(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class JambaMambaDecoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: JambaConfig, layer_idx: int):
|
|
super().__init__()
|
|
num_experts = config.layers_num_experts[layer_idx] if config.layers_num_experts else 1
|
|
self.mamba = JambaMambaMixer(config=config, layer_idx=layer_idx)
|
|
ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
|
|
self.feed_forward = ffn_layer_class(config)
|
|
self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: HybridMambaAttentionDynamicCache | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.FloatTensor:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states = self.mamba(
|
|
hidden_states=hidden_states,
|
|
cache_params=past_key_values,
|
|
attention_mask=attention_mask,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
residual = hidden_states
|
|
hidden_states = self.pre_ff_layernorm(hidden_states)
|
|
hidden_states = self.feed_forward(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer}
|
|
|
|
|
|
class JambaPreTrainedModel(PreTrainedModel):
|
|
config: JambaConfig
|
|
base_model_prefix = "model"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"]
|
|
_skip_keys_device_placement = "past_key_values"
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
_is_stateful = True
|
|
_can_record_outputs = {
|
|
"hidden_states": [JambaAttentionDecoderLayer, JambaMambaDecoderLayer],
|
|
"attentions": JambaAttention,
|
|
"router_logits": OutputRecorder(nn.Linear, layer_name="router"),
|
|
}
|
|
|
|
@torch.no_grad()
|
|
def _init_weights(self, module):
|
|
super()._init_weights(module)
|
|
if isinstance(module, JambaMambaMixer):
|
|
A = torch.arange(1, module.ssm_state_size + 1)[None, :]
|
|
A = A.expand(module.intermediate_size, -1).contiguous()
|
|
init.copy_(module.A_log, torch.log(A))
|
|
init.ones_(module.D)
|
|
elif isinstance(module, JambaExperts):
|
|
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
|
|
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
|
|
|
|
|
|
@auto_docstring
|
|
class JambaModel(JambaPreTrainedModel):
|
|
def __init__(self, config: JambaConfig):
|
|
super().__init__(config)
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
decoder_layers = []
|
|
for i in range(config.num_hidden_layers):
|
|
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
|
|
decoder_layers.append(layer_class(config, layer_idx=i))
|
|
self.layers = nn.ModuleList(decoder_layers)
|
|
|
|
self.final_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
self.gradient_checkpointing = False
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@check_model_inputs
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: HybridMambaAttentionDynamicCache | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
use_cache: bool | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> MoeModelOutputWithPast:
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = HybridMambaAttentionDynamicCache(
|
|
config=self.config,
|
|
batch_size=inputs_embeds.shape[0],
|
|
dtype=inputs_embeds.dtype,
|
|
device=inputs_embeds.device,
|
|
)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
causal_mask = create_causal_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
)
|
|
mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
|
|
hidden_states = inputs_embeds
|
|
for decoder_layer in self.layers:
|
|
layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask
|
|
|
|
hidden_states = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=layer_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = self.final_layernorm(hidden_states)
|
|
|
|
if past_key_values and not past_key_values.has_previous_state:
|
|
past_key_values.has_previous_state = True
|
|
|
|
return MoeModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values,
|
|
)
|
|
|
|
def _update_mamba_mask(self, attention_mask, cache_position):
|
|
"""
|
|
No need for zeroing states when
|
|
1. Cached forward
|
|
2. Attending to all inputs
|
|
"""
|
|
mamba_mask = attention_mask
|
|
if (cache_position is not None and cache_position[0] > 0) or (
|
|
attention_mask is not None and torch.all(attention_mask == 1)
|
|
):
|
|
mamba_mask = None
|
|
return mamba_mask
|
|
|
|
|
|
class JambaForCausalLM(MixtralForCausalLM):
|
|
def __init__(self, config: JambaConfig):
|
|
super().__init__(config)
|
|
self.num_experts = config.num_experts
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: HybridMambaAttentionDynamicCache | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
labels: torch.LongTensor | None = None,
|
|
use_cache: bool | None = None,
|
|
output_router_logits: bool | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
logits_to_keep: int | torch.Tensor = 0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> MoeCausalLMOutputWithPast:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, JambaForCausalLM
|
|
|
|
>>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
|
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
```"""
|
|
return super().forward(
|
|
input_ids,
|
|
attention_mask,
|
|
position_ids,
|
|
past_key_values,
|
|
inputs_embeds,
|
|
labels,
|
|
use_cache,
|
|
cache_position,
|
|
logits_to_keep,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class JambaForSequenceClassification(GenericForSequenceClassification, JambaPreTrainedModel):
|
|
pass
|
|
|
|
|
|
__all__ = ["JambaForCausalLM", "JambaForSequenceClassification", "JambaModel", "JambaPreTrainedModel"]
|