You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1198 lines
51 KiB

# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Chameleon model."""
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
TransformersKwargs,
auto_docstring,
can_return_tuple,
logging,
torch_compilable_check,
)
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
logger = logging.get_logger(__name__)
@dataclass
@auto_docstring
class ChameleonVQVAEModelOutput(BaseModelOutputWithPooling):
r"""
quantized_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
Quantized last hidden state from the VQ-VAE model.
image_tokens (`torch.FloatTensor` of shape `(batch_size, config.vocab_size`):
Indices of the image tokens predicted by the VQ-VAE model.
embedding_loss (`torch.FloatTensor`):
The embedding loss computed during quantization.
"""
quantized_last_hidden_state: torch.FloatTensor | None = None
image_tokens: torch.FloatTensor | None = None
embedding_loss: torch.FloatTensor | None = None
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon
class ChameleonRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
ChameleonRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
class ChameleonRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: ChameleonConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
@staticmethod
def compute_default_rope_parameters(
config: ChameleonConfig | None = None,
device: Optional["torch.device"] = None,
seq_len: int | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Chameleon
class ChameleonMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
# Ignore copy
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class ChameleonLayerNorm(nn.LayerNorm):
"""
LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta
from each shard separately to each head, instead of reducing. We can apply each head's own
gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed
in the last dimension. This module applies gamma/beta manually to fulfill this requirement.
"""
def __init__(self, hidden_size, *args, **kwargs):
super().__init__(hidden_size, *args, **kwargs)
self.normalized_shape = (hidden_size[-1],)
def forward(self, hidden_states):
hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5)
hidden_states = hidden_states * self.weight + self.bias
return hidden_states
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class ChameleonAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: ChameleonConfig, layer_idx: int | None = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.is_causal = True
self.model_parallel_size = config.model_parallel_size
self.scaling = self.head_dim**-0.5
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
query_states = self.q_norm(query_states)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
key_states = self.k_norm(key_states)
query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
class ChameleonDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
self.mlp = ChameleonMLP(config)
self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_values (`Cache`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class ChameleonSwinDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
self.mlp = ChameleonMLP(config)
self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
"""
Args:
hidden_states (`torch.FloatTensor`):
input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings
past_key_values (`Cache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
"""
residual = hidden_states
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.input_layernorm(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class ChameleonVQVAEVectorQuantizer(nn.Module):
"""
A module for vector quantization using learned embedding vectors.
This module implements the quantization process similar to te one described in
the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
input vectors into discrete codebook vectors, which are learned during training.
Current implementation improves over previous ones by avoiding costly matrix multiplications
and allowing for post-hoc remapping of indices.
"""
def __init__(self, config):
super().__init__()
self.num_embeddings = config.num_embeddings
self.embedding_dim = config.embed_dim
self.beta = getattr(config, "beta", 0.25)
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
def forward(self, hidden_state: torch.Tensor):
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
distances = (
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
)
min_encoding_indices = torch.argmin(distances, dim=1)
hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
# compute loss for embedding
loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
(hidden_state_quant - hidden_state.detach()) ** 2
)
# preserve gradients
hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
# reshape back to match original input shape
hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
return hidden_state_quant, loss, min_encoding_indices
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, hidden_states):
# no asymmetric padding in torch conv, must do it ourselves
hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
hidden_states = self.conv(hidden_states)
return hidden_states
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
def __init__(
self,
config,
in_channels,
out_channels=None,
conv_shortcut=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.dropout = torch.nn.Dropout(config.dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states *= torch.sigmoid(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states *= torch.sigmoid(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)
return residual + hidden_states
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states):
residual = hidden_states
hidden_states = self.norm(hidden_states)
query_states = self.q(hidden_states)
key_states = self.k(hidden_states)
value_states = self.v(hidden_states)
# compute attention
batch_size, channels, height, width = query_states.shape
query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
key_states = key_states.reshape(batch_size, channels, height * width)
attn_weights = torch.bmm(query_states, key_states)
attn_weights = attn_weights * (int(channels) ** (-0.5))
attn_weights = F.softmax(attn_weights, dim=2)
# attend to values
value_states = value_states.reshape(batch_size, channels, height * width)
attn_weights = attn_weights.permute(0, 2, 1)
attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
attn_output = self.proj_out(attn_output)
return residual + attn_output
class ChameleonVQVAEEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.num_resolutions = len(config.channel_multiplier)
self.num_res_blocks = config.num_res_blocks
base_channels = config.base_channels
resolution = config.resolution
in_channels = config.in_channels
double_latent = config.double_latent
latent_channels = config.latent_channels
channel_multiplier = config.channel_multiplier
self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_channel_multiplier = (1,) + tuple(channel_multiplier)
self.in_channel_multiplier = in_channel_multiplier
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = base_channels * in_channel_multiplier[i_level]
block_out = base_channels * channel_multiplier[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_out,
)
)
block_in = block_out
if (
config.attn_resolutions is not None
and curr_res in config.attn_resolutions
and config.attn_type == "vanilla"
):
attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_in,
)
self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity()
self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_in,
)
self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * latent_channels if double_latent else latent_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, pixel_values: torch.LongTensor):
# downsampling
hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_state = self.down[i_level].block[i_block](
hidden_states[-1],
)
if len(self.down[i_level].attn) > 0:
hidden_state = self.down[i_level].attn[i_block](hidden_state)
hidden_states.append(hidden_state)
if i_level != self.num_resolutions - 1:
hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
# middle
last_hidden_state = hidden_states[-1]
last_hidden_state = self.mid.block_1(last_hidden_state)
last_hidden_state = self.mid.attn_1(last_hidden_state)
last_hidden_state = self.mid.block_2(last_hidden_state)
# end
last_hidden_state = self.norm_out(last_hidden_state)
last_hidden_state *= torch.sigmoid(last_hidden_state)
last_hidden_state = self.conv_out(last_hidden_state)
return last_hidden_state
class ChameleonImageVocabularyMapping:
"""
A class for mapping discrete image tokens from VQGAN to BPE tokens.
"""
def __init__(self, vocab_map):
self.vocab_map = vocab_map
self.image_token_id = vocab_map.get("<image>")
@cached_property
def val2name(self):
return {v: k for k, v in self.vocab_map.items()}
@cached_property
def image_tokens(self):
return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")])
@cached_property
def bpe2img(self):
img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
def remap(old_name: str) -> str:
return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1])
return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
@cached_property
def img2bpe(self):
return {v: k for k, v in self.bpe2img.items()}
@cached_property
def bpe2img_search_tensors(self):
return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values()))
@cached_property
def img2bpe_mapping_tensor(self):
mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
for k, v in self.img2bpe.items():
mapping[k] = v
return mapping
def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
device = img_batch.device
img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
return img_tokens.to(device)
@auto_docstring
class ChameleonPreTrainedModel(PreTrainedModel):
config: ChameleonConfig
base_model_prefix = "model"
input_modalities = ("image", "text")
supports_gradient_checkpointing = True
_no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_supports_flash_attn = True
_supports_sdpa = True
_can_compile_fullgraph = True
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": [ChameleonDecoderLayer, ChameleonSwinDecoderLayer],
"attentions": ChameleonAttention,
}
@auto_docstring(
custom_intro="""
The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
Taigman](https://huggingface.co/papers/2203.13131).
"""
)
class ChameleonVQVAE(ChameleonPreTrainedModel):
config: ChameleonVQVAEConfig
_no_split_modules = [
"ChameleonVQVAEVectorQuantizer",
"ChameleonVQVAEEncoderAttnBlock",
"ChameleonVQVAEEncoderResnetBlock",
]
_can_record_outputs = {
"hidden_states": ChameleonVQVAEEncoderResnetBlock,
"attentions": ChameleonVQVAEEncoderAttnBlock,
}
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__(config)
self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
self.post_init()
@check_model_inputs
def encode(
self, pixel_values: torch.LongTensor, **kwargs: Unpack[TransformersKwargs]
) -> ChameleonVQVAEModelOutput:
hidden_states = self.encoder(pixel_values)
conv_hidden_states = self.quant_conv(hidden_states)
quantized_last_hidden_state, emb_loss, indices = self.quantize(conv_hidden_states)
return ChameleonVQVAEModelOutput(
last_hidden_state=hidden_states,
quantized_last_hidden_state=quantized_last_hidden_state,
image_tokens=indices,
embedding_loss=emb_loss,
)
@auto_docstring
class ChameleonModel(ChameleonPreTrainedModel):
def __init__(self, config: ChameleonConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer
self.layers = nn.ModuleList(
[decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
self.rotary_emb = ChameleonRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_image_tokens(self, pixel_values: torch.FloatTensor):
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
special tokens.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images.
"""
batch_size = pixel_values.shape[0]
vqmodel_outputs: ChameleonVQVAEModelOutput = self.vqmodel.encode(pixel_values, return_dict=True)
bpe_toks = self.vocabulary_mapping.convert_img2bpe(vqmodel_outputs.image_tokens)
bpe_toks = bpe_toks.view(batch_size, -1)
return bpe_toks
@can_return_tuple
@auto_docstring(
custom_intro="Tokenizes images into discrete tokens with VQGAN module and embeds them with text embeddings layer."
)
def get_image_features(
self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
) -> tuple | BaseModelOutputWithPooling:
r"""
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
vqmodel_outputs: ChameleonVQVAEModelOutput = self.vqmodel.encode(pixel_values, return_dict=True, **kwargs)
vqmodel_outputs.pooler_output = self.get_input_embeddings()(vqmodel_outputs.image_tokens)
return vqmodel_outputs
def get_placeholder_mask(
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
n_image_tokens = special_image_mask.sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
torch_compilable_check(
inputs_embeds[special_image_mask].numel() == image_features.numel(),
f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
)
return special_image_mask
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple | BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
special_image_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@auto_docstring(
custom_intro="""
Chameleon Model with a head on top used for outputting logits for next token prediction.
"""
)
class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
def __init__(self, config):
super().__init__(config)
self.model = ChameleonModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_image_tokens(self, pixel_values):
return self.model.get_image_tokens(pixel_values)
@auto_docstring
def get_image_features(
self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
) -> tuple | BaseModelOutputWithPooling:
return self.model.get_image_features(pixel_values, **kwargs)
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
>>> import torch
>>> import httpx
>>> from io import BytesIO
>>> from PIL import Image
>>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", dtype=torch.bfloat16)
>>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
>>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
>>> url = "https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg"
>>> with httpx.stream("GET", url) as response:
... image1 = Image.open(BytesIO(response.read()))
>>> url = "https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg"
>>> with httpx.stream("GET", url) as response:
... image2 = Image.open(BytesIO(response.read()))
>>> inputs = processor(images=[image1, image2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
>>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
>>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
# Disallow image tokens which does not include special begin-image and end-image tokens
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, :, image_tokens] = torch.finfo(logits.dtype).min
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
pixel_values=None,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
is_first_iteration=False,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
pixel_values=pixel_values,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
is_first_iteration=is_first_iteration,
**kwargs,
)
if not is_first_iteration and use_cache:
# Pixel values are used only in the first iteration if available
# In subsquent iterations, they are already merged with text and cached
# NOTE: first iteration doesn't have to be prefill, it can be the first
# iteration with a question and cached system prompt (continue generate from cache)
model_inputs["pixel_values"] = None
return model_inputs
__all__ = ["ChameleonForConditionalGeneration", "ChameleonModel", "ChameleonPreTrainedModel", "ChameleonVQVAE"]