You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1604 lines
70 KiB

# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/sam2/modular_sam2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_sam2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers.utils.generic import OutputRecorder
from ... import initialization as init
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import compile_compatible_method_lru_cache
from ...utils import (
ModelOutput,
auto_docstring,
can_return_tuple,
logging,
)
from ...utils.generic import TransformersKwargs, check_model_inputs, is_flash_attention_requested
from ..auto import AutoModel
from .configuration_sam2 import (
Sam2Config,
Sam2HieraDetConfig,
Sam2MaskDecoderConfig,
Sam2PromptEncoderConfig,
Sam2VisionConfig,
)
logger = logging.get_logger(__name__)
@dataclass
@auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
class Sam2VisionEncoderOutput(BaseModelOutputWithPooling):
r"""
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
model at the output of each stage.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
fpn_hidden_states (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
fpn_position_encoding (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
"""
fpn_hidden_states: torch.FloatTensor | None = None
fpn_position_encoding: torch.FloatTensor | None = None
@dataclass
@auto_docstring(custom_intro="Base class for the Sam2 model's output.")
class Sam2ImageSegmentationOutput(ModelOutput):
r"""
iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
The Intersection over Union (IoU) scores of the predicted masks.
pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
by the processor to be brought to the original image size.
object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
Logits for the object score, indicating if an object is present.
image_embeddings (`tuple(torch.FloatTensor)`):
The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
tensor has shape `(batch_size, channels, height, width)`.
vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
Hidden-states of the vision model at the output of each stage.
vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights of the vision model.
mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights of the mask decoder.
"""
iou_scores: torch.FloatTensor | None = None
pred_masks: torch.FloatTensor | None = None
object_score_logits: torch.FloatTensor | None = None
image_embeddings: tuple[torch.FloatTensor, ...] = None
vision_hidden_states: tuple[torch.FloatTensor, ...] | None = None
vision_attentions: tuple[torch.FloatTensor, ...] | None = None
mask_decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
class Sam2PatchEmbeddings(nn.Module):
r"""
Turns pixel values into patch embeddings for transformer consumption.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`Sam2ImageProcessorFast.__call__`] for details.
Returns:
embeddings (`torch.FloatTensor`):
Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding
"""
def __init__(self, config: Sam2HieraDetConfig):
super().__init__()
num_channels = config.num_channels
hidden_size = config.hidden_size
self.projection = nn.Conv2d(
num_channels,
hidden_size,
kernel_size=config.patch_kernel_size,
stride=config.patch_stride,
padding=config.patch_padding,
)
def forward(self, pixel_values):
_, num_channels, height, width = pixel_values.shape
embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).permute(0, 2, 3, 1)
return embeddings
# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
class Sam2SinePositionEmbedding(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None
):
super().__init__()
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
self.scale = 2 * math.pi if scale is None else scale
@compile_compatible_method_lru_cache(maxsize=1)
def forward(
self,
shape: torch.Size,
device: torch.device | str,
dtype: torch.dtype,
mask: Tensor | None = None,
) -> Tensor:
if mask is None:
mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
not_mask = (~mask).to(dtype)
y_embed = not_mask.cumsum(1)
x_embed = not_mask.cumsum(2)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class Sam2VisionNeck(nn.Module):
def __init__(self, config: Sam2VisionConfig):
super().__init__()
self.config = config
self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
self.convs = nn.ModuleList()
for in_channels in config.backbone_channel_list:
self.convs.append(
nn.Conv2d(
in_channels=in_channels,
out_channels=config.fpn_hidden_size,
kernel_size=config.fpn_kernel_size,
stride=config.fpn_stride,
padding=config.fpn_padding,
),
)
self.fpn_top_down_levels = config.fpn_top_down_levels
def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
fpn_hidden_states = ()
fpn_position_encoding = ()
# forward in top-down order (from low to high resolution)
n = len(self.convs) - 1
for i in range(n, -1, -1):
lateral_features = hidden_states[i].permute(0, 3, 1, 2)
lateral_features = self.convs[n - i](lateral_features.to(self.convs[i].weight.dtype))
if i not in self.fpn_top_down_levels or i == n:
prev_features = lateral_features
else:
top_down_features = F.interpolate(
prev_features.to(dtype=torch.float32),
scale_factor=2.0,
mode="nearest",
align_corners=None,
antialias=False,
).to(lateral_features.dtype)
prev_features = lateral_features + top_down_features
prev_position_encoding = self.position_encoding(
prev_features.shape, prev_features.device, prev_features.dtype
).to(prev_features.dtype)
fpn_hidden_states += (prev_features,)
fpn_position_encoding += (prev_position_encoding,)
return fpn_hidden_states, fpn_position_encoding
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def do_pool(x: torch.Tensor, query_stride: int | None = None) -> torch.Tensor:
if query_stride is None:
return x
# (B, H, W, C) -> (B, C, H, W)
x = x.permute(0, 3, 1, 2)
x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False)
# (B, C, H', W') -> (B, H', W', C)
x = x.permute(0, 2, 3, 1)
return x
class Sam2MultiScaleAttention(nn.Module):
def __init__(
self,
config: Sam2HieraDetConfig,
dim: int,
dim_out: int,
num_attention_heads: int,
query_stride: tuple[int, int] | None = None,
):
super().__init__()
self.config = config
self.dim = dim
self.dim_out = dim_out
self.query_stride = query_stride
self.num_attention_heads = num_attention_heads
head_dim = dim_out // num_attention_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out)
self.is_causal = False
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (B, H * W, 3, nHead, C)
qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
# q, k, v with shape (B, H * W, nheads, C)
query, key, value = torch.unbind(qkv, 2)
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
# Q pooling (for downsample at stage changes)
if self.query_stride:
query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride)
height, width = query.shape[1:3] # downsampled shape
query = query.reshape(batch_size, height * width, self.num_attention_heads, -1)
# transpose query, key, value to (B, nHead, H * W, C)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, _ = attention_interface(
self,
query,
key,
value,
attention_mask=None,
is_causal=self.is_causal,
scaling=self.scale,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
return attn_output
class Sam2FeedForward(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
activation: str = "relu",
sigmoid_output: bool = False,
):
super().__init__()
self.num_layers = num_layers
self.activation = ACT2FN[activation]
self.proj_in = nn.Linear(input_dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, output_dim)
self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
self.sigmoid_output = sigmoid_output
def forward(self, hidden_states):
hidden_states = self.proj_in(hidden_states)
hidden_states = self.activation(hidden_states)
for layer in self.layers:
hidden_states = self.activation(layer(hidden_states))
hidden_states = self.proj_out(hidden_states)
if self.sigmoid_output:
hidden_states = F.sigmoid(hidden_states)
return hidden_states
def window_partition(hidden_state, window_size):
"""
Partition into non-overlapping windows with padding if needed.
Args:
hidden_state (`torch.Tensor`):
Input tokens with [batch_size, height, width, num_channels].
window_size (`int`):
Window size.
Returns:
`tuple(torch.FloatTensor)` comprising various elements:
- windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
- (padded_height, padded_width): padded height and width before partition
"""
batch_size, height, width, num_channels = hidden_state.shape
pad_height = (window_size - height % window_size) % window_size
pad_width = (window_size - width % window_size) % window_size
# Noop in case pad_width == 0 and pad_height == 0.
hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
padded_height, padded_width = height + pad_height, width + pad_width
hidden_state = hidden_state.view(
batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
)
windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
return windows, (padded_height, padded_width)
def window_unpartition(windows, window_size, pad_height_width, height_width):
"""
Window unpartition into original sequences and removing padding.
Args:
windows (`torch.Tensor`):
Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
window_size (`int`):
Window size.
pad_height_width (`tuple[int]`):
Padded height and width (padded_height, padded_width).
height_width (`tuple[int]`):
Original height and width before padding.
Returns:
hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
"""
padded_height, padded_width = pad_height_width
height, width = height_width
batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
hidden_state = windows.view(
batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
)
hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
# We always have height <= padded_height and width <= padded_width
hidden_state = hidden_state[:, :height, :width, :].contiguous()
return hidden_state
class Sam2MultiScaleBlock(GradientCheckpointingLayer):
def __init__(
self,
config: Sam2HieraDetConfig,
stage_idx: int,
block_idx: int,
total_block_idx: int,
):
super().__init__()
# take embed dim from previous stage if first block of stage
self.dim = (
config.embed_dim_per_stage[stage_idx - 1]
if stage_idx > 0 and block_idx == 0
else config.embed_dim_per_stage[stage_idx]
)
self.dim_out = config.embed_dim_per_stage[stage_idx]
self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
# take window size from previous stage if first block of stage
self.window_size = (
config.window_size_per_stage[stage_idx - 1]
if stage_idx > 0 and block_idx == 0
else config.window_size_per_stage[stage_idx]
)
self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size
# use query stride for first block of stage if stage is a query pool stage
self.query_stride = (
config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None
)
self.attn = Sam2MultiScaleAttention(
config,
self.dim,
self.dim_out,
num_attention_heads=config.num_attention_heads_per_stage[stage_idx],
query_stride=self.query_stride,
)
self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps)
self.mlp = Sam2FeedForward(
self.dim_out,
int(self.dim_out * config.mlp_ratio),
self.dim_out,
num_layers=2,
activation=config.hidden_act,
)
if self.dim != self.dim_out:
self.proj = nn.Linear(self.dim, self.dim_out)
def forward(
self,
hidden_states: torch.Tensor,
**kwargs: Unpack[TransformersKwargs],
) -> torch.FloatTensor:
residual = hidden_states # batch_size, height, width, channel
hidden_states = self.layer_norm1(hidden_states)
# Skip connection
if self.dim != self.dim_out:
residual = do_pool(self.proj(hidden_states), self.query_stride)
# Window partition
window_size = self.window_size
if self.window_size > 0:
H, W = hidden_states.shape[1], hidden_states.shape[2]
hidden_states, pad_hw = window_partition(hidden_states, window_size)
# Window Attention + Q Pooling (if stage change)
attn_output = self.attn(
hidden_states=hidden_states,
**kwargs,
)
hidden_states = attn_output
if self.query_stride:
# Shapes have changed due to Q pooling
window_size = self.window_size // self.query_stride[0]
H, W = residual.shape[1:3]
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
pad_hw = (H + pad_h, W + pad_w)
# Reverse window partition
if self.window_size > 0:
hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W))
hidden_states = residual + hidden_states
layernorm_output = self.layer_norm2(hidden_states)
hidden_states = hidden_states + self.mlp(layernorm_output)
return hidden_states
@dataclass
@auto_docstring(
custom_intro="""
Hiera model's outputs that also contains a pooling of the last hidden states.
"""
)
class Sam2HieraDetModelOutput(ModelOutput):
r"""
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
hidden-states at the output of the last layer of the model.
intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`):
Sequence of hidden-states at the output of the intermediate layers of the model.
"""
last_hidden_state: torch.FloatTensor | None = None
intermediate_hidden_states: tuple[torch.FloatTensor, ...] | None = None
@auto_docstring
class Sam2PreTrainedModel(PreTrainedModel):
config_class = Sam2Config
base_model_prefix = "sam2"
main_input_name = "pixel_values"
input_modalities = ("image",)
_supports_sdpa = True
_supports_flash_attn = True
_supports_attention_backend = True
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, Sam2HieraDetModel):
if module.pos_embed is not None:
init.zeros_(module.pos_embed)
if module.pos_embed_window is not None:
init.zeros_(module.pos_embed_window)
elif isinstance(module, Sam2PositionalEmbedding):
init.normal_(module.positional_embedding, std=module.scale)
elif isinstance(module, Sam2Model):
if module.no_memory_embedding is not None:
init.zeros_(module.no_memory_embedding)
class Sam2HieraDetModel(Sam2PreTrainedModel):
config_class = Sam2HieraDetConfig
main_input_name = "pixel_values"
_can_record_outputs = {
"hidden_states": Sam2MultiScaleBlock,
"attentions": Sam2MultiScaleAttention,
}
def __init__(self, config: Sam2HieraDetConfig):
super().__init__(config)
self.patch_embed = Sam2PatchEmbeddings(config)
# Windowed positional embedding (https://huggingface.co/papers/2311.05613)
self.pos_embed = nn.Parameter(
torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size)
)
self.pos_embed_window = nn.Parameter(
torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0])
)
self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist()
self.blocks = nn.ModuleList()
total_block_idx = 0
for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage):
for block_idx in range(blocks_per_stage):
block = Sam2MultiScaleBlock(
config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx
)
self.blocks.append(block)
total_block_idx += 1
self.post_init()
def get_input_embeddings(self):
return self.patch_embed
def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
h, w = hw
window_embed = self.pos_embed_window
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
@check_model_inputs
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Sam2HieraDetModelOutput:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.patch_embed(pixel_values)
hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3])
intermediate_hidden_states = ()
for i, block_module in enumerate(self.blocks):
hidden_states = block_module(hidden_states, **kwargs)
if i in self.stage_ends:
intermediate_hidden_states = intermediate_hidden_states + (hidden_states,)
return Sam2HieraDetModelOutput(
last_hidden_state=hidden_states,
intermediate_hidden_states=intermediate_hidden_states,
)
@auto_docstring(
custom_intro="""
The vision model from Sam without any head or projection on top.
"""
)
class Sam2VisionModel(Sam2PreTrainedModel):
config_class = Sam2VisionConfig
main_input_name = "pixel_values"
_can_record_outputs = {
"hidden_states": Sam2MultiScaleBlock,
"attentions": Sam2MultiScaleAttention,
}
def __init__(self, config: Sam2VisionConfig):
super().__init__(config)
self.config = config
self.backbone = AutoModel.from_config(config.backbone_config)
self.neck = Sam2VisionNeck(config)
self.num_feature_levels = config.num_feature_levels
self.post_init()
def get_input_embeddings(self):
return self.backbone.get_input_embeddings()
@check_model_inputs
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Sam2VisionEncoderOutput:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Forward through backbone
backbone_output = self.backbone(pixel_values, **kwargs)
hidden_states = backbone_output.last_hidden_state
intermediate_hidden_states = backbone_output.intermediate_hidden_states
fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
# Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
return Sam2VisionEncoderOutput(
last_hidden_state=hidden_states,
fpn_hidden_states=fpn_hidden_states,
fpn_position_encoding=fpn_position_encoding,
)
class Sam2PositionalEmbedding(nn.Module):
def __init__(self, config: Sam2PromptEncoderConfig):
super().__init__()
self.scale = config.scale
positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
self.register_buffer("positional_embedding", positional_embedding)
def forward(self, input_coords, input_shape=None):
"""Positionally encode points that are normalized to [0,1]."""
coordinates = input_coords.clone()
if input_shape is not None:
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
coordinates.to(torch.float32)
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coordinates = 2 * coordinates - 1
coordinates = coordinates.to(self.positional_embedding.dtype)
coordinates = coordinates @ self.positional_embedding
coordinates = 2 * np.pi * coordinates
# outputs d_1 x ... x d_n x channel shape
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
class Sam2MaskEmbedding(nn.Module):
def __init__(self, config: Sam2PromptEncoderConfig):
super().__init__()
self.mask_input_channels = config.mask_input_channels // 4
self.activation = ACT2FN[config.hidden_act]
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
self.layer_norm1 = Sam2LayerNorm(
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
)
self.layer_norm2 = Sam2LayerNorm(
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
)
def forward(self, masks):
hidden_states = self.conv1(masks)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.activation(hidden_states)
dense_embeddings = self.conv3(hidden_states)
return dense_embeddings
class Sam2PromptEncoder(nn.Module):
def __init__(self, config: Sam2PromptEncoderConfig):
super().__init__()
self.shared_embedding = Sam2PositionalEmbedding(config)
self.mask_embed = Sam2MaskEmbedding(config)
self.no_mask_embed = nn.Embedding(1, config.hidden_size)
self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
self.input_image_size = config.image_size
self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
self.hidden_size = config.hidden_size
self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
input_shape = (self.input_image_size, self.input_image_size)
point_embedding = self.shared_embedding(points, input_shape)
# torch.where and expanding the labels tensor is required by the ONNX export
point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
# This is required for the ONNX export. The dtype, device need to be explicitly
# specified as otherwise torch.onnx.export interprets as double
point_embedding = torch.where(
labels[..., None] != -10,
point_embedding,
torch.zeros_like(point_embedding),
)
# Add point embeddings for labels >= 0
point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.view(*boxes.shape[:2], 2, 2)
# add padding point for consistency with the original implementation
coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
return corner_embedding
def forward(
self,
input_points: tuple[torch.Tensor, torch.Tensor] | None,
input_labels: torch.Tensor | None,
input_boxes: torch.Tensor | None,
input_masks: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (`torch.Tensor`, *optional*):
point coordinates and labels to embed.
boxes (`torch.Tensor`, *optional*):
boxes to embed
masks (`torch.Tensor`, *optional*):
masks to embed
"""
sparse_embeddings = None
batch_size = 1
if input_points is not None:
batch_size = input_points.shape[0]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
sparse_embeddings = point_embeddings
if input_boxes is not None:
batch_size = input_boxes.shape[0]
box_embeddings = self._embed_boxes(input_boxes)
if sparse_embeddings is None:
sparse_embeddings = box_embeddings
else:
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
if input_masks is not None:
dense_embeddings = self.mask_embed(input_masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
class Sam2Attention(nn.Module):
"""
SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""
def __init__(self, config, downsample_rate=None):
super().__init__()
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
self.config = config
self.hidden_size = config.hidden_size
self.internal_dim = config.hidden_size // downsample_rate
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.internal_dim // config.num_attention_heads
self.scaling = self.head_dim**-0.5
self.is_causal = False
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_similarity: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
# Input projections
batch_size, point_batch_size = query.shape[:2]
new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
query = self.q_proj(query).view(*new_shape).transpose(1, 2)
key = self.k_proj(key).view(*new_shape).transpose(1, 2)
value = self.v_proj(value).view(*new_shape).transpose(1, 2)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
if is_flash_attention_requested(self.config) and attention_similarity is not None:
# Target guided masks are represented as float masks and are incompatible with Flash Attention
# Fallback to SDPA for this call only so the rest of the model can still benefit from FA
attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
logger.warning_once(
"Falling back to SDPA for target-guided attention because "
"Flash Attention does not support additive bias masks."
)
attn_output, attn_weights = attention_interface(
self,
query,
key,
value,
attention_mask=attention_similarity,
dropout=0.0,
scaling=self.scaling,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(
batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Sam2TwoWayAttentionBlock(GradientCheckpointingLayer):
def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False):
"""
A transformer block with four layers:
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
sparse inputs (4) cross attention of dense inputs -> sparse inputs
Arguments:
config (`Sam2MaskDecoderConfig`):
The configuration file used to instantiate the block
attention_downsample_rate (*optionalk*, int, defaults to 2):
The downsample ratio of the block used to reduce the inner dim of the attention.
skip_first_layer_pe (*optional*, bool, defaults to `False`):
Whether or not to skip the addition of the query_point_embedding on the first layer.
"""
super().__init__()
self.self_attn = Sam2Attention(config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(config.hidden_size)
self.cross_attn_token_to_image = Sam2Attention(config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size)
self.mlp = Sam2FeedForward(
config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
)
self.layer_norm3 = nn.LayerNorm(config.hidden_size)
self.layer_norm4 = nn.LayerNorm(config.hidden_size)
self.cross_attn_image_to_token = Sam2Attention(config)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self,
queries: Tensor,
keys: Tensor,
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
**kwargs: Unpack[TransformersKwargs],
):
# Self attention block
if self.skip_first_layer_pe:
queries, _ = self.self_attn(query=queries, key=queries, value=queries)
else:
query = queries + query_point_embedding
attn_out, _ = self.self_attn(query=query, key=query, value=queries)
queries = queries + attn_out
queries = self.layer_norm1(queries)
# Cross attention block, tokens attending to image embedding
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out, _ = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity
)
queries = queries + attn_out
queries = self.layer_norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.layer_norm3(queries)
# Cross attention block, image embedding attending to tokens
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
keys = keys + attn_out
keys = self.layer_norm4(keys)
return queries, keys, attn_out
class Sam2TwoWayTransformer(nn.Module):
def __init__(self, config: Sam2MaskDecoderConfig):
super().__init__()
self.config = config
self.num_hidden_layers = config.num_hidden_layers
self.layers = nn.ModuleList()
for i in range(self.num_hidden_layers):
self.layers.append(Sam2TwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
self.final_attn_token_to_image = Sam2Attention(config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
def forward(
self,
point_embeddings: Tensor,
image_embeddings: Tensor,
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutput:
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
# Prepare queries
queries = point_embeddings
keys = image_embeddings
# Apply transformer blocks and final layernorm
for layer in self.layers:
if target_embedding is not None:
queries += target_embedding
queries, keys, _ = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
**kwargs,
)
# Apply the final attention layer from the points to the image
query = queries + point_embeddings
key = keys + image_positional_embeddings
attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm_final_attn(queries)
return queries, keys
class Sam2LayerNorm(nn.LayerNorm):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
super().__init__(normalized_shape, eps=eps, **kwargs)
if data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(f"Unsupported data format: {data_format}")
self.data_format = data_format
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Args:
features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
"""
if self.data_format == "channels_first":
features = features.permute(0, 2, 3, 1)
features = super().forward(features)
features = features.permute(0, 3, 1, 2)
else:
features = super().forward(features)
return features
class Sam2MaskDecoder(nn.Module):
def __init__(self, config: Sam2MaskDecoderConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1
self.iou_token = nn.Embedding(1, self.hidden_size)
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
self.transformer = Sam2TwoWayTransformer(config)
# should we create a new class for this?
self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
self.upscale_layer_norm = Sam2LayerNorm(self.hidden_size // 4, data_format="channels_first")
self.activation = nn.GELU()
mlps_list = []
for _ in range(self.num_mask_tokens):
mlps_list += [Sam2FeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
self.iou_prediction_head = Sam2FeedForward(
self.hidden_size,
config.iou_head_hidden_dim,
self.num_mask_tokens,
config.iou_head_depth,
sigmoid_output=True,
)
self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
self.obj_score_token = nn.Embedding(1, self.hidden_size)
self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3)
self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
def forward(
self,
image_embeddings: torch.Tensor,
image_positional_embeddings: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
high_resolution_features: list[torch.Tensor],
attention_similarity: torch.Tensor | None = None,
target_embedding: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Args:
image_embeddings (`torch.Tensor`):
The embeddings from the image encoder.
image_positional_embeddings (`torch.Tensor`):
Positional encoding with the shape of image_embeddings.
sparse_prompt_embeddings (`torch.Tensor`):
The embeddings of the points and boxes.
dense_prompt_embeddings (`torch.Tensor`):
The embeddings of the mask inputs.
multimask_output (`bool`):
Whether to return multiple masks or a single mask.
high_resolution_features (`list[torch.Tensor]`, *optional*):
The high-resolution features from the vision encoder.
attention_similarity (`torch.Tensor`, *optional*):
The attention similarity tensor.
target_embedding (`torch.Tensor`, *optional*):
The target embedding.
"""
batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1]
# Concatenate output tokens
output_tokens = torch.cat(
[
self.obj_score_token.weight,
self.iou_token.weight,
self.mask_tokens.weight,
],
dim=0,
)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
if sparse_prompt_embeddings.shape[0] != 0:
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
else:
tokens = output_tokens
point_embeddings = tokens.to(self.iou_token.weight.dtype)
# Expand per-image data in batch direction to be per-mask
image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
# Run the transformer
point_embeddings, image_embeddings = self.transformer(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
**kwargs,
)
iou_token_out = point_embeddings[:, :, 1, :]
mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
image_embeddings = image_embeddings.transpose(2, 3).view(
batch_size * point_batch_size, num_channels, height, width
)
feat_s0, feat_s1 = high_resolution_features
feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
hyper_in_list: list[torch.Tensor] = []
for i in range(self.num_mask_tokens):
current_mlp = self.output_hypernetworks_mlps[i]
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
hyper_in = torch.stack(hyper_in_list, dim=2)
_, num_channels, height, width = upscaled_embedding.shape
upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
elif self.dynamic_multimask_via_stability and not self.training:
mask_slice = slice(0, 1)
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else:
mask_slice = slice(0, 1)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
return masks, iou_pred, sam_tokens_out, object_score_logits
def _get_stability_scores(self, mask_logits):
"""
Compute stability scores of the mask logits based on the IoU between upper and
lower thresholds.
"""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
return stability_scores
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
When outputting a single mask, if the stability score from the current single-mask
output (based on output token 0) falls below a threshold, we instead select from
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
"""
# The best mask from multimask output tokens (1~3)
multimask_logits = all_mask_logits[:, :, 1:, :, :]
multimask_iou_scores = all_iou_scores[:, :, 1:]
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
best_scores_inds_expanded = best_scores_inds_expanded.expand(
-1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
)
best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
# The mask from singlemask output token 0 and its stability score
singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
singlemask_iou_scores = all_iou_scores[:, :, 0:1]
stability_scores = self._get_stability_scores(singlemask_logits)
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
# Dynamically fall back to best multimask output upon low stability scores.
mask_logits_out = torch.where(
is_stable[..., None, None].expand_as(singlemask_logits),
singlemask_logits,
best_multimask_logits,
)
iou_scores_out = torch.where(
is_stable.expand_as(singlemask_iou_scores),
singlemask_iou_scores,
best_multimask_iou_scores,
)
return mask_logits_out, iou_scores_out
@auto_docstring(
custom_intro="""
Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
input points and labels, boxes, or masks.
"""
)
class Sam2Model(Sam2PreTrainedModel):
input_modalities = ("image", "text")
_can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2TwoWayAttentionBlock, index=2)}
_tied_weights_keys = {}
_keys_to_ignore_on_load_unexpected = [
r"^memory_.*",
r"^mask_downsample.*",
r"^object_pointer_proj.*",
r"^temporal_positional_encoding_projection_layer.*",
"no_memory_positional_encoding",
"no_object_pointer",
"occlusion_spatial_embedding_parameter",
]
def __init__(self, config: Sam2Config):
super().__init__(config)
self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config)
self.vision_encoder = AutoModel.from_config(config.vision_config)
self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config)
# The module using it is not a PreTrainedModel subclass so we need this
config.mask_decoder_config._attn_implementation = config._attn_implementation
self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config)
self.num_feature_levels = config.vision_config.num_feature_levels
self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
# a single token to indicate no memory embedding from previous frames
self.hidden_dim = config.vision_config.fpn_hidden_size
self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
self.post_init()
def get_input_embeddings(self):
return self.vision_encoder.get_input_embeddings()
def get_image_wide_positional_embeddings(self) -> torch.Tensor:
size = self.prompt_encoder.image_embedding_size
target_device = self.shared_image_embedding.positional_embedding.device
target_dtype = self.shared_image_embedding.positional_embedding.dtype
grid = torch.ones(size, device=target_device, dtype=target_dtype)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / size[0]
x_embed = x_embed / size[1]
positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
@torch.no_grad()
def get_image_embeddings(
self,
pixel_values: torch.FloatTensor,
**kwargs: Unpack[TransformersKwargs],
) -> list[torch.Tensor]:
r"""
Returns the image embeddings by passing the pixel values through the vision encoder.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Input pixel values
"""
batch_size = pixel_values.shape[0]
image_outputs = self.get_image_features(pixel_values, return_dict=True, **kwargs)
feature_maps = image_outputs.fpn_hidden_states
# add no memory embedding to the last feature map
feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
# reshape feature maps to the same shape as the backbone feature sizes
image_embeddings = [
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
]
return image_embeddings
@torch.no_grad()
def get_prompt_embeddings(
self,
input_points: torch.FloatTensor | None = None,
input_labels: torch.LongTensor | None = None,
input_boxes: torch.FloatTensor | None = None,
input_masks: torch.LongTensor | None = None,
):
r"""
Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
Args:
input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
Optional input points for the prompt encoder. The padding of the point is automatically done by the
processor. `point_batch_size` refers to the number of masks that we want the model to predict per
point. The model will output `point_batch_size` times 3 masks in total.
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
processor, or can be fed by the user.
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
processor. users can also pass manually the input boxes.
input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
Optional input masks for the prompt encoder.
"""
prompt_output = self.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
input_masks=input_masks,
)
return prompt_output
@check_model_inputs
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
input_points: torch.FloatTensor | None = None,
input_labels: torch.LongTensor | None = None,
input_boxes: torch.FloatTensor | None = None,
input_masks: torch.LongTensor | None = None,
image_embeddings: torch.FloatTensor | None = None,
multimask_output: bool = True,
attention_similarity: torch.FloatTensor | None = None,
target_embedding: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> Sam2ImageSegmentationOutput:
r"""
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
better results. The points can be obtained by passing a list of list of list to the processor that will
create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
coordinates of the point. If a different number of points is passed either for each image, or for each
mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
computation of the embedding will be skipped for these points using the labels.
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
official implementation, there are 3 types of labels
- `1`: the point is a point that contains the object of interest
- `0`: the point is a point that does not contain the object of interest
- `-1`: the point corresponds to the background
We added the label:
- `-10`: the point is a padding point, thus should be ignored by the prompt encoder
The padding labels should be automatically done by the processor.
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
In the order (`x1`, `y1`, `x2`, `y2`):
- `x1`: the x coordinate of the top left point of the input box
- `y1`: the y coordinate of the top left point of the input box
- `x2`: the x coordinate of the bottom right point of the input box
- `y2`: the y coordinate of the bottom right point of the input box
input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
multimask_output (`bool`, *optional*):
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
"best" mask, by specifying `multimask_output=False`.
attention_similarity (`torch.FloatTensor`, *optional*):
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
target_embedding (`torch.FloatTensor`, *optional*):
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
Example:
```python
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> from transformers import AutoModel, AutoProcessor
>>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny")
>>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny")
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
>>> with httpx.stream("GET", url) as response:
... raw_image = Image.open(BytesIO(response.read())).convert("RGB")
>>> input_points = [[[400, 650]]] # 2D location of a window on the car
>>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
>>> # Get segmentation mask
>>> outputs = model(**inputs)
>>> # Postprocess masks
>>> masks = processor.post_process_masks(
... outputs.pred_masks, inputs["original_sizes"]
... )
```
"""
if not ((pixel_values is None) ^ (image_embeddings is None)):
raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
if input_points is not None and input_boxes is not None:
if input_points.shape[1] != input_boxes.shape[1]:
raise ValueError(
f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
)
image_positional_embeddings = self.get_image_wide_positional_embeddings()
# repeat with batch size
batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_attentions = None
vision_hidden_states = None
if pixel_values is not None:
image_outputs: Sam2VisionEncoderOutput = self.get_image_features(pixel_values, return_dict=True, **kwargs)
feature_maps = image_outputs.fpn_hidden_states
vision_hidden_states = image_outputs.hidden_states
vision_attentions = image_outputs.attentions
# add no memory embedding to the last feature map
feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
# reshape feature maps to the same shape as the backbone feature sizes
image_embeddings = [
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
]
if input_points is not None and input_labels is None:
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
if input_points is None and input_boxes is None:
# If no points are provide, pad with an empty point (with label -1)
input_points = torch.zeros(
batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
)
input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
if input_masks is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder
if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
input_masks = F.interpolate(
input_masks.float(),
size=self.prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
).to(input_masks.dtype)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
input_masks=input_masks,
)
low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
image_embeddings=image_embeddings[-1],
image_positional_embeddings=image_positional_embeddings,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
high_resolution_features=image_embeddings[:-1],
attention_similarity=attention_similarity,
target_embedding=target_embedding,
**kwargs,
)
return Sam2ImageSegmentationOutput(
iou_scores=iou_scores,
pred_masks=low_res_multimasks,
object_score_logits=object_score_logits,
image_embeddings=image_embeddings,
vision_hidden_states=vision_hidden_states,
vision_attentions=vision_attentions,
)
@can_return_tuple
@auto_docstring
def get_image_features(
self,
pixel_values: torch.FloatTensor,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Sam2VisionEncoderOutput:
r"""
pixel_values (`torch.FloatTensor`):
Input pixel values of shape `(batch_size, num_channels, height, width)`.
"""
vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)
feature_maps = vision_outputs.fpn_hidden_states
feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
feature_maps = list(feature_maps)
feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
# flatten NxCxHxW to HWxNxC
feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
feature_maps_position_embeddings = [
feature_map_position_embedding.flatten(2).permute(2, 0, 1)
for feature_map_position_embedding in feature_maps_position_embeddings
]
vision_outputs.fpn_hidden_states = feature_maps
vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
return vision_outputs
__all__ = ["Sam2Model", "Sam2VisionModel", "Sam2PreTrainedModel", "Sam2HieraDetModel"]