You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1473 lines
62 KiB
1473 lines
62 KiB
# Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch InstructBLIP model."""
|
|
|
|
import math
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ... import initialization as init
|
|
from ...activations import ACT2FN
|
|
from ...generation import GenerationMixin
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
BaseModelOutputWithPooling,
|
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
CausalLMOutputWithPast,
|
|
Seq2SeqLMOutput,
|
|
)
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...pytorch_utils import apply_chunking_to_forward
|
|
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
|
|
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
|
from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class BaseModelOutputWithVisionQformerOutputs(BaseModelOutputWithPooling):
|
|
r"""
|
|
vision_outputs (`BaseModelOutputWithPooling`):
|
|
Outputs of the vision encoder.
|
|
qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
|
Outputs of the Q-Former (Querying Transformer).
|
|
"""
|
|
|
|
vision_outputs: BaseModelOutputWithPooling | None = None
|
|
qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Class defining the outputs of [`InstructBlipForConditionalGeneration`].
|
|
"""
|
|
)
|
|
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlip
|
|
class InstructBlipForConditionalGenerationModelOutput(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
|
Language modeling loss from the language model.
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
|
Prediction scores of the language modeling head of the language model.
|
|
vision_outputs (`BaseModelOutputWithPooling`):
|
|
Outputs of the vision encoder.
|
|
qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
|
Outputs of the Q-Former (Querying Transformer).
|
|
language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
|
|
Outputs of the language model.
|
|
"""
|
|
|
|
loss: tuple[torch.FloatTensor] | None = None
|
|
logits: tuple[torch.FloatTensor] | None = None
|
|
vision_outputs: BaseModelOutputWithPooling | None = None
|
|
qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None
|
|
language_model_outputs: CausalLMOutputWithPast | Seq2SeqLMOutput | None = None
|
|
|
|
def to_tuple(self) -> tuple[Any]:
|
|
return tuple(
|
|
self[k]
|
|
if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
|
|
else getattr(self, k).to_tuple()
|
|
for k in self.keys()
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlip
|
|
class InstructBlipVisionEmbeddings(nn.Module):
|
|
def __init__(self, config: InstructBlipVisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
|
|
|
self.patch_embedding = nn.Conv2d(
|
|
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
|
)
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
self.num_positions = self.num_patches + 1
|
|
|
|
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
"""
|
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
|
images. This method is also adapted to support torch.jit tracing.
|
|
|
|
Adapted from:
|
|
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
|
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
|
"""
|
|
|
|
num_patches = embeddings.shape[1] - 1
|
|
num_positions = self.position_embedding.shape[1] - 1
|
|
|
|
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
|
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
|
return self.position_embedding
|
|
|
|
class_pos_embed = self.position_embedding[:, :1]
|
|
patch_pos_embed = self.position_embedding[:, 1:]
|
|
|
|
dim = embeddings.shape[-1]
|
|
|
|
new_height = height // self.patch_size
|
|
new_width = width // self.patch_size
|
|
|
|
sqrt_num_positions = torch_int(num_positions**0.5)
|
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
|
|
|
patch_pos_embed = nn.functional.interpolate(
|
|
patch_pos_embed,
|
|
size=(new_height, new_width),
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
)
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
|
|
|
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
|
batch_size, _, height, width = pixel_values.shape
|
|
target_dtype = self.patch_embedding.weight.dtype
|
|
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
|
if interpolate_pos_encoding:
|
|
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
|
|
else:
|
|
position_embedding = self.position_embedding
|
|
embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
|
return embeddings
|
|
|
|
|
|
# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBLIP doesn't cast attn weights to fp32
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: torch.Tensor | None,
|
|
scaling: float,
|
|
dropout: float = 0.0,
|
|
**kwargs,
|
|
):
|
|
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
|
if attention_mask is not None:
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip
|
|
class InstructBlipAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
if self.head_dim * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
self.scale = self.head_dim**-0.5
|
|
self.is_causal = False
|
|
self.attention_dropout = config.attention_dropout
|
|
|
|
# small tweak here compared to CLIP, no bias here
|
|
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
|
|
|
|
if config.qkv_bias:
|
|
q_bias = nn.Parameter(torch.zeros(self.embed_dim))
|
|
v_bias = nn.Parameter(torch.zeros(self.embed_dim))
|
|
else:
|
|
q_bias = None
|
|
v_bias = None
|
|
|
|
if q_bias is not None:
|
|
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
|
self.qkv.bias = nn.Parameter(qkv_bias)
|
|
|
|
self.projection = nn.Linear(self.embed_dim, self.embed_dim)
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
bsz, tgt_len, embed_dim = hidden_states.size()
|
|
|
|
mixed_qkv = self.qkv(hidden_states)
|
|
|
|
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
|
|
2, 0, 3, 1, 4
|
|
)
|
|
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask=None,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scale,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
|
attn_output = self.projection(attn_output)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
# Copied from transformers.models.blip.modeling_blip.BlipMLP
|
|
class InstructBlipMLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
|
|
class InstructBlipEncoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: InstructBlipConfig):
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.self_attn = InstructBlipAttention(config)
|
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
self.mlp = InstructBlipMLP(config)
|
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.FloatTensor:
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states, _ = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
**kwargs,
|
|
)
|
|
hidden_states = hidden_states + residual
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
hidden_states = hidden_states + residual
|
|
|
|
return hidden_states
|
|
|
|
|
|
@auto_docstring
|
|
class InstructBlipPreTrainedModel(PreTrainedModel):
|
|
config: InstructBlipConfig
|
|
base_model_prefix = "blip"
|
|
input_modalities = ("image", "text")
|
|
supports_gradient_checkpointing = True
|
|
_supports_attention_backend = True
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
_supports_flex_attn = True
|
|
|
|
_can_compile_fullgraph = True
|
|
|
|
_no_split_modules = [
|
|
"InstructBlipQFormerEmbeddings",
|
|
"InstructBlipAttention",
|
|
"InstructBlipQFormerMultiHeadAttention",
|
|
"InstructBlipQFormerSelfOutput",
|
|
]
|
|
|
|
@torch.no_grad()
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
super()._init_weights(module)
|
|
factor = self.config.initializer_range
|
|
if isinstance(module, InstructBlipVisionEmbeddings):
|
|
init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
|
init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
|
elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)):
|
|
init.zeros_(module.query_tokens)
|
|
elif isinstance(module, InstructBlipQFormerEmbeddings):
|
|
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
|
|
|
|
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlip
|
|
class InstructBlipEncoder(nn.Module):
|
|
"""
|
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
|
[`InstructBlipEncoderLayer`].
|
|
|
|
Args:
|
|
config (`InstructBlipConfig`):
|
|
The corresponding vision configuration for the `InstructBlipEncoder`.
|
|
"""
|
|
|
|
def __init__(self, config: InstructBlipConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layers = nn.ModuleList([InstructBlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.gradient_checkpointing = False
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
inputs_embeds,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | BaseModelOutput:
|
|
hidden_states = inputs_embeds
|
|
for encoder_layer in self.layers:
|
|
hidden_states = encoder_layer(
|
|
hidden_states,
|
|
**kwargs,
|
|
)
|
|
|
|
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
|
|
|
|
class InstructBlipVisionModel(InstructBlipPreTrainedModel):
|
|
main_input_name = "pixel_values"
|
|
input_modalities = ("image",)
|
|
config: InstructBlipVisionConfig
|
|
_can_record_outputs = {
|
|
"hidden_states": InstructBlipEncoderLayer,
|
|
"attentions": InstructBlipAttention,
|
|
}
|
|
|
|
def __init__(self, config: InstructBlipVisionConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
embed_dim = config.hidden_size
|
|
|
|
self.embeddings = InstructBlipVisionEmbeddings(config)
|
|
self.encoder = InstructBlipEncoder(config)
|
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
|
|
|
self.post_init()
|
|
|
|
@check_model_inputs(tie_last_hidden_states=False)
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
interpolate_pos_encoding: bool = False,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | BaseModelOutputWithPooling:
|
|
if pixel_values is None:
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
|
|
|
encoder_outputs: BaseModelOutput = self.encoder(
|
|
inputs_embeds=hidden_states,
|
|
**kwargs,
|
|
)
|
|
|
|
last_hidden_state = encoder_outputs.last_hidden_state
|
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
|
|
pooled_output = last_hidden_state[:, 0, :]
|
|
pooled_output = self.post_layernorm(pooled_output)
|
|
|
|
return BaseModelOutputWithPooling(
|
|
last_hidden_state=last_hidden_state,
|
|
pooler_output=pooled_output,
|
|
)
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings
|
|
|
|
|
|
class InstructBlipQFormerMultiHeadAttention(nn.Module):
|
|
def __init__(self, config, is_cross_attention=False):
|
|
super().__init__()
|
|
self.config = config
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
raise ValueError(
|
|
"The hidden size (%d) is not a multiple of the number of attention heads (%d)"
|
|
% (config.hidden_size, config.num_attention_heads)
|
|
)
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
if is_cross_attention:
|
|
self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
|
|
else:
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
self.save_attention = False
|
|
|
|
def save_attn_gradients(self, attn_gradients):
|
|
self.attn_gradients = attn_gradients
|
|
|
|
def get_attn_gradients(self):
|
|
return self.attn_gradients
|
|
|
|
def save_attention_map(self, attention_map):
|
|
self.attention_map = attention_map
|
|
|
|
def get_attention_map(self):
|
|
return self.attention_map
|
|
|
|
def transpose_for_scores(self, x):
|
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
|
x = x.view(*new_x_shape)
|
|
return x.permute(0, 2, 1, 3)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
# If this is instantiated as a cross-attention module, the keys
|
|
# and values come from an encoder; the attention mask needs to be
|
|
# such that the encoder's padding tokens are not attended to.
|
|
is_cross_attention = encoder_hidden_states is not None
|
|
|
|
if is_cross_attention:
|
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
|
attention_mask = encoder_attention_mask
|
|
else:
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
|
|
mixed_query_layer = self.query(hidden_states)
|
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
attention_scores_dtype = attention_scores.dtype
|
|
|
|
if attention_mask is not None:
|
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
|
|
|
|
if is_cross_attention and self.save_attention:
|
|
self.save_attention_map(attention_probs)
|
|
attention_probs.register_hook(self.save_attn_gradients)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs_dropped = self.dropout(attention_probs)
|
|
|
|
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
context_layer = context_layer.view(*new_context_layer_shape)
|
|
|
|
return context_layer, attention_probs
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipQFormer
|
|
class InstructBlipQFormerSelfOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlip
|
|
class InstructBlipQFormerAttention(nn.Module):
|
|
def __init__(self, config, is_cross_attention=False):
|
|
super().__init__()
|
|
self.attention = InstructBlipQFormerMultiHeadAttention(config, is_cross_attention)
|
|
self.output = InstructBlipQFormerSelfOutput(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
encoder_hidden_states: torch.FloatTensor | None = None,
|
|
encoder_attention_mask: torch.FloatTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.Tensor:
|
|
attn_output, _ = self.attention(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
**kwargs,
|
|
)
|
|
attention_output = self.output(attn_output, hidden_states)
|
|
return attention_output
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipQFormer
|
|
class InstructBlipQFormerIntermediate(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.intermediate_act_fn = config.hidden_act
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipQFormer
|
|
class InstructBlipQFormerOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class InstructBlipQFormerLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config, layer_idx):
|
|
super().__init__()
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
self.seq_len_dim = 1
|
|
self.attention = InstructBlipQFormerAttention(config)
|
|
|
|
self.layer_idx = layer_idx
|
|
|
|
if layer_idx % config.cross_attention_frequency == 0:
|
|
self.crossattention = InstructBlipQFormerAttention(config, is_cross_attention=True)
|
|
self.has_cross_attention = True
|
|
else:
|
|
self.has_cross_attention = False
|
|
|
|
self.intermediate = InstructBlipQFormerIntermediate(config)
|
|
self.output = InstructBlipQFormerOutput(config)
|
|
|
|
self.intermediate_query = InstructBlipQFormerIntermediate(config)
|
|
self.output_query = InstructBlipQFormerOutput(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
query_length=0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
attention_output = self.attention(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
if query_length > 0:
|
|
query_attention_output = attention_output[:, :query_length, :]
|
|
|
|
if self.has_cross_attention:
|
|
if encoder_hidden_states is None:
|
|
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
|
|
query_attention_output = self.crossattention(
|
|
query_attention_output,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
layer_output = apply_chunking_to_forward(
|
|
self.feed_forward_chunk_query,
|
|
self.chunk_size_feed_forward,
|
|
self.seq_len_dim,
|
|
query_attention_output,
|
|
)
|
|
|
|
if attention_output.shape[1] > query_length:
|
|
layer_output_text = apply_chunking_to_forward(
|
|
self.feed_forward_chunk,
|
|
self.chunk_size_feed_forward,
|
|
self.seq_len_dim,
|
|
attention_output[:, query_length:, :],
|
|
).to(layer_output.device)
|
|
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
|
else:
|
|
layer_output = apply_chunking_to_forward(
|
|
self.feed_forward_chunk,
|
|
self.chunk_size_feed_forward,
|
|
self.seq_len_dim,
|
|
attention_output,
|
|
)
|
|
return layer_output
|
|
|
|
def feed_forward_chunk(self, attention_output):
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
return layer_output
|
|
|
|
def feed_forward_chunk_query(self, attention_output):
|
|
intermediate_output = self.intermediate_query(attention_output)
|
|
layer_output = self.output_query(intermediate_output, attention_output)
|
|
return layer_output
|
|
|
|
|
|
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlip
|
|
class InstructBlipQFormerEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer = nn.ModuleList(
|
|
[InstructBlipQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
)
|
|
self.gradient_checkpointing = False
|
|
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
query_length=0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
for i in range(self.config.num_hidden_layers):
|
|
layer_module = self.layer[i]
|
|
|
|
hidden_states = layer_module(
|
|
hidden_states,
|
|
attention_mask,
|
|
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
query_length=query_length,
|
|
**kwargs,
|
|
)
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
)
|
|
|
|
|
|
class InstructBlipQFormerEmbeddings(nn.Module):
|
|
"""Construct the embeddings from word and position embeddings."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
|
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
self.register_buffer(
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
|
)
|
|
|
|
self.config = config
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
position_ids=None,
|
|
query_embeds=None,
|
|
past_key_values_length=0,
|
|
):
|
|
if input_ids is not None:
|
|
seq_length = input_ids.size()[1]
|
|
else:
|
|
seq_length = 0
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
|
|
|
|
if input_ids is not None:
|
|
embeddings = self.word_embeddings(input_ids)
|
|
|
|
position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
|
|
embeddings = embeddings + position_embeddings
|
|
|
|
if query_embeds is not None:
|
|
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
|
else:
|
|
embeddings = query_embeds
|
|
|
|
embeddings = embeddings.to(self.layernorm.weight.dtype)
|
|
embeddings = self.layernorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
|
|
class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
|
|
"""
|
|
Querying Transformer (Q-Former), used in InstructBLIP. Slightly modified from BLIP-2 as it also takes the
|
|
instruction as input.
|
|
"""
|
|
|
|
_supports_attention_backend = False # adds position on attn weights before last matmul
|
|
_supports_flash_attn = False
|
|
_supports_sdpa = False
|
|
_supports_flex_attn = False
|
|
|
|
_can_record_outputs = {
|
|
"hidden_states": InstructBlipQFormerLayer,
|
|
"attentions": [
|
|
OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".attention"),
|
|
],
|
|
"cross_attentions": [
|
|
OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".crossattention"),
|
|
],
|
|
}
|
|
|
|
def __init__(self, config: InstructBlipQFormerConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.embeddings = InstructBlipQFormerEmbeddings(config)
|
|
|
|
self.encoder = InstructBlipQFormerEncoder(config)
|
|
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
def get_extended_attention_mask(
|
|
self,
|
|
attention_mask: torch.Tensor,
|
|
input_shape: tuple[int],
|
|
device: torch.device,
|
|
has_query: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
|
|
|
Arguments:
|
|
attention_mask (`torch.Tensor`):
|
|
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
|
input_shape (`tuple[int]`):
|
|
The shape of the input to the model.
|
|
device: (`torch.device`):
|
|
The device of the input to the model.
|
|
|
|
Returns:
|
|
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
|
"""
|
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
if attention_mask.dim() == 3:
|
|
extended_attention_mask = attention_mask[:, None, :, :]
|
|
elif attention_mask.dim() == 2:
|
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
extended_attention_mask = attention_mask[:, None, None, :]
|
|
else:
|
|
raise ValueError(
|
|
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})",
|
|
)
|
|
|
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
# positions we want to attend and -10000.0 for masked positions.
|
|
# Since we are adding it to the raw scores before the softmax, this is
|
|
# effectively the same as removing these entirely.
|
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
return extended_attention_mask
|
|
|
|
@check_model_inputs
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
query_embeds: torch.Tensor | None = None,
|
|
encoder_hidden_states: torch.FloatTensor | None = None,
|
|
encoder_attention_mask: torch.FloatTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.FloatTensor] | BaseModelOutputWithPoolingAndCrossAttentions:
|
|
r"""
|
|
query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Hidden states to be used in the attention computation. If cross-attention,
|
|
will be used for the query (i.e., key and value will use the encoder_hidden_states).
|
|
"""
|
|
if input_ids is None and query_embeds is None:
|
|
raise ValueError("You have to specify query_embeds when input_ids is None")
|
|
|
|
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
|
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
query_embeds=query_embeds,
|
|
)
|
|
|
|
input_shape = embedding_output.size()[:-1]
|
|
batch_size, seq_length = input_shape
|
|
device = embedding_output.device
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
|
|
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
|
|
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
if encoder_hidden_states is not None:
|
|
if isinstance(encoder_hidden_states, list):
|
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
|
else:
|
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
|
|
|
if isinstance(encoder_attention_mask, list):
|
|
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
|
elif encoder_attention_mask is None:
|
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
else:
|
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
else:
|
|
encoder_extended_attention_mask = None
|
|
|
|
encoder_outputs: BaseModelOutput = self.encoder(
|
|
embedding_output,
|
|
attention_mask=extended_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_extended_attention_mask,
|
|
query_length=query_length,
|
|
**kwargs,
|
|
)
|
|
sequence_output = encoder_outputs.last_hidden_state
|
|
pooled_output = sequence_output[:, 0, :]
|
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
InstructBLIP base Model consisting of language model, qformer and vision encoder.
|
|
"""
|
|
)
|
|
class InstructBlipModel(InstructBlipPreTrainedModel):
|
|
main_input_name = "pixel_values"
|
|
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
|
|
|
|
def __init__(self, config: InstructBlipConfig):
|
|
super().__init__(config)
|
|
|
|
self.vision_model = InstructBlipVisionModel(config.vision_config)
|
|
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
|
self.qformer = InstructBlipQFormerModel(config.qformer_config)
|
|
|
|
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
|
self.language_model = AutoModel.from_config(config.text_config)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.language_model.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.language_model.set_input_embeddings(value)
|
|
|
|
def _preprocess_accelerate(self):
|
|
r"""
|
|
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
|
https://github.com/huggingface/transformers/pull/21707 for more details.
|
|
"""
|
|
hf_device_map = self.hf_device_map
|
|
|
|
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
|
# warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
|
|
logger.warning(
|
|
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
|
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
|
" Please pass a `device_map` that contains `language_model` to remove this warning."
|
|
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
|
|
" more details on creating a `device_map` for large models.",
|
|
)
|
|
|
|
if hasattr(self.language_model, "_hf_hook"):
|
|
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
|
|
|
def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
|
|
"""
|
|
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
|
|
"""
|
|
if input_ids is None:
|
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
special_image_mask = special_image_mask.all(-1)
|
|
else:
|
|
special_image_mask = input_ids == self.config.image_token_id
|
|
|
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
return special_image_mask
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
qformer_input_ids: torch.FloatTensor,
|
|
qformer_attention_mask: torch.LongTensor | None = None,
|
|
input_ids: torch.FloatTensor | None = None,
|
|
attention_mask: torch.LongTensor | None = None,
|
|
decoder_input_ids: torch.LongTensor | None = None,
|
|
decoder_attention_mask: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
interpolate_pos_encoding: bool = False,
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> tuple | InstructBlipForConditionalGenerationModelOutput:
|
|
r"""
|
|
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
|
to serve as text prompt, which the Q-Former model will encode.
|
|
|
|
Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
|
|
details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
|
be used by default.
|
|
|
|
Only relevant in case an encoder-decoder language model (like T5) is used.
|
|
"""
|
|
|
|
# step 1: forward the images through the vision encoder,
|
|
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
|
vision_outputs = self.vision_model(
|
|
pixel_values=pixel_values,
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
**kwargs,
|
|
)
|
|
image_embeds = vision_outputs[0]
|
|
|
|
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
|
|
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
if qformer_attention_mask is None:
|
|
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
|
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
|
query_outputs = self.qformer(
|
|
input_ids=qformer_input_ids,
|
|
attention_mask=qformer_attention_mask,
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_attention_mask,
|
|
**kwargs,
|
|
)
|
|
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
# step 3: use the language model, conditioned on the query outputs and the prompt
|
|
language_model_inputs = self.language_projection(query_output)
|
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
|
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
|
|
if self.config.use_decoder_only_language_model:
|
|
outputs = self.language_model(
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
outputs = self.language_model(
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
decoder_input_ids=decoder_input_ids,
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
return InstructBlipForConditionalGenerationModelOutput(
|
|
vision_outputs=vision_outputs,
|
|
qformer_outputs=query_outputs,
|
|
language_model_outputs=outputs,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision
|
|
encoder, Querying Transformer (Q-Former) and a language model.
|
|
|
|
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
|
the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
|
|
"""
|
|
)
|
|
class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
|
|
config: InstructBlipConfig
|
|
main_input_name = "pixel_values"
|
|
|
|
_can_compile_fullgraph = True
|
|
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
|
|
|
|
def __init__(self, config: InstructBlipConfig):
|
|
super().__init__(config)
|
|
|
|
self.vision_model = InstructBlipVisionModel._from_config(config.vision_config)
|
|
|
|
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
|
self.qformer = InstructBlipQFormerModel._from_config(config.qformer_config)
|
|
|
|
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
|
|
|
if config.use_decoder_only_language_model:
|
|
language_model = AutoModelForCausalLM.from_config(config.text_config)
|
|
else:
|
|
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
|
|
|
self.language_model = language_model
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.language_model.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.language_model.set_input_embeddings(value)
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.language_model.set_output_embeddings(new_embeddings)
|
|
|
|
def get_output_embeddings(self) -> nn.Module:
|
|
return self.language_model.get_output_embeddings()
|
|
|
|
def get_encoder(self, modality=None):
|
|
if modality is None:
|
|
return self.language_model.get_encoder()
|
|
else:
|
|
return super().get_encoder(modality=modality)
|
|
|
|
def get_decoder(self):
|
|
return self.language_model.get_decoder()
|
|
|
|
# Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate
|
|
def _preprocess_accelerate(self):
|
|
r"""
|
|
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
|
https://github.com/huggingface/transformers/pull/21707 for more details.
|
|
"""
|
|
hf_device_map = self.hf_device_map
|
|
|
|
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
|
# warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
|
|
logger.warning(
|
|
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
|
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
|
" Please pass a `device_map` that contains `language_model` to remove this warning."
|
|
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
|
|
" more details on creating a `device_map` for large models.",
|
|
)
|
|
|
|
if hasattr(self.language_model, "_hf_hook"):
|
|
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
qformer_input_ids: torch.LongTensor,
|
|
qformer_attention_mask: torch.LongTensor | None = None,
|
|
interpolate_pos_encoding: bool | None = False,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | BaseModelOutputWithVisionQformerOutputs:
|
|
r"""
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
The tensors corresponding to the input images.
|
|
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
|
to serve as text prompt, which the Q-Former model will encode.
|
|
|
|
Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
|
|
details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
"""
|
|
# step 1: forward the images through the vision encoder,
|
|
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
|
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
|
|
pixel_values=pixel_values,
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
return_dict=True,
|
|
**kwargs,
|
|
)
|
|
vision_outputs = BaseModelOutputWithVisionQformerOutputs(**vision_outputs, vision_outputs=vision_outputs)
|
|
image_embeds = vision_outputs[0]
|
|
|
|
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
|
|
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
if qformer_attention_mask is None:
|
|
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
|
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
|
qformer_outputs = self.qformer(
|
|
input_ids=qformer_input_ids,
|
|
attention_mask=qformer_attention_mask,
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_attention_mask,
|
|
return_dict=True,
|
|
**kwargs,
|
|
)
|
|
vision_outputs.qformer_outputs = qformer_outputs
|
|
query_output = qformer_outputs[0][:, : query_tokens.size(1), :]
|
|
|
|
# step 3: use the language model, conditioned on the query outputs and the prompt
|
|
image_features = self.language_projection(query_output)
|
|
vision_outputs.pooler_output = image_features
|
|
|
|
return vision_outputs
|
|
|
|
def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
|
|
"""
|
|
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
|
|
"""
|
|
if input_ids is None:
|
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
special_image_mask = special_image_mask.all(-1)
|
|
else:
|
|
special_image_mask = input_ids == self.config.image_token_id
|
|
|
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
return special_image_mask
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
qformer_input_ids: torch.FloatTensor,
|
|
qformer_attention_mask: torch.LongTensor | None = None,
|
|
input_ids: torch.FloatTensor | None = None,
|
|
attention_mask: torch.LongTensor | None = None,
|
|
decoder_input_ids: torch.LongTensor | None = None,
|
|
decoder_attention_mask: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
labels: torch.LongTensor | None = None,
|
|
interpolate_pos_encoding: bool = False,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | InstructBlipForConditionalGenerationModelOutput:
|
|
r"""
|
|
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
|
to serve as text prompt, which the Q-Former model will encode.
|
|
|
|
Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
|
|
details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
|
be used by default.
|
|
|
|
Only relevant in case an encoder-decoder language model (like T5) is used.
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
|
|
1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
|
config.vocab_size]`
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
|
|
>>> import torch
|
|
>>> from PIL import Image
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
|
|
>>> model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
|
>>> processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
|
|
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
|
|
|
>>> url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read())).convert("RGB")
|
|
>>> prompt = "What is unusual about this image?"
|
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
|
|
|
|
>>> outputs = model.generate(
|
|
... **inputs,
|
|
... do_sample=False,
|
|
... num_beams=5,
|
|
... max_length=256,
|
|
... min_length=1,
|
|
... top_p=0.9,
|
|
... repetition_penalty=1.5,
|
|
... length_penalty=1.0,
|
|
... temperature=1,
|
|
... )
|
|
>>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
|
>>> print(generated_text)
|
|
The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation.
|
|
```"""
|
|
|
|
image_features: BaseModelOutputWithVisionQformerOutputs = self.get_image_features(
|
|
pixel_values,
|
|
qformer_input_ids=qformer_input_ids,
|
|
qformer_attention_mask=qformer_attention_mask,
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
return_dict=True,
|
|
)
|
|
language_model_inputs = image_features.pooler_output
|
|
qformer_outputs = image_features.qformer_outputs
|
|
vision_outputs = image_features.vision_outputs
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
|
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
|
|
if self.config.use_decoder_only_language_model:
|
|
outputs = self.language_model(
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
**kwargs,
|
|
)
|
|
logits = outputs[0]
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(
|
|
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
|
)
|
|
|
|
else:
|
|
kwargs["return_dict"] = True
|
|
outputs = self.language_model(
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
decoder_input_ids=decoder_input_ids,
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
labels=labels,
|
|
**kwargs,
|
|
)
|
|
loss = outputs.loss
|
|
logits = outputs.logits
|
|
|
|
return InstructBlipForConditionalGenerationModelOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
vision_outputs=vision_outputs,
|
|
qformer_outputs=qformer_outputs,
|
|
language_model_outputs=outputs,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
qformer_input_ids: torch.LongTensor | None = None,
|
|
qformer_attention_mask: torch.LongTensor | None = None,
|
|
input_ids: torch.LongTensor | None = None,
|
|
attention_mask: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
interpolate_pos_encoding: bool = False,
|
|
**generate_kwargs,
|
|
) -> torch.LongTensor:
|
|
"""
|
|
Overrides `generate` function to be able to use the model as a conditional generator.
|
|
|
|
Args:
|
|
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
|
|
Input images to be processed.
|
|
qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
|
The sequence used as a prompt to be fed to the Q-Former module.
|
|
qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
|
Mask to avoid performing attention on padding token indices.
|
|
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
|
The sequence used as a prompt for the generation.
|
|
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
|
Mask to avoid performing attention on padding token indices.
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Embedded representation of the inputs. Should be float, not int tokens.
|
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
|
Whether to interpolate the positional encoding of the image embeddings.
|
|
|
|
Returns:
|
|
captions (list): A list of strings of length batch_size * num_captions.
|
|
"""
|
|
if hasattr(self, "hf_device_map"):
|
|
# preprocess for `accelerate`
|
|
self._preprocess_accelerate()
|
|
|
|
batch_size = pixel_values.shape[0]
|
|
image_features: BaseModelOutputWithVisionQformerOutputs = self.get_image_features(
|
|
pixel_values,
|
|
qformer_input_ids=qformer_input_ids,
|
|
qformer_attention_mask=qformer_attention_mask,
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
return_dict=True,
|
|
)
|
|
language_model_inputs = image_features.pooler_output
|
|
|
|
if inputs_embeds is None:
|
|
if input_ids is None:
|
|
image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
|
|
start_tokens = image_tokens + [self.config.text_config.bos_token_id]
|
|
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
|
input_ids = input_ids.repeat(batch_size, 1)
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
|
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
|
|
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
|
if not self.language_model.config.is_encoder_decoder:
|
|
inputs["input_ids"] = input_ids
|
|
|
|
outputs = self.language_model.generate(**inputs, **generate_kwargs)
|
|
|
|
return outputs
|
|
|
|
|
|
__all__ = [
|
|
"InstructBlipQFormerModel",
|
|
"InstructBlipPreTrainedModel",
|
|
"InstructBlipModel",
|
|
"InstructBlipForConditionalGeneration",
|
|
"InstructBlipVisionModel",
|
|
]
|