You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1424 lines
58 KiB
1424 lines
58 KiB
# Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
|
|
# All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch GIT model."""
|
|
|
|
import math
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ... import initialization as init
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...configuration_utils import PreTrainedConfig
|
|
from ...generation import GenerationMixin
|
|
from ...masking_utils import create_masks_for_generate
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPast,
|
|
BaseModelOutputWithPooling,
|
|
CausalLMOutputWithPast,
|
|
)
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...pytorch_utils import apply_chunking_to_forward
|
|
from ...utils import (
|
|
ModelOutput,
|
|
auto_docstring,
|
|
can_return_tuple,
|
|
logging,
|
|
torch_int,
|
|
)
|
|
from ...utils.generic import is_flash_attention_requested
|
|
from .configuration_git import GitConfig, GitVisionConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
|
"""
|
|
)
|
|
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Git
|
|
class GitVisionModelOutput(ModelOutput):
|
|
r"""
|
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
|
The image embeddings obtained by applying the projection layer to the pooler_output.
|
|
"""
|
|
|
|
image_embeds: torch.FloatTensor | None = None
|
|
last_hidden_state: torch.FloatTensor | None = None
|
|
hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
|
attentions: tuple[torch.FloatTensor, ...] | None = None
|
|
|
|
|
|
# Copied from transformers.models.gemma3.modeling_gemma3.token_type_ids_mask_function
|
|
def token_type_ids_mask_function(
|
|
token_type_ids: torch.Tensor | None,
|
|
image_group_ids: torch.Tensor | None,
|
|
) -> Callable | None:
|
|
"""
|
|
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
|
not start and end indices.
|
|
"""
|
|
# Do not return an additional mask in this case
|
|
if token_type_ids is None:
|
|
return None
|
|
|
|
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
|
# If it's 1 for both query and key/value, we are in an image block
|
|
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
|
|
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
|
|
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
|
|
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
|
|
|
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
|
|
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
|
|
|
|
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
|
|
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
|
|
|
|
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
|
|
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
|
|
|
|
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
|
|
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
|
|
|
|
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
|
|
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
|
|
|
|
# This is bidirectional attention whenever we are dealing with image tokens
|
|
return is_image_block & same_image_block
|
|
|
|
return inner_mask
|
|
|
|
|
|
# Copied from transformers.models.gemma3.modeling_gemma3.create_causal_mask_mapping
|
|
def create_causal_mask_mapping(
|
|
config: PreTrainedConfig,
|
|
input_embeds: torch.Tensor,
|
|
attention_mask: torch.Tensor | None,
|
|
cache_position: torch.Tensor,
|
|
past_key_values: Cache | None,
|
|
position_ids: torch.Tensor | None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
is_training: bool = False,
|
|
is_first_iteration: bool | None = None,
|
|
**kwargs,
|
|
) -> dict:
|
|
"""
|
|
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
|
|
for all kinds of forward passes. Gemma3 uses a bidirectional mask for images.
|
|
|
|
Uses `pixel_values` as an optional input to disambiguate edge cases.
|
|
"""
|
|
if is_training and token_type_ids is None:
|
|
raise ValueError("`token_type_ids` is required as a model input when training")
|
|
|
|
mask_kwargs = {
|
|
"config": config.get_text_config(),
|
|
"input_embeds": input_embeds,
|
|
"attention_mask": attention_mask,
|
|
"cache_position": cache_position,
|
|
"past_key_values": past_key_values,
|
|
"position_ids": position_ids,
|
|
}
|
|
# NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
|
|
# (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
|
|
# means). Determining prefill in that case requires checking data values, which is not compile-compatible.
|
|
is_first_iteration = (
|
|
is_first_iteration
|
|
if is_first_iteration is not None
|
|
else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
|
|
)
|
|
if token_type_ids is not None and is_first_iteration:
|
|
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
|
|
# undo the causal masking)
|
|
|
|
# First find where a new image block starts: 1 if image and previous not image
|
|
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
|
is_image = (token_type_ids == 1).to(cache_position.device)
|
|
is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
|
new_image_start = is_image & ~is_previous_image
|
|
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
|
image_group_ids = torch.where(is_image, image_group_ids, -1)
|
|
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
|
token_type_ids.to(cache_position.device), image_group_ids
|
|
)
|
|
|
|
return create_masks_for_generate(**mask_kwargs)
|
|
|
|
|
|
class GitEmbeddings(nn.Module):
|
|
"""Construct the embeddings from word and position embeddings."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
self.register_buffer(
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
past_key_values_length: int = 0,
|
|
) -> torch.Tensor:
|
|
if input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
else:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
seq_length = input_shape[1]
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
|
|
|
if inputs_embeds is None:
|
|
embeddings = self.word_embeddings(input_ids)
|
|
else:
|
|
embeddings = inputs_embeds
|
|
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings += position_embeddings
|
|
|
|
embeddings = self.LayerNorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
|
|
class GitSelfAttention(nn.Module):
|
|
def __init__(self, config, layer_idx=None):
|
|
super().__init__()
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
raise ValueError(
|
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
|
f"heads ({config.num_attention_heads})"
|
|
)
|
|
self.layer_idx = layer_idx
|
|
if layer_idx is None:
|
|
logger.warning_once(
|
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
|
"when creating this class."
|
|
)
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
|
if config.num_image_with_embedding is not None:
|
|
self.image_patch_tokens *= config.num_image_with_embedding
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor]:
|
|
batch_size = hidden_states.shape[0]
|
|
query_layer = (
|
|
self.query(hidden_states)
|
|
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
|
.transpose(1, 2)
|
|
)
|
|
|
|
key_layer = (
|
|
self.key(hidden_states)
|
|
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
|
.transpose(1, 2)
|
|
)
|
|
value_layer = (
|
|
self.value(hidden_states)
|
|
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
|
.transpose(1, 2)
|
|
)
|
|
if past_key_values is not None:
|
|
key_layer, value_layer = past_key_values.update(
|
|
key_layer, value_layer, self.layer_idx, cache_kwargs={"cache_position": cache_position}
|
|
)
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
if attention_mask is not None:
|
|
# Apply the attention mask is (precomputed for all layers in GitModel forward() function)
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs = self.dropout(attention_probs)
|
|
|
|
context_layer = torch.matmul(attention_probs, value_layer)
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
context_layer = context_layer.view(new_context_layer_shape)
|
|
|
|
return context_layer, attention_probs
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
|
|
class GitSelfOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
GIT_SELF_ATTENTION_CLASSES = {
|
|
"eager": GitSelfAttention,
|
|
}
|
|
|
|
|
|
class GitAttention(nn.Module):
|
|
def __init__(self, config, layer_idx=None):
|
|
super().__init__()
|
|
self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
|
self.output = GitSelfOutput(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.Tensor | None = None,
|
|
output_attentions: bool | None = False,
|
|
) -> tuple[torch.Tensor]:
|
|
attn_output, self_attn_weights = self.self(
|
|
hidden_states,
|
|
attention_mask,
|
|
past_key_values,
|
|
cache_position=cache_position,
|
|
)
|
|
attention_output = self.output(attn_output, hidden_states)
|
|
return attention_output, self_attn_weights
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertIntermediate
|
|
class GitIntermediate(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.intermediate_act_fn = config.hidden_act
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertOutput
|
|
class GitOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class GitLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config, layer_idx=None):
|
|
super().__init__()
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
self.seq_len_dim = 1
|
|
self.attention = GitAttention(config, layer_idx=layer_idx)
|
|
self.intermediate = GitIntermediate(config)
|
|
self.output = GitOutput(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.Tensor | None = None,
|
|
output_attentions: bool | None = False,
|
|
) -> tuple[torch.Tensor]:
|
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
|
attention_output, self_attention_weights = self.attention(
|
|
hidden_states,
|
|
attention_mask,
|
|
output_attentions=output_attentions,
|
|
past_key_values=past_key_values,
|
|
cache_position=cache_position,
|
|
)
|
|
|
|
layer_output = apply_chunking_to_forward(
|
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
|
)
|
|
return layer_output, self_attention_weights
|
|
|
|
def feed_forward_chunk(self, attention_output):
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
return layer_output
|
|
|
|
|
|
class GitEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer = nn.ModuleList([GitLayer(config, i) for i in range(config.num_hidden_layers)])
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
use_cache: bool | None = None,
|
|
output_attentions: bool | None = False,
|
|
output_hidden_states: bool | None = False,
|
|
return_dict: bool | None = True,
|
|
cache_position: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor] | BaseModelOutputWithPast:
|
|
if self.gradient_checkpointing and self.training:
|
|
if use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
)
|
|
use_cache = False
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = DynamicCache(config=self.config)
|
|
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
for i, layer_module in enumerate(self.layer):
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
layer_outputs = layer_module(
|
|
hidden_states,
|
|
attention_mask,
|
|
past_key_values,
|
|
output_attentions,
|
|
cache_position,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [
|
|
hidden_states,
|
|
past_key_values,
|
|
all_hidden_states,
|
|
all_self_attentions,
|
|
]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class GitPreTrainedModel(PreTrainedModel):
|
|
config: GitConfig
|
|
base_model_prefix = "git"
|
|
input_modalities = ("image", "text")
|
|
supports_gradient_checkpointing = True
|
|
|
|
@torch.no_grad()
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
if isinstance(module, GitVisionEmbeddings):
|
|
init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
|
|
init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
|
|
init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
|
|
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
if isinstance(module, nn.Linear):
|
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
if module.bias is not None:
|
|
init.zeros_(module.bias)
|
|
elif isinstance(module, nn.Embedding):
|
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
|
|
init.zeros_(module.weight[module.padding_idx])
|
|
elif isinstance(module, nn.LayerNorm):
|
|
init.zeros_(module.bias)
|
|
init.ones_(module.weight)
|
|
elif isinstance(module, GitEmbeddings):
|
|
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
|
|
|
|
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
|
|
class GitVisionEmbeddings(nn.Module):
|
|
def __init__(self, config: GitVisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
|
|
|
self.patch_embedding = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=self.embed_dim,
|
|
kernel_size=self.patch_size,
|
|
stride=self.patch_size,
|
|
bias=False,
|
|
)
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
self.num_positions = self.num_patches + 1
|
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
"""
|
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
|
images. This method is also adapted to support torch.jit tracing.
|
|
|
|
Adapted from:
|
|
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
|
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
|
"""
|
|
|
|
num_patches = embeddings.shape[1] - 1
|
|
position_embedding = self.position_embedding.weight.unsqueeze(0)
|
|
num_positions = position_embedding.shape[1] - 1
|
|
|
|
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
|
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
|
return self.position_embedding(self.position_ids)
|
|
|
|
class_pos_embed = position_embedding[:, :1]
|
|
patch_pos_embed = position_embedding[:, 1:]
|
|
|
|
dim = embeddings.shape[-1]
|
|
|
|
new_height = height // self.patch_size
|
|
new_width = width // self.patch_size
|
|
|
|
sqrt_num_positions = torch_int(num_positions**0.5)
|
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
|
|
|
patch_pos_embed = nn.functional.interpolate(
|
|
patch_pos_embed,
|
|
size=(new_height, new_width),
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
)
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
|
|
|
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
|
batch_size, _, height, width = pixel_values.shape
|
|
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
|
|
raise ValueError(
|
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
|
|
)
|
|
target_dtype = self.patch_embedding.weight.dtype
|
|
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
|
|
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
|
if interpolate_pos_encoding:
|
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
|
else:
|
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
return embeddings
|
|
|
|
|
|
class GitVisionMLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: torch.Tensor | None,
|
|
scaling: float,
|
|
dropout: float = 0.0,
|
|
**kwargs,
|
|
):
|
|
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
|
if attention_mask is not None:
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class GitVisionAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
if self.head_dim * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
self.scale = self.head_dim**-0.5
|
|
self.dropout = config.attention_dropout
|
|
self.is_causal = False
|
|
|
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
causal_attention_mask: torch.Tensor | None = None,
|
|
output_attentions: bool | None = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
batch_size, seq_length, embed_dim = hidden_states.shape
|
|
|
|
queries = self.q_proj(hidden_states)
|
|
keys = self.k_proj(hidden_states)
|
|
values = self.v_proj(hidden_states)
|
|
|
|
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
|
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
|
|
if not is_flash_attention_requested(self.config):
|
|
if attention_mask is not None and causal_attention_mask is not None:
|
|
attention_mask = attention_mask + causal_attention_mask
|
|
elif causal_attention_mask is not None:
|
|
attention_mask = causal_attention_mask
|
|
else:
|
|
self.is_causal = causal_attention_mask is not None
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
queries,
|
|
keys,
|
|
values,
|
|
attention_mask,
|
|
is_causal=self.is_causal,
|
|
scaling=self.scale,
|
|
dropout=0.0 if not self.training else self.dropout,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
|
attn_output = self.out_proj(attn_output)
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
return attn_output, attn_weights
|
|
|
|
|
|
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision
|
|
class GitVisionEncoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: GitVisionConfig):
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.self_attn = GitVisionAttention(config)
|
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
self.mlp = GitVisionMLP(config)
|
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
causal_attention_mask: torch.Tensor,
|
|
output_attentions: bool | None = False,
|
|
) -> tuple[torch.FloatTensor]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
`(config.encoder_attention_heads,)`.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states, attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
causal_attention_mask=causal_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->GitVision, CLIPConfig
|
|
class GitVisionEncoder(nn.Module):
|
|
"""
|
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
|
[`GitVisionEncoderLayer`].
|
|
|
|
Args:
|
|
config: GitVisionConfig
|
|
"""
|
|
|
|
def __init__(self, config: GitVisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.gradient_checkpointing = False
|
|
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
inputs_embeds,
|
|
attention_mask: torch.Tensor | None = None,
|
|
causal_attention_mask: torch.Tensor | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
return_dict: bool | None = None,
|
|
) -> tuple | BaseModelOutput:
|
|
r"""
|
|
Args:
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
for more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
encoder_states = () if output_hidden_states else None
|
|
all_attentions = () if output_attentions else None
|
|
|
|
hidden_states = inputs_embeds
|
|
for idx, encoder_layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
layer_outputs = encoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
causal_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_attentions = all_attentions + (layer_outputs[1],)
|
|
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
|
)
|
|
|
|
|
|
class GitVisionTransformer(nn.Module):
|
|
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPEncoder->GitVisionEncoder, AltCLIP->Git
|
|
def __init__(self, config: GitVisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
embed_dim = config.hidden_size
|
|
|
|
self.embeddings = GitVisionEmbeddings(config)
|
|
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
|
self.encoder = GitVisionEncoder(config)
|
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
interpolate_pos_encoding: bool | None = False,
|
|
return_dict: bool | None = None,
|
|
) -> tuple | BaseModelOutput:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if pixel_values is None:
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
|
hidden_states = self.pre_layrnorm(hidden_states)
|
|
|
|
encoder_outputs = self.encoder(
|
|
inputs_embeds=hidden_states,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
last_hidden_state = encoder_outputs[0]
|
|
|
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
|
|
if not return_dict:
|
|
return (last_hidden_state,) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutput(
|
|
last_hidden_state=last_hidden_state,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The vision model from CLIP, used in GIT, without any head or projection on top.
|
|
"""
|
|
)
|
|
class GitVisionModel(GitPreTrainedModel):
|
|
config: GitVisionConfig
|
|
main_input_name = "pixel_values"
|
|
input_modalities = ("image",)
|
|
|
|
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP->Git
|
|
def __init__(self, config: GitVisionConfig):
|
|
super().__init__(config)
|
|
self.vision_model = GitVisionTransformer(config)
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> nn.Module:
|
|
return self.vision_model.embeddings.patch_embedding
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
interpolate_pos_encoding: bool = False,
|
|
return_dict: bool | None = None,
|
|
**kwargs,
|
|
) -> tuple | BaseModelOutput:
|
|
r"""
|
|
Examples:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
>>> from transformers import AutoProcessor, GitVisionModel
|
|
|
|
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
|
>>> model = GitVisionModel.from_pretrained("microsoft/git-base")
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read()))
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt")
|
|
|
|
>>> outputs = model(**inputs)
|
|
>>> last_hidden_state = outputs.last_hidden_state
|
|
```"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
return self.vision_model(
|
|
pixel_values=pixel_values,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
|
|
class GitProjection(nn.Module):
|
|
def __init__(self, config: GitConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.visual_projection = nn.Sequential(
|
|
nn.Linear(config.vision_config.hidden_size, config.hidden_size),
|
|
nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps),
|
|
)
|
|
|
|
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
|
|
return self.visual_projection(embeddings)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states
|
|
"""
|
|
)
|
|
class GitModel(GitPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.embeddings = GitEmbeddings(config)
|
|
self.image_encoder = GitVisionModel(config.vision_config)
|
|
self.encoder = GitEncoder(config)
|
|
|
|
self.visual_projection = GitProjection(config)
|
|
|
|
if config.num_image_with_embedding is not None:
|
|
self.img_temporal_embedding = nn.ParameterList(
|
|
nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
|
|
for _ in range(config.num_image_with_embedding)
|
|
)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
pixel_values: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
use_cache: bool | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
interpolate_pos_encoding: bool = False,
|
|
return_dict: bool | None = None,
|
|
cache_position: torch.Tensor | None = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor] | BaseModelOutputWithPooling:
|
|
r"""
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import AutoProcessor, AutoModel
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
>>> from PIL import Image
|
|
|
|
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
|
>>> model = AutoModel.from_pretrained("microsoft/git-base")
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read()))
|
|
|
|
>>> text = "this is an image of two cats"
|
|
|
|
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
|
|
|
>>> outputs = model(**inputs)
|
|
>>> last_hidden_state = outputs.last_hidden_state
|
|
```"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
# past_key_values_length
|
|
past_key_values_length = 0
|
|
if past_key_values is not None:
|
|
past_key_values_length = (
|
|
past_key_values.get_seq_length()
|
|
if not isinstance(past_key_values, Cache)
|
|
else past_key_values.get_seq_length()
|
|
)
|
|
|
|
# Adjust position ids by adding image seq length
|
|
if pixel_values is None and past_key_values is not None and input_ids.shape[1] == 1:
|
|
position_ids = position_ids + past_key_values_length
|
|
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
past_key_values_length=past_key_values_length,
|
|
)
|
|
|
|
if cache_position is None:
|
|
cache_position = torch.arange(
|
|
past_key_values_length,
|
|
past_key_values_length + embedding_output.shape[1],
|
|
device=embedding_output.device,
|
|
)
|
|
|
|
# Always create `token_type_ids` so we can re-use Gemma3 style mask preparation fn
|
|
token_type_ids = torch.zeros_like(embedding_output, dtype=torch.int)[..., 0]
|
|
|
|
if pixel_values is not None:
|
|
if pixel_values.ndim == 4:
|
|
# here we assume pixel_values is of shape (batch_size, num_channels, height, width)
|
|
visual_features = self.image_encoder(
|
|
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
|
).last_hidden_state
|
|
|
|
elif pixel_values.ndim == 5:
|
|
# here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
|
|
visual_features = []
|
|
for frame_idx in range(pixel_values.shape[1]):
|
|
visual_features_frame = self.image_encoder(
|
|
pixel_values[:, frame_idx, :, :], interpolate_pos_encoding=interpolate_pos_encoding
|
|
).last_hidden_state
|
|
visual_features_frame += self.img_temporal_embedding[frame_idx]
|
|
visual_features.append(visual_features_frame)
|
|
|
|
# finally, concatenate all features along sequence dimension
|
|
visual_features = torch.cat(visual_features, dim=1)
|
|
|
|
else:
|
|
raise ValueError("pixel_values must be of rank 4 or 5")
|
|
|
|
projected_visual_features = self.visual_projection(visual_features)
|
|
|
|
# Repeat visual features to match embedding batch size.
|
|
projected_visual_features = projected_visual_features.repeat(
|
|
embedding_output.size(0) // projected_visual_features.size(0), 1, 1
|
|
)
|
|
|
|
# concatenate patch token and text token embeddings
|
|
embedding_output = torch.cat((projected_visual_features, embedding_output), dim=1)
|
|
image_token_type_ids = torch.ones_like(projected_visual_features, dtype=torch.int)[..., 0]
|
|
token_type_ids = torch.cat([image_token_type_ids, token_type_ids], dim=-1)
|
|
cache_position = torch.arange(embedding_output.shape[1], device=embedding_output.device, dtype=torch.int)
|
|
if attention_mask is not None:
|
|
attention_mask = torch.cat([torch.ones_like(image_token_type_ids), attention_mask], dim=-1)
|
|
elif past_key_values is not None and input_ids.shape[1] == 1:
|
|
# Expand attention mask and cache position with image tokens because GIT doesn't add image
|
|
# placeholder tokens when processing. Doesn't worth the refactor, low usage!
|
|
cache_position = torch.tensor(
|
|
[past_key_values_length], dtype=cache_position.dtype, device=cache_position.device
|
|
)
|
|
extended_attention_mask = torch.ones(
|
|
(attention_mask.shape[0], past_key_values_length - attention_mask.shape[1] + 1),
|
|
dtype=attention_mask.dtype,
|
|
device=attention_mask.device,
|
|
)
|
|
attention_mask = torch.cat([extended_attention_mask, attention_mask], dim=-1)
|
|
|
|
# Images attend each other bidirectionally while text remains causal
|
|
causal_mask = create_causal_mask_mapping(
|
|
self.config,
|
|
embedding_output,
|
|
attention_mask,
|
|
cache_position,
|
|
past_key_values,
|
|
None,
|
|
token_type_ids,
|
|
pixel_values,
|
|
)
|
|
|
|
hidden_states = embedding_output
|
|
|
|
encoder_outputs = self.encoder(
|
|
hidden_states,
|
|
attention_mask=causal_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
cache_position=cache_position,
|
|
)
|
|
sequence_output = encoder_outputs[0]
|
|
|
|
if not return_dict:
|
|
return (sequence_output,) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=sequence_output,
|
|
past_key_values=encoder_outputs.past_key_values,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
GIT Model with a `language modeling` head on top for autoregressive language modeling.
|
|
"""
|
|
)
|
|
class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
|
|
_tied_weights_keys = {"output.weight": "git.embeddings.word_embeddings.weight"}
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.git = GitModel(config)
|
|
self.output = nn.Linear(config.hidden_size, config.vocab_size)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.output
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.output = new_embeddings
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
pixel_values: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
labels: torch.Tensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
use_cache: bool | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
interpolate_pos_encoding: bool = False,
|
|
return_dict: bool | None = None,
|
|
logits_to_keep: int | torch.Tensor = 0,
|
|
cache_position: torch.Tensor | None = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor] | CausalLMOutputWithPast:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
|
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
|
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
|
|
|
|
Examples:
|
|
|
|
Image captioning example:
|
|
|
|
```python
|
|
>>> from transformers import AutoProcessor, AutoModelForCausalLM
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
>>> from PIL import Image
|
|
|
|
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read()))
|
|
|
|
>>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
|
|
|
>>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
|
|
>>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
>>> print(generated_caption)
|
|
two cats sleeping on a pink blanket next to remotes.
|
|
```
|
|
|
|
Visual question answering (VQA) example:
|
|
|
|
```python
|
|
>>> from transformers import AutoProcessor, AutoModelForCausalLM
|
|
>>> from huggingface_hub import hf_hub_download
|
|
>>> from PIL import Image
|
|
|
|
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
|
|
|
|
>>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
|
|
>>> image = Image.open(file_path).convert("RGB")
|
|
|
|
>>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
|
|
|
>>> question = "what does the front of the bus say at the top?"
|
|
|
|
>>> input_ids = processor(text=question, add_special_tokens=False).input_ids
|
|
>>> input_ids = [processor.tokenizer.cls_token_id] + input_ids
|
|
>>> input_ids = torch.tensor(input_ids).unsqueeze(0)
|
|
|
|
>>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
|
|
>>> print(processor.batch_decode(generated_ids, skip_special_tokens=True))
|
|
['what does the front of the bus say at the top? special']
|
|
```
|
|
|
|
Video captioning example:
|
|
|
|
```python
|
|
>>> import av
|
|
>>> import numpy as np
|
|
>>> from PIL import Image
|
|
>>> from huggingface_hub import hf_hub_download
|
|
>>> from transformers import AutoProcessor, AutoModelForCausalLM
|
|
|
|
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
|
|
|
|
>>> # set seed for reproducibility
|
|
>>> np.random.seed(45)
|
|
|
|
|
|
>>> def read_video_pyav(container, indices):
|
|
... '''
|
|
... Decode the video with PyAV decoder.
|
|
... Args:
|
|
... container (`av.container.input.InputContainer`): PyAV container.
|
|
... indices (`list[int]`): List of frame indices to decode.
|
|
... Returns:
|
|
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
|
|
... '''
|
|
... frames = []
|
|
... container.seek(0)
|
|
... start_index = indices[0]
|
|
... end_index = indices[-1]
|
|
... for i, frame in enumerate(container.decode(video=0)):
|
|
... if i > end_index:
|
|
... break
|
|
... if i >= start_index and i in indices:
|
|
... frames.append(frame)
|
|
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
|
|
|
|
|
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
|
|
... '''
|
|
... Sample a given number of frame indices from the video.
|
|
... Args:
|
|
... clip_len (`int`): Total number of frames to sample.
|
|
... frame_sample_rate (`int`): Sample every n-th frame.
|
|
... seg_len (`int`): Maximum allowed index of sample's last frame.
|
|
... Returns:
|
|
... indices (`list[int]`): List of sampled frame indices
|
|
... '''
|
|
... converted_len = int(clip_len * frame_sample_rate)
|
|
... end_idx = np.random.randint(converted_len, seg_len)
|
|
... start_idx = end_idx - converted_len
|
|
... indices = np.linspace(start_idx, end_idx, num=clip_len)
|
|
... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
|
|
... return indices
|
|
|
|
|
|
>>> # load video
|
|
>>> file_path = hf_hub_download(
|
|
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
|
|
... )
|
|
>>> container = av.open(file_path)
|
|
|
|
>>> # sample frames
|
|
>>> num_frames = model.config.num_image_with_embedding
|
|
>>> indices = sample_frame_indices(
|
|
... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
|
|
... )
|
|
>>> frames = read_video_pyav(container, indices)
|
|
|
|
>>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values
|
|
|
|
>>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
|
|
|
|
>>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))
|
|
Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']
|
|
```
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
if labels is not None:
|
|
use_cache = False
|
|
|
|
outputs = self.git(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
pixel_values=pixel_values,
|
|
inputs_embeds=inputs_embeds,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
return_dict=return_dict,
|
|
cache_position=cache_position,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
logits = self.output(hidden_states[:, slice_indices, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
# we are doing next-token prediction; shift prediction scores and input ids by one
|
|
num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
|
|
shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
|
|
labels = labels[:, 1:].contiguous()
|
|
loss = self.loss_function(
|
|
shifted_logits.view(-1, self.config.vocab_size),
|
|
labels.view(-1),
|
|
vocab_size=self.config.vocab_size,
|
|
**kwargs,
|
|
)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
pixel_values=None,
|
|
attention_mask=None,
|
|
use_cache=None,
|
|
cache_position=None,
|
|
is_first_iteration=False,
|
|
**kwargs,
|
|
):
|
|
# Overwritten -- `git` has special `pixel_values` handling
|
|
|
|
model_inputs = super().prepare_inputs_for_generation(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
is_first_iteration=is_first_iteration,
|
|
**kwargs,
|
|
)
|
|
|
|
if is_first_iteration or not use_cache:
|
|
model_inputs["pixel_values"] = pixel_values
|
|
|
|
return model_inputs
|
|
|
|
|
|
__all__ = ["GitForCausalLM", "GitModel", "GitPreTrainedModel", "GitVisionModel"]
|