You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1882 lines
80 KiB

# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BridgeTower Model"""
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ... import initialization as init
from ...activations import ACT2FN, QuickGELUActivation
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...masking_utils import create_bidirectional_mask, create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
ModelOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import TransformersKwargs, auto_docstring, logging, torch_int
from ...utils.generic import can_return_tuple
from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig
logger = logging.get_logger(__name__)
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
@dataclass
@auto_docstring(
custom_intro="""
Output type of [`BridgeTowerModel`].
"""
)
class BridgeTowerModelOutput(ModelOutput):
r"""
text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`):
Sequence of hidden-states at the text output of the last layer of the model.
image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`):
Sequence of hidden-states at the image output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`):
Concatenation of last layer hidden-state of the first token of the text and image sequence (classification
token), respectively, after further processing through layers used for auxiliary pretraining tasks.
"""
text_features: torch.FloatTensor | None = None
image_features: torch.FloatTensor | None = None
pooler_output: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
@dataclass
@auto_docstring(
custom_intro="""
Output type of ['BridgeTowerForContrastiveLearning']
"""
)
class BridgeTowerContrastiveOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Image-text contrastive loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
"""
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor | None = None
text_embeds: tuple[torch.FloatTensor] | None = None
image_embeds: tuple[torch.FloatTensor] | None = None
cross_embeds: tuple[torch.FloatTensor] | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
class BridgeTowerResidualAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = nn.MultiheadAttention(config.hidden_size, config.hidden_size // 64)
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = nn.ModuleDict(
OrderedDict(
[
("c_fc", nn.Linear(config.hidden_size, config.hidden_size * 4)),
("gelu", QuickGELUActivation()),
("c_proj", nn.Linear(config.hidden_size * 4, config.hidden_size)),
]
)
)
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn_mask = None
def attention(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor):
if attention_mask is not None:
attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_state.device)
self.attn_mask = (
self.attn_mask.to(dtype=hidden_state.dtype, device=hidden_state.device)
if self.attn_mask is not None
else None
)
return self.attn(
hidden_state,
hidden_state,
hidden_state,
need_weights=False,
attn_mask=self.attn_mask,
key_padding_mask=attention_mask,
)[0]
def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None):
residual_state = hidden_state + self.attention(self.ln_1(hidden_state), attention_mask)
hidden_state = self.ln_2(residual_state)
for layer in self.mlp.values():
hidden_state = layer(hidden_state)
hidden_state = residual_state + hidden_state
return hidden_state
class BridgeTowerTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
if config.remove_last_layer:
self.resblocks = nn.ModuleList(
[BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers - 1)]
)
else:
self.resblocks = nn.ModuleList(
[BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers)]
)
self.stop_gradient = config.stop_gradient
def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None):
hidden_states = []
for block in self.resblocks:
hidden_state = block(hidden_state, attention_mask)
if self.stop_gradient:
hidden_states.append(hidden_state.detach())
else:
hidden_states.append(hidden_state)
return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->BridgeTower
class BridgeTowerVisionEmbeddings(nn.Module):
def __init__(self, config: BridgeTowerVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1] - 1
position_embedding = self.position_embedding.weight.unsqueeze(0)
num_positions = position_embedding.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embedding(self.position_ids)
class_pos_embed = position_embedding[:, :1]
patch_pos_embed = position_embedding[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
)
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class BridgeTowerVisionTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.embeddings = BridgeTowerVisionEmbeddings(config)
self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.transformer = BridgeTowerTransformer(config)
self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.share_layernorm = config.share_layernorm
if not config.share_layernorm:
self.ln_separate = nn.ModuleList(
[nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)]
)
def forward(
self,
pixel_values: torch.Tensor,
attention_mask,
interpolate_pos_encoding: bool = False,
):
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding)
hidden_states = self.ln_pre(hidden_states)
# NLD -> LND
hidden_states = hidden_states.permute(1, 0, 2)
hidden_states = self.transformer(hidden_states, attention_mask)
# shape = [num_hidden_layers, hidden_size, *, grid ** 2]
hidden_states = torch.stack(hidden_states, dim=0)
# shape = [num_hidden_layers, *, hidden_size, grid ** 2]
hidden_states = hidden_states.permute(0, 2, 1, 3)
if self.share_layernorm:
hidden_states = self.ln_post(hidden_states)
else:
hidden_states_stack = []
for hidden_states, ln in zip(hidden_states, self.ln_separate):
hidden_states = ln(hidden_states)
hidden_states_stack.append(hidden_states)
# shape = [num_hidden_layers, *, hidden_size, grid ** 2]
hidden_states = torch.stack(hidden_states_stack, dim=0)
return hidden_states
def forward_pre(
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False,
):
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.ln_pre(hidden_states)
# NLD -> LND
hidden_states = hidden_states.permute(1, 0, 2)
return hidden_states
def forward_post(self, hidden_state: torch.Tensor):
visual_output_post = hidden_state.permute(1, 0, 2)
visual_output_post = self.ln_post(visual_output_post)
return visual_output_post
class BridgeTowerLinkTower(nn.Module):
def __init__(self, config):
super().__init__()
self.link_tower_type = config.link_tower_type
self.hidden_size = config.hidden_size
if config.link_tower_type in ["add", "scaled_add", "interpolate"]:
if config.link_tower_type == "scaled_add":
self.scaled_factor = nn.Parameter(torch.tensor(1.0))
elif config.link_tower_type == "interpolate":
self.beta = nn.Parameter(torch.tensor(0.5))
self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
else:
raise NotImplementedError(f"link_tower_type {config.link_tower_type} is not implemented")
def forward(self, hidden_states, cross_modal_hidden_states, attention_mask):
if self.link_tower_type == "add":
return self.LayerNorm(hidden_states + cross_modal_hidden_states)
elif self.link_tower_type == "scaled_add":
return self.LayerNorm(hidden_states * self.scaled_factor + cross_modal_hidden_states)
elif self.link_tower_type == "interpolate":
return self.LayerNorm(hidden_states * (1 - self.beta) + cross_modal_hidden_states * self.beta)
else:
raise NotImplementedError(f"link_tower_type {self.link_tower_type} is not implemented")
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->BridgeTower
class BridgeTowerSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->BridgeTower
class BridgeTowerIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->BridgeTower
class BridgeTowerOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->BridgeTower
class BridgeTowerPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float | None = None,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower
class BridgeTowerSelfAttention(nn.Module):
def __init__(self, config, is_causal=False, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.is_causal = is_causal
self.layer_idx = layer_idx
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
past_key_values: Cache | None = None,
cache_position: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# get all proj
query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
if past_key_values is not None:
# decoder-only roberta can have a simple dynamic cache for example
current_past_key_values = past_key_values
if isinstance(past_key_values, EncoderDecoderCache):
current_past_key_values = past_key_values.self_attention_cache
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
key_layer, value_layer = current_past_key_values.update(
key_layer,
value_layer,
self.layer_idx,
{"cache_position": cache_position},
)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.roberta.modeling_roberta.RobertaCrossAttention with Roberta->BridgeTower
class BridgeTowerCrossAttention(nn.Module):
def __init__(self, config, is_causal=False, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.is_causal = is_causal
self.layer_idx = layer_idx
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.FloatTensor | None = None,
attention_mask: torch.FloatTensor | None = None,
past_key_values: EncoderDecoderCache | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
src_len = encoder_hidden_states.shape[1]
q_input_shape = (bsz, tgt_len, -1, self.attention_head_size)
kv_input_shape = (bsz, src_len, -1, self.attention_head_size)
# get query proj
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
if past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
else:
key_layer = self.key(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
value_layer = self.value(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
if past_key_values is not None:
# save all states to the cache
key_layer, value_layer = past_key_values.cross_attention_cache.update(
key_layer, value_layer, self.layer_idx
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_values.is_updated[self.layer_idx] = True
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER
class BridgeTowerAttention(nn.Module):
def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
super().__init__()
self.is_cross_attention = is_cross_attention
attention_class = BridgeTowerCrossAttention if is_cross_attention else BridgeTowerSelfAttention
self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
self.output = BridgeTowerSelfOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
encoder_hidden_states: torch.FloatTensor | None = None,
encoder_attention_mask: torch.FloatTensor | None = None,
past_key_values: Cache | None = None,
cache_position: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
attention_output, attn_weights = self.self(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
attention_output = self.output(attention_output, hidden_states)
return attention_output, attn_weights
class BridgeTowerBertCrossLayer(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BridgeTowerAttention(config, is_causal=True, layer_idx=layer_idx)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
self.crossattention = BridgeTowerAttention(
config,
is_causal=False,
layer_idx=layer_idx,
is_cross_attention=True,
)
self.intermediate = BridgeTowerIntermediate(config)
self.output = BridgeTowerOutput(config)
def forward(
self,
hidden_states,
encoder_hidden_states,
attention_mask=None,
encoder_attention_mask=None,
past_key_values=None,
**kwargs: Unpack[TransformersKwargs],
):
self_attention_output, self_attn_weights = self.attention(
hidden_states,
attention_mask=attention_mask,
past_key_values=None,
**kwargs,
)
attention_output = self_attention_output
cross_attention_output, cross_attn_weights = self.crossattention(
attention_output,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
**kwargs,
)
attention_output = cross_attention_output
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
return (
layer_output,
self_attn_weights,
cross_attn_weights,
)
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BridgeTowerTextLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BridgeTowerAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = BridgeTowerAttention(
config,
is_causal=False,
layer_idx=layer_idx,
is_cross_attention=True,
)
self.intermediate = BridgeTowerIntermediate(config)
self.output = BridgeTowerOutput(config)
# copied from transformers.models.bert.modeling_bert.BertLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
encoder_hidden_states: torch.FloatTensor | None = None,
encoder_attention_mask: torch.FloatTensor | None = None,
past_key_values: Cache | None = None,
cache_position: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
outputs = ()
self_attention_output, self_attn_weights = self.attention(
hidden_states,
attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
attention_output = self_attention_output
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
cross_attention_output, cross_attn_weights = self.crossattention(
self_attention_output,
None, # attention_mask
encoder_hidden_states,
encoder_attention_mask,
past_key_values=past_key_values,
**kwargs,
)
attention_output = cross_attention_output
outputs = (cross_attn_weights,)
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
return outputs + (
layer_output,
self_attn_weights,
)
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
# copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText
class BridgeTowerTextEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[BridgeTowerTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
encoder_hidden_states: torch.FloatTensor | None = None,
encoder_attention_mask: torch.FloatTensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = False,
output_hidden_states: bool | None = False,
cache_position: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer):
layer_outputs = layer_module(
hidden_states,
attention_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->BridgeTowerText
class BridgeTowerTextEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
def forward(
self,
input_ids: torch.LongTensor | None = None,
token_type_ids: torch.LongTensor | None = None,
position_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(
input_ids, self.padding_idx, past_key_values_length
)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
# NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
@staticmethod
def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
Args:
inputs_embeds: torch.Tensor
Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
@staticmethod
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
@auto_docstring
class BridgeTowerPreTrainedModel(PreTrainedModel):
config: BridgeTowerConfig
base_model_prefix = "bridgetower"
input_modalities = ("image", "text")
supports_gradient_checkpointing = False
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
_skip_keys_device_placement = "past_key_values"
@torch.no_grad()
def _init_weights(self, module: nn.Module):
std = self.config.initializer_factor
if isinstance(module, BridgeTowerVisionTransformer):
proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5)
attn_std = self.config.hidden_size**-0.5
fc_std = (2 * self.config.hidden_size) ** -0.5
for block in module.transformer.resblocks:
init.normal_(block.attn.in_proj_weight, std=attn_std * std)
init.zeros_(block.attn.in_proj_bias)
init.normal_(block.attn.out_proj.weight, std=proj_std * std)
init.normal_(block.mlp.c_fc.weight, std=fc_std * std)
init.normal_(block.mlp.c_proj.weight, std=proj_std * std)
init.normal_(module.embeddings.class_embedding, std=attn_std * std)
init.normal_(module.embeddings.position_embedding.weight, std=attn_std * std)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)):
init.normal_(module.weight, mean=0.0, std=0.05 * std)
elif isinstance(module, nn.LayerNorm):
init.zeros_(module.bias)
init.ones_(module.weight)
elif isinstance(module, BridgeTowerForContrastiveLearning):
init.constant_(module.logit_scale, self.config.logit_scale_init_value)
elif isinstance(module, BridgeTowerVisionEmbeddings):
init.copy_(module.position_ids, torch.arange(module.num_positions).expand((1, -1)))
elif isinstance(module, BridgeTowerTextEmbeddings):
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
init.zeros_(module.token_type_ids)
if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None:
init.zeros_(module.bias)
class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
config: BridgeTowerVisionConfig
input_modalities = ("image",)
def __init__(self, config):
super().__init__(config)
self.visual = BridgeTowerVisionTransformer(config)
self.post_init()
@property
def dtype(self):
return self.visual.embeddings.patch_embedding.weight.dtype
def forward(self, image, image_mask=None, interpolate_pos_encoding=False, **kwargs):
return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding)
@auto_docstring(
custom_intro="""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
.. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
"""
)
class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
config: BridgeTowerTextConfig
input_modalities = ("text",)
def __init__(self, config, add_pooling_layer=True):
r"""
add_pooling_layer (bool, *optional*, defaults to `True`):
Whether to add a pooling layer
"""
super().__init__(config)
self.config = config
self.gradient_checkpointing = False
self.embeddings = BridgeTowerTextEmbeddings(config)
self.encoder = BridgeTowerTextEncoder(config)
self.pooler = BridgeTowerPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
@can_return_tuple
@auto_docstring
# NOTE: bridgetower with its multimodality has a more complicated scheme making records harder
# for now we skip the copies from bert but stay close to the original
# copied from transformers.models.bert.modeling_bert.BertModel.forward
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if input_ids is not None:
device = input_ids.device
input_shape = input_ids.shape
else:
device = inputs_embeds.device
input_shape = inputs_embeds.shape[:-1]
seq_length = input_shape[1]
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
attention_mask, encoder_attention_mask = self._create_attention_masks(
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
embedding_output=embedding_output,
encoder_hidden_states=encoder_hidden_states,
cache_position=cache_position,
past_key_values=past_key_values,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
position_ids=position_ids,
**kwargs,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
# Copied from transformers.models.bert.modeling_bert.BertModel._create_attention_masks
def _create_attention_masks(
self,
attention_mask,
encoder_attention_mask,
embedding_output,
encoder_hidden_states,
cache_position,
past_key_values,
):
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
)
if encoder_attention_mask is not None:
encoder_attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
)
return attention_mask, encoder_attention_mask
@auto_docstring(
custom_intro="""
The bare BridgeTower Model transformer outputting BridgeTowerModelOutput object without any specific head on
"""
)
class BridgeTowerModel(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
vision_config = config.vision_config
text_config = config.text_config
if config.share_cross_modal_transformer_layers:
self.cross_modal_text_transform = nn.Linear(text_config.hidden_size, config.hidden_size)
self.cross_modal_image_transform = nn.Linear(vision_config.hidden_size, config.hidden_size)
else:
self.cross_modal_text_transform = nn.ModuleList(
[nn.Linear(text_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
)
self.cross_modal_image_transform = nn.ModuleList(
[nn.Linear(vision_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
)
self.token_type_embeddings = nn.Embedding(2, config.hidden_size)
self.vision_model = BridgeTowerVisionModel(vision_config)
self.text_model = BridgeTowerTextModel(text_config)
if not vision_config.share_layernorm and config.init_layernorm_from_vision_encoder:
for ln in self.vision_model.visual.cross_modal_ln_separate:
ln.weight.data = self.vision_model.visual.ln_post.weight.data
ln.bias.data = self.vision_model.visual.ln_post.bias.data
self.cross_modal_image_layers = nn.ModuleList(
[BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)]
)
self.cross_modal_text_layers = nn.ModuleList(
[BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)]
)
# Class token => Linear => Tanh
self.cross_modal_image_pooler = BridgeTowerPooler(config)
self.cross_modal_text_pooler = BridgeTowerPooler(config)
# Initialize BridgeTower Components
self.cross_modal_text_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cross_modal_image_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if config.share_link_tower_layers:
self.cross_modal_text_link_tower = BridgeTowerLinkTower(config)
self.cross_modal_image_link_tower = BridgeTowerLinkTower(config)
else:
self.cross_modal_text_link_tower = nn.ModuleList(
[BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
)
self.cross_modal_image_link_tower = nn.ModuleList(
[BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
)
self.post_init()
def get_input_embeddings(self):
return self.text_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.FloatTensor | None = None,
token_type_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
pixel_mask: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
image_embeds: torch.FloatTensor | None = None,
image_token_type_idx: int | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
labels: torch.LongTensor | None = None,
interpolate_pos_encoding: bool = False,
**kwargs,
) -> tuple[torch.Tensor] | BridgeTowerModelOutput:
r"""
image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
image_token_type_idx (`int`, *optional*):
- The token type ids for images.
output_hidden_states (`bool`, *optional*):
If set to `True`, hidden states are returned as a list containing the hidden states of text, image, and
cross-modal components respectively. i.e. `(hidden_states_text, hidden_states_image,
hidden_states_cross_modal)` where each element is a list of the hidden states of the corresponding
modality. `hidden_states_txt/img` are a list of tensors corresponding to unimodal hidden states and
`hidden_states_cross_modal` is a list of tuples containing `cross_modal_text_hidden_states` and
`cross_modal_image_hidden_states` of each brdige layer.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels are currently not supported.
Examples:
```python
>>> from transformers import BridgeTowerProcessor, BridgeTowerModel
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> # prepare image and text
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> text = "hello world"
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base")
>>> model = BridgeTowerModel.from_pretrained("BridgeTower/bridgetower-base")
>>> inputs = processor(image, text, return_tensors="pt")
>>> outputs = model(**inputs)
>>> outputs.keys()
odict_keys(['text_features', 'image_features', 'pooler_output'])
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
all_hidden_states_text = () if output_hidden_states else None
all_hidden_states_image = () if output_hidden_states else None
all_hidden_states_cross = () if output_hidden_states else None
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if inputs_embeds is not None and input_ids is None:
raise NotImplementedError(
"BridgeTowerModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
image_token_type_idx = image_token_type_idx or 1
input_shape = input_ids.size()
text_embeds = self.text_model.embeddings(input_ids=input_ids)
if output_hidden_states:
all_hidden_states_text += (text_embeds,)
if attention_mask is None:
attention_mask = torch.ones(input_shape, dtype=torch.long, device=input_ids.device)
extend_text_masks = self.text_model.get_extended_attention_mask(attention_mask, input_shape).to(
input_ids.device
)
# The split_index determines how many layers of the uni-modal encoder are applied before the cross-modal encoder
split_index = len(self.text_model.encoder.layer) - self.config.num_hidden_layers + 1
# Run the first 'split_index' layers of the textual encoder
for layer in self.text_model.encoder.layer[:split_index]:
text_embeds = layer(text_embeds, extend_text_masks)[0]
if output_hidden_states:
all_hidden_states_text += (text_embeds,)
if image_embeds is None:
image_embeds = self.vision_model.visual.forward_pre(
pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding
)
else:
# Permute as BridgeTowerResidualAttention has batch_first=True
image_embeds = image_embeds.permute(1, 0, 2)
if output_hidden_states:
all_hidden_states_image += (image_embeds,)
# Run the first 'split_index' layers of the visual encoder
for block in self.vision_model.visual.transformer.resblocks[:split_index]:
image_embeds = block(image_embeds)
if output_hidden_states:
all_hidden_states_image += (image_embeds,)
image_embeds_with_ln = self.vision_model.visual.forward_post(image_embeds.type(self.vision_model.dtype))
# first layer is a special case because we don't have the output from the cross-encoder yet
cross_modal_text = self.cross_modal_text_transform(text_embeds)
text_token_type_embeddings = self.token_type_embeddings(
torch.zeros(1, dtype=torch.long, device=input_ids.device)
).expand_as(cross_modal_text)
cross_modal_text = self.cross_modal_text_layernorm(cross_modal_text + text_token_type_embeddings)
image_embeds_with_ln = self.cross_modal_image_transform(image_embeds_with_ln)
image_token_type_embeddings = self.token_type_embeddings(
torch.full((1,), image_token_type_idx, dtype=torch.long, device=input_ids.device)
).expand_as(image_embeds_with_ln)
image_embeds_with_ln = image_embeds_with_ln + image_token_type_embeddings
cross_modal_image = self.cross_modal_image_layernorm(image_embeds_with_ln)
pixel_mask = torch.ones(
(cross_modal_image.size(0), cross_modal_image.size(1)),
dtype=torch.long,
device=input_ids.device,
)
extend_image_masks = self.text_model.get_extended_attention_mask(pixel_mask, pixel_mask.size()).to(
input_ids.device
)
layer_outputs_text = self.cross_modal_text_layers[0](
cross_modal_text,
cross_modal_image,
attention_mask=extend_text_masks,
encoder_attention_mask=extend_image_masks,
output_attentions=output_attentions,
)
cross_text_features = layer_outputs_text[0]
layer_outputs_image = self.cross_modal_image_layers[0](
cross_modal_image,
cross_modal_text,
attention_mask=extend_image_masks,
encoder_attention_mask=extend_text_masks,
output_attentions=output_attentions,
)
cross_image_features = layer_outputs_image[0]
if output_hidden_states:
all_hidden_states_cross += ((cross_text_features, cross_image_features),)
if output_attentions:
all_self_attentions += ((layer_outputs_text[1], layer_outputs_image[1]),)
link_layer_index = 0
# Each of the top 6 layers of the visual and textual encoders ([split_index:]) is connected to each layer of
# the cross-modal encoder via bridge layers, which brings bottom-up alignment and fusion to the cross-modal encoder.
for i in range(split_index, len(self.text_model.encoder.layer)):
text_embeds = self.text_model.encoder.layer[i](text_embeds, extend_text_masks)[0]
image_embeds = self.vision_model.visual.transformer.resblocks[i](image_embeds).type(
self.vision_model.dtype
)
image_embeds_with_ln = (
self.cross_modal_image_transform(self.vision_model.visual.forward_post(image_embeds))
+ image_token_type_embeddings
)
text_link_tower = self.cross_modal_text_link_tower[link_layer_index]
image_link_tower = self.cross_modal_image_link_tower[link_layer_index]
# Bridge layers for textual and visual encoders
cross_text_features_ = text_link_tower(
self.cross_modal_text_transform(text_embeds) + text_token_type_embeddings,
cross_text_features,
extend_text_masks,
)
cross_image_features_ = image_link_tower(image_embeds_with_ln, cross_image_features, extend_image_masks)
# Cross-modal encoder via bridge layers of textual and visual encoders
layer_outputs_text = self.cross_modal_text_layers[link_layer_index + 1](
cross_text_features_,
cross_image_features_,
attention_mask=extend_text_masks,
encoder_attention_mask=extend_image_masks,
output_attentions=output_attentions,
)
cross_text_features = layer_outputs_text[0]
layer_outputs_image = self.cross_modal_image_layers[link_layer_index + 1](
cross_image_features_,
cross_text_features_,
attention_mask=extend_image_masks,
encoder_attention_mask=extend_text_masks,
output_attentions=output_attentions,
)
cross_image_features = layer_outputs_image[0]
link_layer_index += 1
if output_hidden_states:
all_hidden_states_text += (text_embeds,)
all_hidden_states_image += (image_embeds,)
all_hidden_states_cross += ((cross_text_features, cross_image_features),)
if output_attentions:
all_self_attentions += ((layer_outputs_text[1], layer_outputs_image[1]),)
# Concatenate the cls token of the text and image features to get the final represtation
text_features, image_features = cross_text_features, cross_image_features
cls_features = self.get_cls_features(text_features, image_features)
if output_hidden_states:
all_hidden_states = (all_hidden_states_text, all_hidden_states_image, all_hidden_states_cross)
if not return_dict:
return tuple(
v
for v in [text_features, image_features, cls_features, all_hidden_states, all_self_attentions]
if v is not None
)
return BridgeTowerModelOutput(
text_features=text_features,
image_features=image_features,
pooler_output=cls_features,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def get_cls_features(self, text_features, image_features):
cls_features_text = self.cross_modal_text_pooler(text_features)
cls_features_image = self.cross_modal_image_pooler(image_features)
return torch.cat([cls_features_text, cls_features_image], dim=-1)
# Copied from transformers.models.vilt.modeling_vilt.ViltPredictionHeadTransform with Vilt->BridgeTower
class BridgeTowerPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BridgeTowerMLMHead(nn.Module):
def __init__(self, config, weight=None):
super().__init__()
self.config = config
self.transform = BridgeTowerPredictionHeadTransform(config)
self.decoder = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.text_config.vocab_size))
if weight is not None:
self.decoder.weight = weight
def forward(self, x):
mlm_score = self.transform(x)
mlm_score = self.decoder(mlm_score) + self.bias
return mlm_score
class BridgeTowerITMHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.fc = nn.Linear(hidden_size, 2)
def forward(self, x):
itm_score = self.fc(x)
return itm_score
@auto_docstring(
custom_intro="""
BridgeTower Model with a language modeling head on top as done during pretraining.
"""
)
class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
_tied_weights_keys = {"mlm_score.decoder.weight": "bridgetower.text_model.embeddings.word_embeddings.weight"}
def __init__(self, config):
super().__init__(config)
self.bridgetower = BridgeTowerModel(config)
self.mlm_score = BridgeTowerMLMHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.mlm_score.decoder
def set_output_embeddings(self, new_embeddings):
self.mlm_score.decoder = new_embeddings
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.FloatTensor | None = None,
token_type_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
pixel_mask: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
image_embeds: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
labels: torch.LongTensor | None = None,
**kwargs,
) -> MaskedLMOutput | tuple[torch.FloatTensor]:
r"""
image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
Examples:
```python
>>> from transformers import BridgeTowerProcessor, BridgeTowerForMaskedLM
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> url = "http://images.cocodataset.org/val2017/000000360943.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read())).convert("RGB")
>>> text = "a <mask> looking out of the window"
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
>>> model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
>>> # prepare inputs
>>> encoding = processor(image, text, return_tensors="pt")
>>> # forward pass
>>> outputs = model(**encoding)
>>> results = processor.decode(outputs.logits.argmax(dim=-1).squeeze(0).tolist())
>>> print(results)
.a cat looking out of the window.
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bridgetower(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0])
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
labels = labels.to(mlm_logits.device)
masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
if not return_dict:
output = tuple(mlm_logits)
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=mlm_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@auto_docstring(
custom_intro="""
BridgeTower Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the
[CLS] token) for image-to-text matching.
"""
)
class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bridgetower = BridgeTowerModel(config)
self.itm_score = BridgeTowerITMHead(config.hidden_size * 2)
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.FloatTensor | None = None,
token_type_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
pixel_mask: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
image_embeds: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
labels: torch.LongTensor | None = None,
**kwargs,
) -> SequenceClassifierOutput | tuple[torch.FloatTensor]:
r"""
image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
The pairs with 0 will be skipped for calculation.
Examples:
```python
>>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval
>>> import httpx
>>> from io import BytesIO
>>> from PIL import Image
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
>>> model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
>>> # forward pass
>>> scores = dict()
>>> for text in texts:
... # prepare inputs
... encoding = processor(image, text, return_tensors="pt")
... outputs = model(**encoding)
... scores[text] = outputs.logits[0, 1].item()
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bridgetower(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooler_output = outputs.pooler_output if return_dict else outputs[2]
logits = self.itm_score(pooler_output)
itm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(logits.device)
itm_loss = loss_fct(logits, labels)
if not return_dict:
output = tuple(logits)
return ((itm_loss,) + output) if itm_loss is not None else output
return SequenceClassifierOutput(
loss=itm_loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class BridgeTowerContrastiveHead(nn.Module):
def __init__(self, hidden_size, embed_size):
super().__init__()
self.fc = nn.Linear(hidden_size, embed_size)
def forward(self, x):
x = self.fc(x)
return x
@auto_docstring(
custom_intro="""
BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss.
"""
)
class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bridgetower = BridgeTowerModel(config)
self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.FloatTensor | None = None,
token_type_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
pixel_mask: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
image_embeds: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = True,
return_dict: bool | None = None,
return_loss: bool | None = None,
**kwargs,
) -> BridgeTowerContrastiveOutput | tuple[torch.FloatTensor]:
r"""
image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
Examples:
```python
>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
>>> import httpx
>>> from io import BytesIO
>>> from PIL import Image
>>> import torch
>>> image_urls = [
... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
... "http://images.cocodataset.org/val2017/000000039769.jpg",
... ]
>>> texts = ["two dogs in a car", "two cats sleeping on a couch"]
>>> with httpx.stream("GET", urls[0]) as response:
... image1 = Image.open(BytesIO(response.read()))
>>> with httpx.stream("GET", urls[1]) as response:
... image2 = Image.open(BytesIO(response.read()))
>>> images = [image1, image2]
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> inputs = processor(images, texts, padding=True, return_tensors="pt")
>>> loss = model(**inputs, return_loss=True).loss
>>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
>>> loss_swapped = model(**inputs, return_loss=True).loss
>>> print("Loss", round(loss.item(), 4))
Loss 0.0019
>>> print("Loss with swapped images", round(loss_swapped.item(), 4))
Loss with swapped images 2.126
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bridgetower(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
pooler_output = outputs.pooler_output if return_dict else outputs[2]
hidden_states_txt, hidden_states_img, hidden_states_cross_modal = (
outputs.hidden_states if return_dict else outputs[3]
)
text_embeds = hidden_states_txt[-1]
image_embeds = hidden_states_img[-1]
image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds)
image_token_type_embeddings = self.bridgetower.token_type_embeddings(
torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
).expand_as(image_embeds_with_ln)
image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings
# normalized features
text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2).to(
device=text_embeds.device
)
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2).to(
device=text_embeds.device
)
logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)
logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale
itc_loss = None
if return_loss:
labels = torch.arange(len(logits), device=logits.device)
text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
if not return_dict:
output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:]
return ((itc_loss,) + output) if itc_loss is not None else output
return BridgeTowerContrastiveOutput(
loss=itc_loss,
logits=logits,
text_embeds=text_embeds,
image_embeds=image_embeds,
cross_embeds=cross_embeds,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"BridgeTowerForContrastiveLearning",
"BridgeTowerForImageAndTextRetrieval",
"BridgeTowerForMaskedLM",
"BridgeTowerModel",
"BridgeTowerPreTrainedModel",
]