You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1127 lines
44 KiB

# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_siglip2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ... import initialization as init
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
from ...utils.generic import check_model_inputs, is_flash_attention_requested
from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
@dataclass
@auto_docstring(
custom_intro="""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
"""
)
class Siglip2VisionOutput(ModelOutput):
r"""
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
"""
image_embeds: torch.FloatTensor | None = None
last_hidden_state: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
@dataclass
@auto_docstring(
custom_intro="""
Base class for text model's outputs that also contains a pooling of the last hidden states.
"""
)
class Siglip2TextOutput(ModelOutput):
r"""
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
"""
text_embeds: torch.FloatTensor | None = None
last_hidden_state: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
@dataclass
@auto_docstring
class Siglip2Output(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`Siglip2VisionModel`].
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`Siglip2TextModel`].
vision_model_output (`BaseModelOutputWithPooling`):
The output of the [`Siglip2VisionModel`].
"""
loss: torch.FloatTensor | None = None
logits_per_image: torch.FloatTensor | None = None
logits_per_text: torch.FloatTensor | None = None
text_embeds: torch.FloatTensor | None = None
image_embeds: torch.FloatTensor | None = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class Siglip2VisionEmbeddings(nn.Module):
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Linear(
in_features=config.num_channels * self.patch_size * self.patch_size,
out_features=self.embed_dim,
)
self.num_patches = config.num_patches
self.position_embedding_size = int(self.num_patches**0.5)
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
@staticmethod
def resize_positional_embeddings(
positional_embeddings: torch.Tensor,
spatial_shapes: torch.LongTensor,
max_length: int,
) -> torch.Tensor:
"""
Resize positional embeddings to image-specific size and pad to a fixed size.
Args:
positional_embeddings (`torch.Tensor`):
Position embeddings of shape (height, width, embed_dim)
spatial_shapes (`torch.LongTensor`):
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
max_length (`int`):
Maximum length of the positional embeddings to pad resized positional embeddings to
Returns:
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
"""
batch_size = spatial_shapes.shape[0]
embed_dim = positional_embeddings.shape[-1]
source_dtype = positional_embeddings.dtype
resulted_positional_embeddings = torch.empty(
(batch_size, max_length, embed_dim),
device=positional_embeddings.device,
dtype=source_dtype,
)
# (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
# Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
if positional_embeddings.device.type == "cpu":
positional_embeddings = positional_embeddings.to(torch.float32)
for i in range(batch_size):
# (1, dim, height, width) -> (1, dim, target_height, target_width)
height, width = spatial_shapes[i].tolist() # will be itemized in F.interpolate either way
torch_compilable_check((width > 0), "Width of resized positional embeddings must be positive.")
torch_compilable_check((height > 0), "Height of resized positional embeddings must be positive.")
torch_compilable_check((height * width) <= max_length, "Resized positional embeddings exceed max_length.")
resized_embeddings = F.interpolate(
positional_embeddings,
size=(height, width),
mode="bilinear",
align_corners=False,
antialias=True,
)
# (1, dim, target_height, target_width) -> (target_height * target_width, dim)
resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
# Cast to original dtype
resized_embeddings = resized_embeddings.to(source_dtype)
resulted_positional_embeddings[i, : height * width] = resized_embeddings
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
return resulted_positional_embeddings
def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
"""
Args:
pixel_values (`torch.FloatTensor`):
Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
spatial_shapes (`list[tuple[int, int]]`):
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
"""
# Apply patch embeddings to already patchified pixel values
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
# Get positional resized and padded positional embeddings
positional_embeddings = self.position_embedding.weight.reshape(
self.position_embedding_size, self.position_embedding_size, -1
)
resized_positional_embeddings = self.resize_positional_embeddings(
positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
)
# Add positional embeddings to patch embeddings
embeddings = patch_embeds + resized_positional_embeddings
return embeddings
class Siglip2TextEmbeddings(nn.Module):
def __init__(self, config: Siglip2TextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
input_ids: torch.LongTensor | None = None,
position_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
max_position_embedding = self.position_embedding.weight.shape[0]
if seq_length > max_position_embedding:
raise ValueError(
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Siglip2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Input shape: Batch x Time x Channel"""
batch_size, seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class Siglip2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Siglip2EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Siglip2VisionConfig | Siglip2TextConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.self_attn = Siglip2Attention(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config)
@auto_docstring
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
**kwargs: Unpack[TransformersKwargs],
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class Siglip2PreTrainedModel(PreTrainedModel):
config: Siglip2Config
base_model_prefix = "siglip2"
input_modalities = ("image", "text")
supports_gradient_checkpointing = True
_no_split_modules = [
"Siglip2TextEmbeddings",
"Siglip2VisionEmbeddings",
"Siglip2EncoderLayer",
"Siglip2MultiheadAttentionPoolingHead",
]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Siglip2EncoderLayer,
"attentions": Siglip2Attention,
}
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Siglip2VisionEmbeddings):
width = (
self.config.vision_config.hidden_size
if isinstance(self.config, Siglip2Config)
else self.config.hidden_size
)
init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
if hasattr(module, "position_ids"):
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
elif isinstance(module, nn.Embedding):
init.default_flax_embed_init_(module.weight)
elif isinstance(module, Siglip2Attention):
init.xavier_uniform_(module.q_proj.weight)
init.xavier_uniform_(module.k_proj.weight)
init.xavier_uniform_(module.v_proj.weight)
init.xavier_uniform_(module.out_proj.weight)
init.zeros_(module.q_proj.bias)
init.zeros_(module.k_proj.bias)
init.zeros_(module.v_proj.bias)
init.zeros_(module.out_proj.bias)
elif isinstance(module, Siglip2MLP):
init.xavier_uniform_(module.fc1.weight)
init.xavier_uniform_(module.fc2.weight)
init.normal_(module.fc1.bias, std=1e-6)
init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, Siglip2MultiheadAttentionPoolingHead):
init.xavier_uniform_(module.probe)
init.xavier_uniform_(module.attention.in_proj_weight)
init.zeros_(module.attention.in_proj_bias)
elif isinstance(module, Siglip2Model):
init.zeros_(module.logit_scale)
init.zeros_(module.logit_bias)
elif isinstance(module, Siglip2ForImageClassification):
init.normal_(
module.classifier.weight,
std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
init.lecun_normal_(module.weight)
if module.bias is not None:
init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
init.zeros_(module.bias)
init.ones_(module.weight)
elif isinstance(module, Siglip2TextEmbeddings):
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Siglip2EncoderLayer`].
Args:
config: Siglip2Config
"""
def __init__(self, config: Siglip2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Ignore copy
@auto_docstring
def forward(
self,
inputs_embeds,
attention_mask: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
**kwargs,
)
return BaseModelOutput(last_hidden_state=hidden_states)
class Siglip2VisionTransformer(Siglip2PreTrainedModel):
_input_embed_layer = "patch_embedding"
def __init__(self, config: Siglip2VisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
if self.use_head:
self.head = Siglip2MultiheadAttentionPoolingHead(config)
self.post_init()
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
**kwargs,
) -> BaseModelOutputWithPooling:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
hidden_states = self.embeddings(pixel_values, spatial_shapes)
if attention_mask is not None and not is_flash_attention_requested(self.config):
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
else:
encoder_attention_mask = attention_mask
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)
pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class Siglip2TextTransformer(Siglip2PreTrainedModel):
_input_embed_layer = "token_embedding"
def __init__(self, config: Siglip2TextConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2TextEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, config.projection_size)
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
# note: Siglip2's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
uses_flash_attention = is_flash_attention_requested(self.config)
if uses_flash_attention:
attention_mask = None
elif attention_mask is not None and not uses_flash_attention:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
**kwargs,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.final_layer_norm(last_hidden_state)
# The model uses the last token's hidden state, which may be padding.
pooled_output = last_hidden_state[:, -1, :]
pooled_output = self.head(pooled_output)
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
)
@auto_docstring(
custom_intro="""
The text model from Siglip2 without any head or projection on top.
"""
)
class Siglip2TextModel(Siglip2PreTrainedModel):
config: Siglip2TextConfig
input_modalities = ("text",)
def __init__(self, config: Siglip2TextConfig):
super().__init__(config)
self.text_model = Siglip2TextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
r"""
Examples:
```python
>>> from transformers import AutoTokenizer, Siglip2TextModel
>>> model = Siglip2TextModel.from_pretrained("google/siglip2-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
class Siglip2MultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config)
self.num_heads = config.num_attention_heads
def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
if attention_mask is not None:
target_len, source_len = probe.shape[1], hidden_state.shape[1]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
attention_mask = attention_mask.reshape(-1, target_len, source_len)
hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
@auto_docstring(
custom_intro="""
The vision model from Siglip2 without any head or projection on top.
"""
)
class Siglip2VisionModel(Siglip2PreTrainedModel):
config: Siglip2VisionConfig
main_input_name = "pixel_values"
input_modalities = ("image",)
def __init__(self, config: Siglip2VisionConfig):
super().__init__(config)
self.vision_model = Siglip2VisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
Examples:
```python
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> from transformers import AutoProcessor, Siglip2VisionModel
>>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
return self.vision_model(
pixel_values=pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
**kwargs,
)
@auto_docstring
class Siglip2Model(Siglip2PreTrainedModel):
config: Siglip2Config
def __init__(self, config: Siglip2Config):
super().__init__(config)
if not isinstance(config.text_config, Siglip2TextConfig):
raise TypeError(
"config.text_config is expected to be of type Siglip2TextConfig but is of type"
f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, Siglip2VisionConfig):
raise TypeError(
"config.vision_config is expected to be of type Siglip2VisionConfig but is of type"
f" {type(config.vision_config)}."
)
text_config = config.text_config
vision_config = config.vision_config
# First, initialize the text and vision models with proper attention implementation
text_model = Siglip2TextModel._from_config(text_config)
vision_model = Siglip2VisionModel._from_config(vision_config)
# Second, get the text and vision submodules (for backward compatibility)
self.text_model = text_model.text_model
self.vision_model = vision_model.vision_model
self.logit_scale = nn.Parameter(torch.randn(1))
self.logit_bias = nn.Parameter(torch.randn(1))
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value: nn.Module):
self.text_model.embeddings.token_embedding = value
@can_return_tuple
@auto_docstring
def get_text_features(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPooling:
r"""
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... text_features = model.get_text_features(**inputs)
```"""
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
@can_return_tuple
@auto_docstring
def get_image_features(
self,
pixel_values: torch.FloatTensor | None = None,
pixel_attention_mask: torch.Tensor | None = None,
spatial_shapes: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPooling:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
Examples:
```python
>>> import torch
>>> from transformers import AutoProcessor, AutoModel
>>> from transformers.image_utils import load_image
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = load_image(url)
>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
>>> inputs = processor(images=image, return_tensors="pt")
>>> with torch.no_grad():
... image_features = model.get_image_features(**inputs)
```
"""
return self.vision_model(
pixel_values=pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
**kwargs,
)
# NOTE: Siglip2Model uses Pretrained backbones, so we don't need to add `check_model_inputs` here
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
pixel_attention_mask: torch.Tensor | None = None,
spatial_shapes: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
return_loss: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
**kwargs,
) -> Siglip2Output:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
Examples:
```python
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
>>> # important: we pass `padding=max_length` since the model was trained with this
>>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
31.9% that image 0 is 'a photo of 2 cats'
```
"""
# Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values=pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
text_outputs: BaseModelOutputWithPooling = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
image_embeds = vision_outputs.pooler_output
text_embeds = text_outputs.pooler_output
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
# Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
nll = -torch.sum(loglik, dim=-1)
loss = nll.mean()
return Siglip2Output(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
@auto_docstring(
custom_intro="""
Siglip2 vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
the patch tokens) e.g. for ImageNet.
"""
)
class Siglip2ForImageClassification(Siglip2PreTrainedModel):
main_input_name = "pixel_values"
input_modalities = ("image",)
def __init__(self, config: Siglip2Config) -> None:
super().__init__(config)
self.num_labels = config.num_labels
# Create the vision model with proper attention
# and take only vision_model submodule (for backward compatibility)
vision_model = Siglip2VisionModel._from_config(config.vision_config)
self.vision_model = vision_model.vision_model
# Classifier head
self.classifier = (
nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def set_input_embeddings(self, value: nn.Module):
self.vision_model.embeddings.patch_embedding = value
@check_model_inputs
@auto_docstring
def forward(
self,
pixel_values: torch.Tensor | None = None,
pixel_attention_mask: torch.Tensor | None = None,
spatial_shapes: torch.LongTensor | None = None,
labels: torch.Tensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
**kwargs,
) -> ImageClassifierOutput:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Examples:
```python
>>> from transformers import AutoImageProcessor, Siglip2ForImageClassification
>>> import torch
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> # note: we are loading a `Siglip2Model` from the hub here,
>>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
>>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
>>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224")
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the two classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: LABEL_1
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs.last_hidden_state
# average pool the patch tokens
if pixel_attention_mask is not None:
pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
else:
sequence_output = torch.mean(sequence_output, dim=1)
# apply classifier
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss = self.loss_function(labels, logits, self.config)
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"Siglip2Model",
"Siglip2PreTrainedModel",
"Siglip2TextModel",
"Siglip2VisionModel",
"Siglip2ForImageClassification",
]