# Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch KOSMOS-2 model.""" import math import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any import torch from torch import nn from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, CausalLMOutputWithCrossAttentions, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int from ...utils.generic import is_flash_attention_requested from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig logger = logging.get_logger(__name__) def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @dataclass @auto_docstring class BaseModelOutputWithProjectionAttentions(BaseModelOutputWithPooling): r""" projection_attentions (`tuple(torch.FloatTensor)`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute the weighted average in the self-attention heads. """ projection_attentions: tuple[torch.FloatTensor] | None = None @dataclass @auto_docstring( custom_intro=""" Base class for text model's outputs that also contains a pooling of the last hidden states. """ ) class Kosmos2ModelOutput(ModelOutput): r""" past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. projection_attentions (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute the weighted average in the self-attention heads. vision_model_output (`BaseModelOutputWithPooling`, *optional*): The output of the [`Kosmos2VisionModel`]. """ last_hidden_state: torch.FloatTensor | None = None past_key_values: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None attentions: tuple[torch.FloatTensor] | None = None image_embeds: torch.FloatTensor | None = None projection_attentions: tuple[torch.FloatTensor] | None = None vision_model_output: BaseModelOutputWithPooling = None def to_tuple(self) -> tuple[Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() ) @dataclass @auto_docstring( custom_intro=""" Model output class for `Kosmos2ForConditionalGeneration`. """ ) class Kosmos2ForConditionalGenerationModelOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. projection_attentions (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute the weighted average in the self-attention heads. vision_model_output (`BaseModelOutputWithPooling`, *optional*): The output of the [`Kosmos2VisionModel`]. """ loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None past_key_values: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None attentions: tuple[torch.FloatTensor] | None = None image_embeds: torch.FloatTensor | None = None projection_attentions: tuple[torch.FloatTensor] | None = None vision_model_output: BaseModelOutputWithPooling = None def to_tuple(self) -> tuple[Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() ) # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Kosmos2 class Kosmos2VisionEmbeddings(nn.Module): def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False, ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 position_embedding = self.position_embedding.weight.unsqueeze(0) num_positions = position_embedding.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embedding(self.position_ids) class_pos_embed = position_embedding[:, :1] patch_pos_embed = position_embedding[:, 1:] dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})." ) target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> Kosmos2 doesn't cast attn weights to fp32 def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, scaling: float, dropout: float = 0.0, **kwargs, ): attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Kosmos2VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, causal_attention_mask: torch.Tensor | None = None, output_attentions: bool | None = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Input shape: Batch x Time x Channel""" batch_size, seq_length, embed_dim = hidden_states.shape queries = self.q_proj(hidden_states) keys = self.k_proj(hidden_states) values = self.v_proj(hidden_states) queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) # CLIP text model uses both `causal_attention_mask` and `attention_mask` # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` if not is_flash_attention_requested(self.config): if attention_mask is not None and causal_attention_mask is not None: attention_mask = attention_mask + causal_attention_mask elif causal_attention_mask is not None: attention_mask = causal_attention_mask else: self.is_causal = causal_attention_mask is not None attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, queries, keys, values, attention_mask, is_causal=self.is_causal, scaling=self.scale, dropout=0.0 if not self.training else self.dropout, ) attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision class Kosmos2VisionMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision class Kosmos2VisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = Kosmos2VisionAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Kosmos2VisionMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, causal_attention_mask: torch.Tensor, output_attentions: bool | None = False, ) -> tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(config.encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class Kosmos2VisionEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Kosmos2VisionEncoderLayer`]. Args: config: Kosmos2VisionConfig """ def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False @can_return_tuple def forward( self, inputs_embeds, attention_mask: torch.Tensor | None = None, causal_attention_mask: torch.Tensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, ) -> tuple | BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Causal mask for the text model. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) layer_outputs = encoder_layer( hidden_states, attention_mask, causal_attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) return BaseModelOutputWithProjectionAttentions( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) # Similar to `transformers.models.clip.modeling_clip.CLIPVisionTransformer` but without docstring for `forward` class Kosmos2VisionTransformer(nn.Module): # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPVision->Kosmos2Vision,ALTCLIP_VISION->KOSMOS2_VISION,AltCLIP->Kosmos2Vision def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = Kosmos2VisionEmbeddings(config) self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = Kosmos2VisionEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, pixel_values: torch.FloatTensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, interpolate_pos_encoding: bool = False, return_dict: bool | None = None, ) -> tuple | BaseModelOutputWithPooling: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) # Similar to `transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding` but allowing to pass `position_ids` class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.__init__ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None): super().__init__() self.offset = 2 self.num_positions = num_positions self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.make_weights def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None): emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) if hasattr(self, "weights"): # in forward put the weights on the correct dtype and device of the param emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) self.register_buffer("weights", emb_weights, persistent=False) @staticmethod # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.get_embedding def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None): """ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if padding_idx is not None: emb[padding_idx, :] = 0 return emb.to(torch.get_default_dtype()) @torch.no_grad() def forward( self, input_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None, ): if input_ids is not None: bsz, seq_len = input_ids.size() if position_ids is None: # Create the position ids from the input token ids. Any padded tokens remain padded. position_ids = self.create_position_ids_from_input_ids( input_ids, self.padding_idx, past_key_values_length ).to(input_ids.device) else: bsz, seq_len = inputs_embeds.size()[:-1] if position_ids is None: position_ids = self.create_position_ids_from_inputs_embeds( inputs_embeds, past_key_values_length, self.padding_idx ) # expand embeddings if needed max_pos = self.padding_idx + 1 + seq_len + past_key_values_length if max_pos > self.weights.size(0): self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() @staticmethod # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds def create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length, padding_idx): """ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. Args: inputs_embeds: torch.Tensor Returns: torch.Tensor """ input_shape = inputs_embeds.size()[:-1] sequence_length = input_shape[1] position_ids = torch.arange( padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device ) return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length @staticmethod # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings.create_position_ids_from_input_ids def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. This is modified from fairseq's `utils.make_positions`. Args: x: torch.Tensor x: Returns: torch.Tensor """ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. mask = input_ids.ne(padding_idx).int() incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask return incremental_indices.long() + padding_idx class KosmosTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" # Similar to transformers.models.bart.modeling_bart.BartAttention.__init__ except an additional `inner_attn_ln`. def __init__( self, config, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool | None = False, add_inner_attn_layernorm: bool | None = False, bias: bool | None = True, layer_idx: bool | None = None, ): super().__init__() self.config = config self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.is_causal = True if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # End opy self.inner_attn_ln = None if add_inner_attn_layernorm: self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor | None = None, past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, output_attentions: bool = False, cache_position: torch.Tensor | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = encoder_hidden_states is not None batch_size, seq_length = hidden_states.shape[:2] query_states = self.q_proj(hidden_states) query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) is_updated = False if past_key_values is not None: if isinstance(past_key_values, EncoderDecoderCache): is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_values = past_key_values.cross_attention_cache else: curr_past_key_values = past_key_values.self_attention_cache else: curr_past_key_values = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_values.layers[self.layer_idx].keys value_states = curr_past_key_values.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() if self.inner_attn_ln is not None: attn_output = self.inner_attn_ln(attn_output) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class Kosmos2TextFFN(nn.Module): def __init__(self, config: Kosmos2TextConfig): super().__init__() self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(config.embed_dim, config.ffn_dim) self.fc2 = nn.Linear(config.ffn_dim, config.embed_dim) self.ffn_layernorm = nn.LayerNorm(config.ffn_dim, eps=config.layer_norm_eps) def forward(self, hidden_states): hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.ffn_layernorm(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return hidden_states class Kosmos2TextBlock(GradientCheckpointingLayer): def __init__(self, config: Kosmos2TextConfig, layer_idx=None): super().__init__() self.embed_dim = config.embed_dim self.self_attn = KosmosTextAttention( config, embed_dim=self.embed_dim, num_heads=config.attention_heads, dropout=config.attention_dropout, is_decoder=True, add_inner_attn_layernorm=True, layer_idx=layer_idx, ) self.dropout = config.dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) if config.add_cross_attention: self.encoder_attn = KosmosTextAttention( config, embed_dim=self.embed_dim, num_heads=config.attention_heads, dropout=config.attention_dropout, is_decoder=True, add_inner_attn_layernorm=False, layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.ffn = Kosmos2TextFFN(config) self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, output_attentions: bool | None = False, use_cache: bool | None = True, cache_position: torch.Tensor | None = None, **kwargs, ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_values=past_key_values, attention_mask=attention_mask, output_attentions=output_attentions, cache_position=cache_position, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block cross_attn_weights = None if encoder_hidden_states is not None: if not hasattr(self, "encoder_attn"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" " by setting `config.add_cross_attention=True`" ) residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) # FFN hidden_states = self.ffn(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) return outputs class Kosmos2TextTransformer(nn.Module): """ Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2TextBlock`]. Args: config: Kosmos2TextConfig """ def __init__(self, config: Kosmos2TextConfig): super().__init__() self.config = config self.dropout = config.dropout self.layerdrop = config.layerdrop self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0 self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id) self.embed_positions = Kosmos2TextSinusoidalPositionalEmbedding( num_positions=config.max_position_embeddings, embedding_dim=config.embed_dim, padding_idx=config.pad_token_id, ) self.layers = nn.ModuleList([Kosmos2TextBlock(config, layer_idx=i) for i in range(config.layers)]) self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) self.gradient_checkpointing = False def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask def forward_embedding( self, input_ids, inputs_embeds: torch.Tensor | None = None, image_embeds: torch.Tensor | None = None, img_input_mask: torch.Tensor | None = None, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None, ): # The argument `inputs_embeds` should be the one without being multiplied by `self.embed_scale`. if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if image_embeds is not None: inputs_embeds[img_input_mask.to(dtype=torch.bool)] = image_embeds.to(inputs_embeds.device).view( -1, image_embeds.size(-1) ) inputs_embeds = inputs_embeds * self.embed_scale # embed positions positions = self.embed_positions( input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids, ) positions = positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return hidden_states def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, image_embeds: torch.Tensor | None = None, image_embeds_position_mask: torch.Tensor | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, cache_position: torch.Tensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple | BaseModelOutputWithPastAndCrossAttentions: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if use_cache and past_key_values is None: past_key_values = ( EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) if encoder_hidden_states is not None or self.config.is_encoder_decoder else DynamicCache(config=self.config) ) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 # We don't need img info. when `past_key_values_length` > 0 if past_key_values_length > 0: image_embeds = None image_embeds_position_mask = None hidden_states = self.forward_embedding( input_ids=input_ids, inputs_embeds=inputs_embeds, image_embeds=image_embeds, img_input_mask=image_embeds_position_mask, past_key_values_length=past_key_values_length, position_ids=position_ids, ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, hidden_states, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: continue layer_outputs = decoder_layer( hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) # add final layer norm hidden_states = self.layer_norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) @auto_docstring class Kosmos2PreTrainedModel(PreTrainedModel): config: Kosmos2Config input_modalities = ("image", "text") supports_gradient_checkpointing = True _no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"] _supports_attention_backend = True _supports_flash_attn = True _supports_sdpa = True @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(self, Kosmos2VisionModel): factor = self.config.initializer_factor elif isinstance(self, (Kosmos2Model, Kosmos2ForConditionalGeneration)): factor = self.config.vision_config.initializer_factor if isinstance(self, (Kosmos2TextModel, Kosmos2TextForCausalLM)): std = self.config.init_std elif isinstance(self, (Kosmos2Model, Kosmos2ForConditionalGeneration)): std = self.config.text_config.init_std if isinstance(module, Kosmos2VisionEmbeddings): init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1))) elif isinstance(module, Kosmos2VisionAttention): in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor init.normal_(module.q_proj.weight, std=in_proj_std) init.normal_(module.k_proj.weight, std=in_proj_std) init.normal_(module.v_proj.weight, std=in_proj_std) init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, Kosmos2VisionMLP): in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor init.normal_(module.fc1.weight, std=fc_std) init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, KosmosTextAttention): init.normal_(module.q_proj.weight, std=std) init.normal_(module.k_proj.weight, std=std) init.normal_(module.v_proj.weight, std=std) init.normal_(module.out_proj.weight, std=std) elif isinstance(module, Kosmos2TextFFN): init.normal_(module.fc1.weight, std=std) init.normal_(module.fc2.weight, std=std) elif isinstance(module, Kosmos2TextForCausalLM): init.normal_(module.lm_head.weight, std=std) elif isinstance(module, Kosmos2ImageToTextProjection): init.normal_(module.dense.weight, std=std) init.normal_(module.latent_query) elif isinstance(module, Kosmos2TextTransformer): init.normal_(module.embed_tokens.weight, mean=0.0, std=std) if module.embed_tokens.padding_idx is not None: init.zeros_(module.embed_tokens.weight[module.embed_tokens.padding_idx]) elif isinstance(module, nn.LayerNorm): init.ones_(module.weight) init.zeros_(module.bias) elif isinstance(module, Kosmos2TextSinusoidalPositionalEmbedding): emb_weights = module.get_embedding( module.num_positions + module.offset, module.embedding_dim, module.padding_idx ) init.copy_(module.weights, emb_weights) if isinstance(module, nn.Linear) and module.bias is not None: init.zeros_(module.bias) class Kosmos2VisionModel(Kosmos2PreTrainedModel): config: Kosmos2VisionConfig main_input_name = "pixel_values" input_modalities = ("image",) # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model def __init__(self, config: Kosmos2VisionConfig): super().__init__(config) self.model = Kosmos2VisionTransformer(config) # Initialize weights and apply final processing self.post_init() # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.get_input_embeddings with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model def get_input_embeddings(self) -> nn.Module: return self.model.embeddings.patch_embedding @auto_docstring def forward( self, pixel_values: torch.FloatTensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, interpolate_pos_encoding: bool = False, return_dict: bool | None = None, **kwargs, ) -> tuple | BaseModelOutputWithProjectionAttentions: return self.model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) class Kosmos2TextModel(Kosmos2PreTrainedModel): config: Kosmos2TextConfig input_modalities = ("text",) def __init__(self, config: Kosmos2TextConfig): super().__init__(config) self.model = Kosmos2TextTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.model.embed_tokens @can_return_tuple @auto_docstring def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, image_embeds: torch.Tensor | None = None, image_embeds_position_mask: torch.Tensor | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, cache_position: torch.Tensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple | BaseModelOutputWithPastAndCrossAttentions: r""" image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, 1]`: - 1 for places where to put the image features, - 0 for places that are not for image features (i.e. for text tokens). """ return self.model( input_ids=input_ids, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, position_ids=position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) @auto_docstring( custom_intro=""" The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input embeddings). """ ) class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): config: Kosmos2TextConfig _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: Kosmos2TextConfig): super().__init__(config) self.model = Kosmos2TextTransformer(config) self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.model.embed_tokens def get_output_embeddings(self) -> nn.Module: return self.lm_head @can_return_tuple @auto_docstring def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, image_embeds: torch.Tensor | None = None, image_embeds_position_mask: torch.Tensor | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, cache_position: torch.Tensor | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithCrossAttentions: r""" image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, 1]`: - 1 for places where to put the image features, - 0 for places that are not for image features (i.e. for text tokens). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ if labels is not None: if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False outputs: BaseModelOutputWithPastAndCrossAttentions = self.model( input_ids=input_ids, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, position_ids=position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation( self, input_ids, image_embeds=None, image_embeds_position_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, use_cache=None, cache_position=None, is_first_iteration=False, **model_kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model # Pixel values are used only in the first iteration if available # In subsquent iterations, they are already merged with text and cached # NOTE: first iteration doesn't have to be prefill, it can be the first # iteration with a question and cached system prompt (continue generate from cache) if not is_first_iteration and use_cache: image_embeds = None image_embeds_position_mask = None # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) elif image_embeds_position_mask is not None: batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else input_ids.size() mask_len = image_embeds_position_mask.size()[-1] image_embeds_position_mask = torch.cat( ( image_embeds_position_mask, torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device), ), dim=1, ) model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, is_first_iteration=is_first_iteration, **model_kwargs, ) # Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer model_inputs.pop("position_ids", None) return model_inputs class Kosmos2ImageToTextProjection(nn.Module): """The layer that transforms the image model's output to part of the text model's input (namely, image features)""" def __init__(self, config: Kosmos2Config): super().__init__() self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim) self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim)) self.x_attn = KosmosTextAttention( config.text_config, config.text_config.embed_dim, config.text_config.attention_heads, dropout=config.text_config.attention_dropout, is_decoder=False, add_inner_attn_layernorm=False, ) def forward(self, features): hidden_states = self.dense(features) # shape = [batch, latent_query_num, h_dim] latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1) key_value_states = torch.cat([hidden_states, latent_query], dim=1) hidden_states, attn_weights = self.x_attn( hidden_states=latent_query, encoder_hidden_states=key_value_states, past_key_values=None, attention_mask=None, output_attentions=None, ) return hidden_states, attn_weights @auto_docstring( custom_intro=""" KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model. """ ) class Kosmos2Model(Kosmos2PreTrainedModel): config: Kosmos2Config main_input_name = "pixel_values" def __init__(self, config: Kosmos2Config): super().__init__(config) self.text_model = Kosmos2TextModel(config.text_config) self.vision_model = Kosmos2VisionModel(config.vision_config) self.image_to_text_projection = Kosmos2ImageToTextProjection(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.model.embed_tokens def set_input_embeddings(self, value): self.text_model.model.embed_tokens = value @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool | None = False, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithProjectionAttentions: if "return_attentions" in kwargs: warnings.warn( "`return_attentions` is deprecated and will be removed in a future version. Please use `return_dict`" " and access `projection_attentions` from the returned `ModelOutput` instead.", FutureWarning, ) kwargs.pop("return_attentions", None) vision_output: BaseModelOutputWithProjectionAttentions = self.vision_model( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True, **kwargs, ) # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. image_embeds = self.vision_model.model.post_layernorm(vision_output[0]) # normalized features image_embeds = nn.functional.normalize(image_embeds, dim=-1) image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) vision_output.pooler_output = image_embeds vision_output.projection_attentions = projection_attentions return vision_output @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.Tensor | None = None, input_ids: torch.Tensor | None = None, image_embeds_position_mask: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, image_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, interpolate_pos_encoding: bool = False, return_dict: bool | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple | Kosmos2ModelOutput: r""" image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, 1]`: - 1 for places where to put the image features, - 0 for places that are not for image features (i.e. for text tokens). image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. Examples: ```python >>> from PIL import Image >>> import httpx >>> from io import BytesIO >>> from transformers import AutoProcessor, Kosmos2Model >>> model = Kosmos2Model.from_pretrained("microsoft/kosmos-2-patch14-224") >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" >>> with httpx.stream("GET", url) as response: ... image = Image.open(BytesIO(response.read())) >>> text = ( ... " An image of a snowman" ... " warming himself by a fire" ... "" ... ) >>> inputs = processor(text=text, images=image, return_tensors="pt", add_eos_token=True) >>> last_hidden_state = model( ... pixel_values=inputs["pixel_values"], ... input_ids=inputs["input_ids"], ... attention_mask=inputs["attention_mask"], ... image_embeds_position_mask=inputs["image_embeds_position_mask"], ... ).last_hidden_state >>> list(last_hidden_state.shape) [1, 91, 2048] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_model_output = None projection_attentions = None if image_embeds is None: if pixel_values is None: raise ValueError("You have to specify either `pixel_values` or `image_embeds`.") image_features = self.get_image_features( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True ) image_embeds = image_features.pooler_output projection_attentions = image_features.projection_attentions outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, position_ids=position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, **kwargs, ) return Kosmos2ModelOutput( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_embeds=image_embeds, projection_attentions=projection_attentions, vision_model_output=vision_model_output, ) @auto_docstring( custom_intro=""" KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a language model. """ ) class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): config: Kosmos2Config main_input_name = "pixel_values" _tied_weights_keys = {"text_model.lm_head.weight": "text_model.model.embed_tokens.weight"} def __init__(self, config: Kosmos2Config): super().__init__(config) self.text_model = Kosmos2TextForCausalLM(config.text_config) self.vision_model = Kosmos2VisionModel(config.vision_config) self.image_to_text_projection = Kosmos2ImageToTextProjection(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.model.embed_tokens def set_input_embeddings(self, value): self.text_model.model.embed_tokens = value def get_output_embeddings(self) -> nn.Module: return self.text_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.text_model.set_output_embeddings(new_embeddings) @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.Tensor | None = None, input_ids: torch.Tensor | None = None, image_embeds_position_mask: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, image_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Kosmos2ForConditionalGenerationModelOutput: r""" image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, 1]`: - 1 for places where to put the image features, - 0 for places that are not for image features (i.e. for text tokens). image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` Examples: ```python >>> from PIL import Image >>> import httpx >>> from io import BytesIO >>> from transformers import AutoProcessor, Kosmos2ForConditionalGeneration >>> model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224") >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" >>> with httpx.stream("GET", url) as response: ... image = Image.open(BytesIO(response.read())) >>> prompt = " An image of" >>> inputs = processor(text=prompt, images=image, return_tensors="pt") >>> generated_ids = model.generate( ... pixel_values=inputs["pixel_values"], ... input_ids=inputs["input_ids"], ... attention_mask=inputs["attention_mask"], ... image_embeds=None, ... image_embeds_position_mask=inputs["image_embeds_position_mask"], ... use_cache=True, ... max_new_tokens=64, ... ) >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False) >>> processed_text ' An image of a snowman warming himself by a fire.' >>> caption, entities = processor.post_process_generation(generated_text) >>> caption 'An image of a snowman warming himself by a fire.' >>> entities [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) vision_model_output = None projection_attentions = None if image_embeds is None: if pixel_values is None: raise ValueError("You have to specify either `pixel_values` or `image_embeds`.") vision_model_output = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) # normalized features image_embeds = nn.functional.normalize(image_embeds, dim=-1) image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) lm_outputs: CausalLMOutputWithCrossAttentions = self.text_model( input_ids=input_ids, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, position_ids=position_ids, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, logits_to_keep=logits_to_keep, **kwargs, ) return Kosmos2ForConditionalGenerationModelOutput( loss=lm_outputs.loss, logits=lm_outputs.logits, past_key_values=lm_outputs.past_key_values, hidden_states=lm_outputs.hidden_states, attentions=lm_outputs.attentions, image_embeds=image_embeds, projection_attentions=projection_attentions, vision_model_output=vision_model_output, ) @torch.no_grad() def generate( self, pixel_values: torch.Tensor | None = None, image_embeds_position_mask: torch.Tensor | None = None, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, image_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs, ): # in order to allow `inputs` argument (as in `GenerationMixin`) inputs = kwargs.pop("inputs", None) if pixel_values is not None and inputs is not None: raise ValueError( f"`inputs`: {inputs} were passed alongside `pixel_values` which is not allowed." f"Make sure to either pass `inputs` or pixel_values=..." ) if pixel_values is None and inputs is not None: pixel_values = inputs if image_embeds is None: vision_model_output = self.vision_model(pixel_values) # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) # normalized features image_embeds = nn.functional.normalize(image_embeds, dim=-1) image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) output = self.text_model.generate( input_ids=input_ids, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, inputs_embeds=inputs_embeds, **kwargs, ) return output __all__ = ["Kosmos2ForConditionalGeneration", "Kosmos2Model", "Kosmos2PreTrainedModel"]