You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2413 lines
99 KiB

# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable, Iterable
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch import Tensor
from transformers import CLIPTextModelWithProjection
from ... import initialization as init
from ...activations import ACT2FN
from ...masking_utils import create_bidirectional_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ModelOutput,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import compile_compatible_method_lru_cache
from ...utils import auto_docstring, can_return_tuple, logging
from ...utils.generic import TransformersKwargs, check_model_inputs, is_flash_attention_requested
from ..auto import AutoModel
from .configuration_sam3 import (
Sam3Config,
Sam3DETRDecoderConfig,
Sam3DETREncoderConfig,
Sam3GeometryEncoderConfig,
Sam3MaskDecoderConfig,
Sam3VisionConfig,
Sam3ViTConfig,
)
logger = logging.get_logger(__name__)
@dataclass
@auto_docstring
class Sam3VisionEncoderOutput(BaseModelOutputWithPooling):
r"""
fpn_hidden_states (`tuple[torch.FloatTensor]`):
Tuple of multi-level FPN feature maps.
fpn_position_encoding (`tuple[torch.FloatTensor]`):
Tuple of position encodings for each FPN level.
"""
fpn_hidden_states: tuple[torch.FloatTensor, ...] = None
fpn_position_encoding: tuple[torch.FloatTensor, ...] = None
@dataclass
@auto_docstring
class Sam3GeometryEncoderOutput(ModelOutput):
r"""
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_prompts, hidden_size)`):
Encoded geometry prompt features (boxes).
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_prompts)`, *optional*):
Attention mask for geometry prompts where True indicates valid positions and False indicates padding.
"""
last_hidden_state: torch.FloatTensor = None
attention_mask: torch.BoolTensor | None = None
@dataclass
@auto_docstring
class Sam3DETREncoderOutput(ModelOutput):
r"""
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Encoded vision features (flattened from multi-level features).
pos_embeds_flattened (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Flattened position embeddings for the vision features.
text_features (`torch.FloatTensor` of shape `(batch_size, text_seq_len, hidden_size)`, *optional*):
Text features (may be pooled after encoder processing).
spatial_shapes (`torch.LongTensor` of shape `(num_levels, 2)`, *optional*):
Spatial shapes (height, width) for each feature pyramid level.
hidden_states (`tuple[torch.FloatTensor]`, *optional*):
Tuple of hidden states from all encoder layers.
attentions (`tuple[torch.FloatTensor]`, *optional*):
Tuple of attention weights from all encoder layers.
"""
last_hidden_state: torch.FloatTensor = None
pos_embeds_flattened: torch.FloatTensor | None = None
text_features: torch.FloatTensor | None = None
spatial_shapes: torch.LongTensor | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
@dataclass
@auto_docstring
class Sam3DETRDecoderOutput(ModelOutput):
r"""
intermediate_hidden_states (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, hidden_size)`):
Decoder hidden states from all layers.
reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`):
Predicted reference boxes from all decoder layers in (cx, cy, w, h) format.
presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`):
Presence logits from all decoder layers indicating object presence confidence.
hidden_states (`tuple[torch.FloatTensor]`, *optional*):
Tuple of hidden states from all decoder layers.
attentions (`tuple[torch.FloatTensor]`, *optional*):
Tuple of attention weights from all decoder layers (self-attention and cross-attention).
"""
intermediate_hidden_states: torch.FloatTensor = None
reference_boxes: torch.FloatTensor = None
presence_logits: torch.FloatTensor = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
@dataclass
@auto_docstring
class Sam3MaskDecoderOutput(ModelOutput):
r"""
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
Predicted segmentation masks for each query.
semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
Semantic segmentation output.
attentions (`tuple[torch.FloatTensor]`, *optional*):
Tuple of attention weights from mask decoder cross-attention layers.
"""
pred_masks: torch.FloatTensor = None
semantic_seg: torch.FloatTensor | None = None
attentions: tuple[torch.FloatTensor] | None = None
@dataclass
@auto_docstring
class Sam3ImageSegmentationOutput(ModelOutput):
r"""
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
Predicted segmentation masks for each query.
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Predicted bounding boxes in (x1, y1, x2, y2) format.
pred_logits (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Classification confidence scores for each query, computed via dot product between
decoder query features and text features.
presence_logits (`torch.FloatTensor` of shape `(batch_size, 1)`, *optional*):
Presence logits from the DETR decoder presence token (last layer only). These indicate whether objects
are present in the scene. Can be used to compute final scores by multiplying with pred_logits:
`final_scores = pred_logits.sigmoid() * presence_logits.sigmoid()`.
semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
Semantic segmentation output.
decoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
Tuple of hidden states from all DETR decoder layers. Each tensor has shape `(batch_size, num_queries, hidden_size)`.
decoder_reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`, *optional*):
Reference boxes from all DETR decoder layers.
encoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
Tuple of hidden states from all DETR encoder layers.
vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
Tuple of hidden states from all vision encoder (ViT) layers.
vision_attentions (`tuple[torch.FloatTensor]`, *optional*):
Attention weights from vision encoder (ViT) layers.
detr_encoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
Attention weights from DETR encoder layers.
detr_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
Attention weights from DETR decoder layers (self-attention and cross-attention).
mask_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
Attention weights from mask decoder layers.
"""
pred_masks: torch.FloatTensor = None
pred_boxes: torch.FloatTensor = None
pred_logits: torch.FloatTensor | None = None
presence_logits: torch.FloatTensor | None = None
semantic_seg: torch.FloatTensor | None = None
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
decoder_reference_boxes: torch.FloatTensor | None = None
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
vision_hidden_states: tuple[torch.FloatTensor] | None = None
vision_attentions: tuple[torch.FloatTensor] | None = None
detr_encoder_attentions: tuple[torch.FloatTensor] | None = None
detr_decoder_attentions: tuple[torch.FloatTensor] | None = None
mask_decoder_attentions: tuple[torch.FloatTensor] | None = None
def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
"""The inverse function for sigmoid activation function."""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
"""
Concatenates two right-padded sequences, such that the resulting sequence
is contiguous and also right-padded.
Tensors are batch-first, masks are batch-first with True=valid, False=padding.
Args:
seq1: A tensor of shape (batch_size, seq1_length, hidden_size).
mask1: A tensor of shape (batch_size, seq1_length) with True=valid, False=padding.
seq2: A tensor of shape (batch_size, seq2_length, hidden_size).
mask2: A tensor of shape (batch_size, seq2_length) with True=valid, False=padding.
return_index: If True, also returns the index of the ids of the element of seq2
in the concatenated sequence. This can be used to retrieve the elements of seq2.
Returns:
A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
otherwise (concatenated_sequence, concatenated_mask, index).
The concatenated_mask uses True=valid, False=padding convention.
"""
batch_size, seq1_length, hidden_size = seq1.shape
batch_size2, seq2_length, hidden_size2 = seq2.shape
assert batch_size == batch_size2 == mask1.size(0) == mask2.size(0)
assert hidden_size == hidden_size2
assert seq1_length == mask1.size(1)
assert seq2_length == mask2.size(1)
actual_seq1_lengths = mask1.sum(dim=-1)
actual_seq2_lengths = mask2.sum(dim=-1)
final_lengths = actual_seq1_lengths + actual_seq2_lengths
max_length = seq1_length + seq2_length
concatenated_mask = (
torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) < final_lengths[:, None]
)
concatenated_sequence = torch.zeros((batch_size, max_length, hidden_size), device=seq2.device, dtype=seq2.dtype)
concatenated_sequence[:, :seq1_length, :] = seq1
# Shift seq2 elements to start at the end of valid seq1
index = torch.arange(seq2_length, device=seq2.device)[None].repeat(batch_size, 1)
index = index + actual_seq1_lengths[:, None]
# Scatter seq2 into the right positions
concatenated_sequence = concatenated_sequence.scatter(1, index[:, :, None].expand(-1, -1, hidden_size), seq2)
if return_index:
return concatenated_sequence, concatenated_mask, index
return concatenated_sequence, concatenated_mask
def box_cxcywh_to_xyxy(x):
"""Convert boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format."""
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
class Sam3MLP(nn.Module):
def __init__(self, config: Sam3ViTConfig):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float | None = None,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Sam3Attention(nn.Module):
"""
Multi-head attention.
Handles standard [batch_size, seq_len, hidden_size] tensors.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.hidden_size // config.num_attention_heads
self.scaling = self.head_dim**-0.5
self.is_causal = False
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query: [batch_size, query_len, hidden_size]
key: [batch_size, key_len, hidden_size]
value: [batch_size, value_len, hidden_size]
attention_mask: [batch_size, num_heads, query_len, key_len] or broadcastable
Returns:
Tuple of (output, attention_weights)
output: [batch_size, query_len, hidden_size]
attention_weights: [batch_size, num_heads, query_len, key_len]
"""
batch_size = query.shape[0]
query_len = query.shape[1]
key_len = key.shape[1]
query = self.q_proj(query).view(batch_size, query_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
key = self.k_proj(key).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
value = self.v_proj(value).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
if (
is_flash_attention_requested(self.config)
and attention_mask is not None
and attention_mask.dtype != torch.bool
):
# Relative position bias tensors are represented as float masks and are incompatible with Flash Attention
# Fallback to SDPA for this call only so the rest of the model can still benefit from FA
attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
logger.warning_once(
"Sam3Attention: falling back to SDPA for relative-position cross-attention because "
"Flash Attention does not support additive bias masks."
)
attn_output, attn_weights = attention_interface(
self,
query,
key,
value,
attention_mask=attention_mask,
dropout=0.0,
scaling=self.scaling,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, query_len, self.num_attention_heads * self.head_dim).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Sam3ViTRotaryEmbedding(nn.Module):
"""
Vision Rotary Position Embedding for SAM3, following transformers library standards.
Supports 2D (axial) rotary embeddings for spatial dimensions.
"""
def __init__(self, config: Sam3ViTConfig, end_x: int, end_y: int, scale: float = 1.0):
super().__init__()
dim = config.hidden_size // config.num_attention_heads
# Ensure even dimension for proper axial splitting
if dim % 4 != 0:
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
self.end_x, self.end_y = end_x, end_y
self.dim = dim
self.rope_theta = config.rope_theta
self.scale = scale
freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
x_positions = (flattened_indices % end_x) * scale
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * scale
freqs_x = torch.outer(x_positions, freqs).float()
freqs_y = torch.outer(y_positions, freqs).float()
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
# directly register the cos and sin embeddings as we have a fixed feature shape
self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
@torch.no_grad()
def forward(self) -> tuple[torch.Tensor, torch.Tensor]:
# As the feature map size is fixed for each stage, we can just return the pre-computed embeddings.
return self.rope_embeddings_cos, self.rope_embeddings_sin
def rotate_pairwise(x):
"""
pairwise rotation of the hidden dims of the input. Differerent from Llama Half-Tensor Rotation.
This is an optimized version of the following more explicit implementation:
```python
x_rotated = torch.zeros_like(x, dtype=x.dtype, device=x.device)
x_rotated[..., ::2] = -x[..., 1::2]
x_rotated[..., 1::2] = x[..., ::2]
return x_rotated
```
"""
x = x.view(*x.shape[:-1], -1, 2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(start_dim=-2)
def apply_rotary_pos_emb_2d(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embedding to query and key tensors for self-attention.
Args:
q: Query tensor of shape (batch_size, num_windows, seq_len, num_heads, head_dim)
k: Key tensor of shape (batch_size, num_windows, seq_len, num_heads, head_dim)
cos: Cosine position embedding of shape (seq_len, head_dim)
sin: Sine position embedding of shape (seq_len, head_dim)
Returns:
Rotated (q, k) tensors
"""
q_embed = q.float()
q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
k_embed = k.float()
k_embed = (k_embed * cos) + (rotate_pairwise(k_embed) * sin)
return q_embed.type_as(q), k_embed.type_as(k)
class Sam3ViTRoPEAttention(nn.Module):
"""Self-attention with rotary position encoding."""
def __init__(self, config: Sam3ViTConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.hidden_size // config.num_attention_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = False
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
**kwargs: Unpack[TransformersKwargs],
) -> Tensor:
batch_size, height, width, _ = hidden_states.shape
seq_len = height * width
new_shape = (batch_size, seq_len, self.num_attention_heads, self.head_dim)
query = self.q_proj(hidden_states).view(*new_shape).transpose(1, 2)
key = self.k_proj(hidden_states).view(*new_shape).transpose(1, 2)
value = self.v_proj(hidden_states).view(*new_shape).transpose(1, 2)
cos, sin = position_embeddings
query, key = apply_rotary_pos_emb_2d(query, key, cos=cos, sin=sin)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query,
key,
value,
attention_mask=None,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, height, width, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Sam3ViTPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config: Sam3ViTConfig):
super().__init__()
image_size, patch_size = config.pretrain_image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2)
return embeddings
class Sam3ViTEmbeddings(nn.Module):
"""
Construct the patch embeddings and position embeddings for SAM3 ViT.
Position embeddings are tiled (not interpolated) when resizing to match different input sizes.
"""
def __init__(self, config: Sam3ViTConfig):
super().__init__()
self.patch_embeddings = Sam3ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(
torch.randn(1, num_patches, config.hidden_size)
) # !Remove cls token in convert weights!
self.dropout = nn.Dropout(config.hidden_dropout)
self.patch_size = config.patch_size
def _tile_position_embeddings(
self,
position_embeddings: torch.Tensor,
height: int,
width: int,
) -> torch.Tensor:
"""
Tile position embeddings to match target spatial dimensions.
Args:
position_embeddings: Shape [1, num_pretrain_patches, hidden_size]
height: Target height in patches
width: Target width in patches
Returns:
Shape [1, height * width, hidden_size]
"""
pretrain_size = int(position_embeddings.shape[1] ** 0.5)
# Skip tiling if sizes match (but always tile during tracing for consistent graph)
if not torch.jit.is_tracing() and pretrain_size == height and pretrain_size == width:
return position_embeddings.reshape(1, height * width, -1)
# Tile position embeddings to match target spatial dimensions
hidden_size = position_embeddings.shape[-1]
pos_embed = position_embeddings.reshape(1, pretrain_size, pretrain_size, hidden_size).permute(0, 3, 1, 2)
repeat_h = height // pretrain_size + 1
repeat_w = width // pretrain_size + 1
pos_embed = pos_embed.tile([1, 1, repeat_h, repeat_w])[:, :, :height, :width]
return pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, hidden_size)
def forward(
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
height, width = pixel_values.shape[-2:]
embeddings = self.patch_embeddings(pixel_values)
# Calculate spatial dimensions in patches
height_patches = height // self.patch_size
width_patches = width // self.patch_size
position_embeddings = self._tile_position_embeddings(
self.position_embeddings,
height_patches,
width_patches,
)
embeddings = embeddings + position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
def window_partition(hidden_state, window_size):
"""
Partition into non-overlapping windows with padding if needed.
Args:
hidden_state (`torch.Tensor`):
Input tokens with [batch_size, height, width, num_channels].
window_size (`int`):
Window size.
Returns:
`tuple(torch.FloatTensor)` comprising various elements:
- windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
- (padded_height, padded_width): padded height and width before partition
"""
batch_size, height, width, num_channels = hidden_state.shape
pad_height = (window_size - height % window_size) % window_size
pad_width = (window_size - width % window_size) % window_size
# Noop in case pad_width == 0 and pad_height == 0.
hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
padded_height, padded_width = height + pad_height, width + pad_width
hidden_state = hidden_state.view(
batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
)
windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
return windows, (padded_height, padded_width)
def window_unpartition(windows, window_size, pad_height_width, height_width):
"""
Window unpartition into original sequences and removing padding.
Args:
windows (`torch.Tensor`):
Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
window_size (`int`):
Window size.
pad_height_width (`tuple[int]`):
Padded height and width (padded_height, padded_width).
height_width (`tuple[int]`):
Original height and width before padding.
Returns:
hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
"""
padded_height, padded_width = pad_height_width
height, width = height_width
batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
hidden_state = windows.view(
batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
)
hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
# We always have height <= padded_height and width <= padded_width
hidden_state = hidden_state[:, :height, :width, :].contiguous()
return hidden_state
class Sam3ViTLayerScale(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.lambda1 = nn.Parameter(config.layer_scale_init_value * torch.ones(config.hidden_size))
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
return hidden_state * self.lambda1
class Sam3ViTLayer(GradientCheckpointingLayer):
"""Vision Transformer layer with rotary position embeddings and optional windowed attention."""
def __init__(self, config: Sam3ViTConfig, window_size: int = 0) -> None:
super().__init__()
hidden_size = config.hidden_size
image_size = config.image_size
image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
patch_size = config.patch_size
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.layer_norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
rotary_input_size = input_size if window_size == 0 else (window_size, window_size)
rotary_scale = config.window_size / rotary_input_size[0]
self.rotary_emb = Sam3ViTRotaryEmbedding(
config, end_x=rotary_input_size[0], end_y=rotary_input_size[1], scale=rotary_scale
)
self.attention = Sam3ViTRoPEAttention(config)
self.layer_norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.mlp = Sam3MLP(config)
self.dropout = nn.Dropout(config.hidden_dropout)
self.window_size = window_size
def forward(
self,
hidden_states: torch.Tensor,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
if self.window_size > 0:
height, width = hidden_states.shape[1], hidden_states.shape[2]
# Partition into non-overlapping windows for efficient attention
hidden_states, pad_height_width = window_partition(hidden_states, self.window_size)
position_embeddings = self.rotary_emb()
hidden_states, _ = self.attention(hidden_states, position_embeddings, **kwargs)
if self.window_size > 0:
# Reverse window partition to restore original spatial layout
hidden_states = window_unpartition(hidden_states, self.window_size, pad_height_width, (height, width))
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.dropout(hidden_states)
return hidden_states
@auto_docstring
class Sam3PreTrainedModel(PreTrainedModel):
config_class = Sam3Config
base_model_prefix = "sam3"
main_input_name = "pixel_values"
input_modalities = ["image", "text"]
_supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True
_supports_attention_backend = True
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, Sam3ViTEmbeddings):
init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
elif isinstance(module, Sam3ViTRotaryEmbedding):
end_x, end_y = module.end_x, module.end_y
dim = module.dim
freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
x_positions = (flattened_indices % end_x) * module.scale
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale
freqs_x = torch.outer(x_positions, freqs).float()
freqs_y = torch.outer(y_positions, freqs).float()
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
@auto_docstring
class Sam3ViTModel(Sam3PreTrainedModel):
def __init__(self, config: Sam3ViTConfig):
super().__init__(config)
self.config = config
self.embeddings = Sam3ViTEmbeddings(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layers = nn.ModuleList(
[
Sam3ViTLayer(config, window_size=config.window_size if i not in config.global_attn_indexes else 0)
for i in range(config.num_hidden_layers)
]
)
self.post_init()
def get_input_embeddings(self) -> Sam3ViTPatchEmbeddings:
return self.embeddings.patch_embeddings
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: torch.Tensor,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
hidden_states = self.embeddings(pixel_values) # [batch_size, seq_len, hidden_size]
batch_size = hidden_states.shape[0]
height = pixel_values.shape[-2] // self.config.patch_size
width = pixel_values.shape[-1] // self.config.patch_size
hidden_size = hidden_states.shape[-1]
# Reshape to spatial format for windowed attention: [batch_size, height, width, hidden_size]
hidden_states = hidden_states.view(batch_size, height, width, hidden_size)
hidden_states = self.layer_norm(hidden_states)
for layer in self.layers:
hidden_states = layer(hidden_states, **kwargs)
# Reshape back to sequence format: [batch_size, height*width, hidden_size]
hidden_states = hidden_states.view(batch_size, height * width, hidden_size)
return BaseModelOutput(last_hidden_state=hidden_states)
class Sam3SinePositionEmbedding(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None
):
super().__init__()
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
self.scale = 2 * math.pi if scale is None else scale
def encode_1d_positions(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Encode 1D coordinate pairs using sine/cosine positional embeddings.
Args:
x: 1D tensor of x coordinates (flattened)
y: 1D tensor of y coordinates (flattened)
Returns:
Tuple of (pos_x, pos_y) positional embeddings
"""
x_embed = x * self.scale
y_embed = y * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).to(x.dtype)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, None] / dim_t
pos_y = y_embed[:, None] / dim_t
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
return pos_x, pos_y
def encode_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""
Encode 4D box coordinates (x, y, w, h) for decoder conditioning using sine/cosine embeddings.
Args:
boxes: Box coordinates [batch_size, num_queries, 4] in (x, y, w, h) format
Returns:
Position embeddings [batch_size, num_queries, num_pos_feats*4]
"""
assert boxes.size(-1) == 4, f"Expected 4D box coordinates (x, y, w, h), got shape {boxes.shape}"
dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=boxes.device).to(boxes.dtype)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
x_embed = boxes[:, :, 0] * self.scale
y_embed = boxes[:, :, 1] * self.scale
w_embed = boxes[:, :, 2] * self.scale
h_embed = boxes[:, :, 3] * self.scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_w = w_embed[:, :, None] / dim_t
pos_h = h_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
return pos
@compile_compatible_method_lru_cache(maxsize=4)
def forward(
self,
shape: torch.Size,
device: torch.device | str,
dtype: torch.dtype,
mask: Tensor | None = None,
) -> Tensor:
if mask is None:
mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
not_mask = (~mask).to(dtype)
y_embed = not_mask.cumsum(1)
x_embed = not_mask.cumsum(2)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class Sam3FPNLayer(nn.Module):
def __init__(self, in_channels: int, fpn_dim: int, scale_factor: float):
super().__init__()
self.scale_factor = scale_factor
# Build the upsampling/downsampling layers based on scale factor
self.scale_layers = nn.ModuleList()
if scale_factor == 4.0:
self.scale_layers.append(nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2))
self.scale_layers.append(nn.GELU())
self.scale_layers.append(nn.ConvTranspose2d(in_channels // 2, in_channels // 4, kernel_size=2, stride=2))
intermediate_channels = in_channels // 4
elif scale_factor == 2.0:
self.scale_layers.append(nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2))
intermediate_channels = in_channels // 2
elif scale_factor == 1.0:
intermediate_channels = in_channels
elif scale_factor == 0.5:
self.scale_layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
intermediate_channels = in_channels
else:
raise NotImplementedError(f"scale_factor={scale_factor} is not supported yet.")
self.proj1 = nn.Conv2d(in_channels=intermediate_channels, out_channels=fpn_dim, kernel_size=1)
self.proj2 = nn.Conv2d(in_channels=fpn_dim, out_channels=fpn_dim, kernel_size=3, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.to(self.proj1.weight.dtype)
for layer in self.scale_layers:
hidden_states = layer(hidden_states)
hidden_states = self.proj1(hidden_states)
hidden_states = self.proj2(hidden_states)
return hidden_states
class Sam3VisionNeck(nn.Module):
def __init__(self, config: Sam3VisionConfig):
super().__init__()
self.config = config
self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
# Create one FPN layer per scale factor
self.fpn_layers = nn.ModuleList(
[
Sam3FPNLayer(
in_channels=config.backbone_config.hidden_size, fpn_dim=config.fpn_hidden_size, scale_factor=scale
)
for scale in config.scale_factors
]
)
def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
fpn_hidden_states = ()
fpn_position_encoding = ()
for fpn_layer in self.fpn_layers:
fpn_output = fpn_layer(hidden_states)
fpn_hidden_states += (fpn_output,)
# Generate position encoding for this FPN level
pos_enc = self.position_encoding(fpn_output.shape, fpn_output.device, fpn_output.dtype)
fpn_position_encoding += (pos_enc,)
return fpn_hidden_states, fpn_position_encoding
@auto_docstring(
custom_intro="""
The vision model from Sam without any head or projection on top.
"""
)
class Sam3VisionModel(Sam3PreTrainedModel):
config_class = Sam3VisionConfig
main_input_name = "pixel_values"
_can_record_outputs = {
"hidden_states": Sam3ViTLayer,
"attentions": Sam3ViTRoPEAttention,
}
def __init__(self, config: Sam3VisionConfig):
super().__init__(config)
self.config = config
self.backbone = AutoModel.from_config(config.backbone_config)
self.neck = Sam3VisionNeck(config)
self.post_init()
def get_input_embeddings(self):
return self.backbone.get_input_embeddings()
@check_model_inputs
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Sam3VisionEncoderOutput:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
backbone_output = self.backbone(pixel_values, **kwargs)
hidden_states = backbone_output.last_hidden_state # [batch_size, seq_len, hidden_size]
# Reshape for FPN neck: [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size, height, width]
batch_size = hidden_states.shape[0]
height = pixel_values.shape[-2] // self.config.backbone_config.patch_size
width = pixel_values.shape[-1] // self.config.backbone_config.patch_size
hidden_states_spatial = hidden_states.view(batch_size, height, width, -1).permute(0, 3, 1, 2)
fpn_hidden_states, fpn_position_encoding = self.neck(hidden_states_spatial)
return Sam3VisionEncoderOutput(
last_hidden_state=hidden_states,
fpn_hidden_states=fpn_hidden_states,
fpn_position_encoding=fpn_position_encoding,
)
class Sam3GeometryEncoderLayer(nn.Module):
def __init__(self, config: Sam3GeometryEncoderConfig):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size)
self.self_attn = Sam3Attention(config)
self.dropout = nn.Dropout(config.dropout)
self.cross_attn = Sam3Attention(config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size)
self.mlp = Sam3MLP(config)
self.layer_norm3 = nn.LayerNorm(config.hidden_size)
def forward(
self,
prompt_feats: Tensor,
vision_feats: Tensor,
vision_pos_encoding: Tensor,
prompt_mask: Tensor,
**kwargs: Unpack[TransformersKwargs],
):
residual = prompt_feats
hidden_states = self.layer_norm1(prompt_feats)
hidden_states, _ = self.self_attn(
query=hidden_states, key=hidden_states, value=hidden_states, attention_mask=prompt_mask, **kwargs
)
hidden_states = self.dropout(hidden_states) + residual
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
key = vision_feats + vision_pos_encoding
hidden_states, _ = self.cross_attn(query=hidden_states, key=key, value=vision_feats, **kwargs)
hidden_states = self.dropout(hidden_states) + residual
residual = hidden_states
hidden_states = self.layer_norm3(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.dropout(hidden_states) + residual
return hidden_states
class Sam3GeometryEncoder(nn.Module):
"""
Encoder for geometric prompts (boxes).
Boxes are encoded using three approaches:
- Direct projection: linear projection from coordinate space to hidden_size
- Pooling: pool features from the backbone at the specified location (ROI align for boxes)
- Position encoding: use position encoding of the box center
These encodings are combined additively and further processed with transformer layers.
"""
def __init__(self, config: Sam3GeometryEncoderConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.roi_size = config.roi_size
self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=True)
self.label_embed = nn.Embedding(2, self.hidden_size)
self.cls_embed = nn.Embedding(1, self.hidden_size)
# Box encoding layers
self.boxes_direct_project = nn.Linear(4, self.hidden_size)
self.boxes_pool_project = nn.Conv2d(self.hidden_size, self.hidden_size, self.roi_size)
self.boxes_pos_enc_project = nn.Linear(self.hidden_size + 2, self.hidden_size)
# Image feature normalization
self.vision_layer_norm = nn.LayerNorm(self.hidden_size)
# Prompt projection and normalization
self.final_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.prompt_layer_norm = nn.LayerNorm(self.hidden_size)
# Transformer layers
self.layers = nn.ModuleList([Sam3GeometryEncoderLayer(config) for _ in range(config.num_layers)])
self.output_layer_norm = nn.LayerNorm(self.hidden_size)
def _encode_box_coordinates(
self, center_x: torch.Tensor, center_y: torch.Tensor, width: torch.Tensor, height: torch.Tensor
) -> torch.Tensor:
"""
Encode box coordinates by combining position-encoded centers with raw width/height.
Args:
center_x: 1D tensor of box center x coordinates
center_y: 1D tensor of box center y coordinates
width: 1D tensor of box widths
height: 1D tensor of box heights
Returns:
Encoded box coordinates [N, embedding_dim]
"""
pos_x, pos_y = self.position_encoding.encode_1d_positions(center_x, center_y)
pos = torch.cat((pos_y, pos_x, height[:, None], width[:, None]), dim=1)
return pos
def _encode_boxes(self, boxes, boxes_mask, boxes_labels, vision_features):
"""Encode box prompts. Mask convention: True=valid, False=padding."""
batch_size, num_boxes = boxes.shape[:2]
height, width = vision_features.shape[-2:]
boxes_embed = self.boxes_direct_project(boxes)
# Pool features using ROI align
# Convert boxes from CxCyWH to xyxy format and denormalize
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([width, height, width, height], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
scale = scale.view(1, 1, 4)
boxes_xyxy = boxes_xyxy * scale
# ROI align expects list of boxes per batch element,
# convert from bfloat16 to float16 as roi_align only supports float16 and float32
dtype = torch.float16 if vision_features.dtype == torch.bfloat16 else vision_features.dtype
sampled_features = torchvision.ops.roi_align(
vision_features.to(dtype), boxes_xyxy.to(dtype).unbind(0), self.roi_size
).to(vision_features.dtype)
pooled_projection = self.boxes_pool_project(sampled_features)
pooled_projection = pooled_projection.view(batch_size, num_boxes, self.hidden_size)
boxes_embed = boxes_embed + pooled_projection
# Add position encoding
center_x, center_y, box_width, box_height = boxes.unbind(-1)
pos_enc = self._encode_box_coordinates(
center_x.flatten(), center_y.flatten(), box_width.flatten(), box_height.flatten()
)
pos_enc = pos_enc.view(batch_size, num_boxes, pos_enc.shape[-1])
pos_projection = self.boxes_pos_enc_project(pos_enc)
boxes_embed = boxes_embed + pos_projection
# Add label embeddings (positive/negative)
label_embed = self.label_embed(boxes_labels.long())
return label_embed + boxes_embed, boxes_mask
def forward(
self,
box_embeddings: torch.Tensor,
box_mask: torch.Tensor,
box_labels: torch.Tensor,
img_feats: tuple[torch.Tensor, ...],
img_pos_embeds: tuple[torch.Tensor, ...] | None = None,
):
"""
Forward pass for encoding geometric prompts.
Args:
box_embeddings: Box coordinates in CxCyWH format [batch_size, num_boxes, 4]
box_mask: Attention mask for boxes [batch_size, num_boxes]
box_labels: Labels for boxes (positive/negative) [batch_size, num_boxes]
img_feats: Image features from vision encoder
img_pos_embeds: Optional position embeddings for image features
Returns:
Sam3GeometryEncoderOutput containing encoded geometry features and attention mask.
"""
batch_size = box_embeddings.shape[0]
# Prepare vision features for cross-attention: flatten spatial dimensions
vision_feats = img_feats[-1] # [B, C, H, W]
vision_pos_embeds = img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(vision_feats)
vision_feats_flat = vision_feats.flatten(2).transpose(1, 2) # [B, H*W, C]
vision_pos_embeds_flat = vision_pos_embeds.flatten(2).transpose(1, 2) # [B, H*W, C]
# Normalize image features for pooling operations
img_feats_last = img_feats[-1] # [B, C, H, W]
img_feats_last = img_feats_last.permute(0, 2, 3, 1) # [B, H, W, C]
normalized_img_feats = self.vision_layer_norm(img_feats_last)
normalized_img_feats = normalized_img_feats.permute(0, 3, 1, 2) # [B, C, H, W]
prompt_embeds, prompt_mask = self._encode_boxes(box_embeddings, box_mask, box_labels, normalized_img_feats)
# Add CLS token (always valid)
cls_embed = self.cls_embed.weight.view(1, self.hidden_size).unsqueeze(0).expand(batch_size, -1, -1)
cls_mask = torch.ones(batch_size, 1, dtype=prompt_mask.dtype, device=prompt_mask.device)
prompt_embeds, prompt_mask = concat_padded_sequences(prompt_embeds, prompt_mask, cls_embed, cls_mask)
prompt_embeds = self.prompt_layer_norm(self.final_proj(prompt_embeds))
# Create bidirectional attention mask for transformer layers
prompt_attention_mask = None
if prompt_mask is not None:
prompt_attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=prompt_embeds,
attention_mask=prompt_mask,
)
# Apply transformer layers with cross-attention to vision features
for layer in self.layers:
prompt_embeds = layer(
prompt_feats=prompt_embeds,
vision_feats=vision_feats_flat,
vision_pos_encoding=vision_pos_embeds_flat,
prompt_mask=prompt_attention_mask,
)
# Final output normalization
prompt_embeds = self.output_layer_norm(prompt_embeds)
return Sam3GeometryEncoderOutput(
last_hidden_state=prompt_embeds,
attention_mask=prompt_mask,
)
class Sam3DetrEncoderLayer(nn.Module):
"""DETR encoder layer with self-attention and cross-attention."""
def __init__(self, config: Sam3DETREncoderConfig):
super().__init__()
self.config = config
self.layer_norm1 = nn.LayerNorm(config.hidden_size)
self.self_attn = Sam3Attention(config)
self.dropout = nn.Dropout(config.dropout)
self.cross_attn = Sam3Attention(config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size)
self.mlp = Sam3MLP(config)
self.layer_norm3 = nn.LayerNorm(config.hidden_size)
def forward(
self,
vision_feats: Tensor,
prompt_feats: Tensor,
vision_pos_encoding: Tensor,
prompt_cross_attn_mask: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
):
"""
Forward pass for DETR encoder layer.
Args:
vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states)
prompt_feats: Text prompt features [batch_size, text_len, hidden_size]
vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size]
prompt_cross_attn_mask: Cross-attention mask for prompt features
Returns:
Updated vision features [batch_size, vision_len, hidden_size]
"""
# Self-attention on vision features with position encoding
residual = vision_feats
hidden_states = self.layer_norm1(vision_feats)
hidden_states_with_pos = hidden_states + vision_pos_encoding
hidden_states, _ = self.self_attn(
query=hidden_states_with_pos,
key=hidden_states_with_pos,
value=hidden_states,
**kwargs,
)
hidden_states = self.dropout(hidden_states) + residual
# Cross-attention: vision queries attend to text/prompt features
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states, _ = self.cross_attn(
query=hidden_states,
key=prompt_feats,
value=prompt_feats,
attention_mask=prompt_cross_attn_mask,
**kwargs,
)
hidden_states = self.dropout(hidden_states) + residual
# MLP
residual = hidden_states
hidden_states = self.layer_norm3(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.dropout(hidden_states) + residual
return hidden_states
class Sam3DetrEncoder(Sam3PreTrainedModel):
"""
DETR-style encoder that processes multi-level vision features with text fusion.
This encoder processes vision features from multiple levels (e.g., FPN features at different
resolutions) and fuses them with text prompts through a stack of transformer encoder layers.
"""
_can_record_outputs = {
"hidden_states": Sam3DetrEncoderLayer,
"attentions": Sam3Attention,
}
def __init__(self, config: Sam3DETREncoderConfig):
super().__init__(config)
self.config = config
self.hidden_size = config.hidden_size
self.layers = nn.ModuleList([Sam3DetrEncoderLayer(config) for _ in range(config.num_layers)])
self.post_init()
def _prepare_multilevel_features(
self,
vision_features: list[torch.Tensor],
vision_pos_embeds: list[torch.Tensor],
):
"""
Prepare multi-level vision features by flattening spatial dimensions and adding level embeddings.
Args:
vision_features: List of vision features at different levels [batch_size, channels, height, width]
vision_pos_embeds: List of position embeddings for each level [batch_size, channels, height, width]
Returns:
Tuple containing flattened features, position embeddings, and spatial metadata
"""
features_flattened = []
pos_embeds_flattened = []
spatial_shapes = []
for features, pos_embed in zip(vision_features, vision_pos_embeds):
height, width = features.shape[-2:]
spatial_shapes.append((height, width))
# Flatten spatial dimensions: [batch_size, channels, height, width] -> [batch_size, height*width, channels]
features = features.flatten(2).transpose(1, 2)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
features_flattened.append(features)
pos_embeds_flattened.append(pos_embed)
# Concatenate all levels into single sequence
features_flattened = torch.cat(features_flattened, dim=1)
pos_embeds_flattened = torch.cat(pos_embeds_flattened, dim=1)
spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=features_flattened.device)
return (
features_flattened,
pos_embeds_flattened,
spatial_shapes,
)
@check_model_inputs
def forward(
self,
vision_features: list[torch.Tensor],
text_features: torch.Tensor,
vision_pos_embeds: list[torch.Tensor] | None = None,
text_mask: torch.Tensor | None = None,
spatial_sizes: list[tuple[int, int]] | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Sam3DETREncoderOutput:
"""
Forward pass for the DETR encoder.
Args:
vision_features: List of vision features at different levels
text_features: Text prompt features [batch_size, seq_len, hidden_size]
vision_pos_embeds: Optional list of position embeddings for each level
text_mask: Optional text padding mask [batch_size, seq_len]
spatial_sizes: Optional list of (height, width) tuples for reshaping
Returns:
Sam3DETREncoderOutput containing encoded features and metadata.
"""
batch_size = vision_features[0].shape[0] if vision_features[0].dim() == 4 else vision_features[0].shape[1]
# TODO: See if we can remove that reshaping and just use the features as is.
if spatial_sizes is not None:
for i, (height, width) in enumerate(spatial_sizes):
# Reshape from [height*width, batch_size, channels] to [batch_size, channels, height, width]
vision_features[i] = vision_features[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
vision_pos_embeds[i] = vision_pos_embeds[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
# Flatten multi-level features for encoder processing
(
features_flattened,
pos_embeds_flattened,
spatial_shapes,
) = self._prepare_multilevel_features(vision_features, vision_pos_embeds)
prompt_cross_attn_mask = None
if text_mask is not None:
prompt_cross_attn_mask = create_bidirectional_mask(
config=self.config,
input_embeds=features_flattened,
attention_mask=text_mask,
encoder_hidden_states=text_features,
)
hidden_states = features_flattened
for layer in self.layers:
hidden_states = layer(
hidden_states,
prompt_feats=text_features,
vision_pos_encoding=pos_embeds_flattened,
prompt_cross_attn_mask=prompt_cross_attn_mask,
**kwargs,
)
return Sam3DETREncoderOutput(
last_hidden_state=hidden_states,
pos_embeds_flattened=pos_embeds_flattened,
text_features=text_features,
spatial_shapes=spatial_shapes,
)
class Sam3DecoderMLP(nn.Module):
"""Simple 2 or 3-layer MLP for decoder components."""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 2):
super().__init__()
if num_layers == 2:
self.layer1 = nn.Linear(input_dim, hidden_dim)
self.layer2 = nn.Linear(hidden_dim, output_dim)
self.layer3 = None
elif num_layers == 3:
self.layer1 = nn.Linear(input_dim, hidden_dim)
self.layer2 = nn.Linear(hidden_dim, hidden_dim)
self.layer3 = nn.Linear(hidden_dim, output_dim)
else:
raise ValueError(f"Only 2 or 3 layers supported, got {num_layers}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.relu(self.layer1(x))
if self.layer3 is not None:
x = F.relu(self.layer2(x))
x = self.layer3(x)
else:
x = self.layer2(x)
return x
class Sam3DetrDecoderLayer(nn.Module):
"""DETR decoder layer with self-attention, text cross-attention, and vision cross-attention."""
def __init__(self, config: Sam3DETRDecoderConfig):
super().__init__()
self.config = config
self.self_attn = Sam3Attention(config)
self.self_attn_dropout = nn.Dropout(config.dropout)
self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
self.text_cross_attn = Sam3Attention(config)
self.text_cross_attn_dropout = nn.Dropout(config.dropout)
self.text_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)
self.vision_cross_attn = Sam3Attention(config)
self.vision_cross_attn_dropout = nn.Dropout(config.dropout)
self.vision_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)
self.mlp = Sam3MLP(config)
self.mlp_layer_norm = nn.LayerNorm(config.hidden_size)
self.mlp_dropout = nn.Dropout(config.dropout)
def forward(
self,
hidden_states: torch.Tensor,
query_pos: torch.Tensor,
text_features: torch.Tensor,
vision_features: torch.Tensor,
vision_pos_encoding: torch.Tensor,
text_cross_attn_mask: torch.Tensor | None = None,
vision_cross_attn_mask: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
"""
Forward pass for decoder layer.
Args:
hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0)
query_pos: Query position embeddings [batch_size, num_queries, hidden_size]
text_features: Text features [batch_size, seq_len, hidden_size]
vision_features: Vision features [batch_size, height*width, hidden_size]
vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
text_cross_attn_mask: Text cross-attention mask
vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token
Returns:
Updated hidden states (including presence token at position 0)
"""
# Prepend zeros to query_pos for presence token
query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0)
# Self-attention with query position encoding
residual = hidden_states
query_with_pos = hidden_states + query_pos
attn_output, _ = self.self_attn(
query=query_with_pos,
key=query_with_pos,
value=hidden_states,
attention_mask=None,
**kwargs,
)
hidden_states = residual + self.self_attn_dropout(attn_output)
hidden_states = self.self_attn_layer_norm(hidden_states)
# Text cross-attention: queries attend to text features
residual = hidden_states
query_with_pos = hidden_states + query_pos
attn_output, _ = self.text_cross_attn(
query=query_with_pos,
key=text_features,
value=text_features,
attention_mask=text_cross_attn_mask,
**kwargs,
)
hidden_states = residual + self.text_cross_attn_dropout(attn_output)
hidden_states = self.text_cross_attn_layer_norm(hidden_states)
# Vision cross-attention: queries attend to vision features (with RPB)
residual = hidden_states
query_with_pos = hidden_states + query_pos
key_with_pos = vision_features + vision_pos_encoding
attn_output, _ = self.vision_cross_attn(
query=query_with_pos,
key=key_with_pos,
value=vision_features,
attention_mask=vision_cross_attn_mask,
**kwargs,
)
hidden_states = residual + self.vision_cross_attn_dropout(attn_output)
hidden_states = self.vision_cross_attn_layer_norm(hidden_states)
# MLP
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.mlp_dropout(hidden_states)
hidden_states = self.mlp_layer_norm(hidden_states)
return hidden_states
class Sam3DetrDecoder(Sam3PreTrainedModel):
"""
DETR-style decoder with box refinement and presence token.
Simplified version that assumes:
- Box refinement is always enabled
- Intermediate outputs are always returned
- BoxRPB (relative position bias) with log-scale encoding
- Presence token is used
"""
_can_record_outputs = {
"hidden_states": Sam3DetrDecoderLayer,
"attentions": Sam3Attention,
}
def __init__(
self,
config: Sam3DETRDecoderConfig,
):
super().__init__(config)
self.config = config
self.hidden_size = config.hidden_size
self.layers = nn.ModuleList([Sam3DetrDecoderLayer(config) for _ in range(config.num_layers)])
self.output_layer_norm = nn.LayerNorm(config.hidden_size)
self.box_head = Sam3DecoderMLP(config.hidden_size, config.hidden_size, 4, 3)
self.query_embed = nn.Embedding(config.num_queries, config.hidden_size)
self.reference_points = nn.Embedding(config.num_queries, 4)
self.presence_token = nn.Embedding(1, config.hidden_size)
self.presence_head = Sam3DecoderMLP(config.hidden_size, config.hidden_size, 1, 3)
self.presence_layer_norm = nn.LayerNorm(config.hidden_size)
self.clamp_presence_logit_max_val = 10.0
self.ref_point_head = Sam3DecoderMLP(2 * config.hidden_size, config.hidden_size, config.hidden_size, 2)
self.box_rpb_embed_x = Sam3DecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
self.box_rpb_embed_y = Sam3DecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=False)
self.post_init()
@compile_compatible_method_lru_cache(maxsize=1)
def _get_coords(
self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""Generate normalized coordinate grids."""
coords_h = torch.arange(0, height, device=device, dtype=dtype) / height
coords_w = torch.arange(0, width, device=device, dtype=dtype) / width
return coords_h, coords_w
def _get_rpb_matrix(
self, reference_boxes: torch.Tensor, spatial_shape: tuple[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
"""
Compute box relative position bias (RPB) matrix using log-scale encoding.
RPB helps the decoder attend to relevant spatial locations based on predicted box positions.
Args:
reference_boxes: Reference boxes [batch_size, num_queries, 4] in sigmoid space
spatial_shape: (height, width) of the vision features as tensors
Returns:
RPB matrix [batch_size, num_heads, num_queries, height*width]
"""
height, width = spatial_shape
boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes)
batch_size, num_queries, _ = boxes_xyxy.shape
# Generate coordinate grids
coords_h, coords_w = self._get_coords(
height, width, dtype=reference_boxes.dtype, device=reference_boxes.device
)
# Compute deltas between coordinates and box boundaries
deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
deltas_y = deltas_y.view(batch_size, num_queries, -1, 2)
deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
deltas_x = deltas_x.view(batch_size, num_queries, -1, 2)
# Apply log-scale encoding
deltas_x_log = deltas_x * 8
deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / math.log2(8)
deltas_y_log = deltas_y * 8
deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / math.log2(8)
# Embed deltas
deltas_x = self.box_rpb_embed_x(deltas_x_log) # [batch_size, num_queries, width, num_heads]
deltas_y = self.box_rpb_embed_y(deltas_y_log) # [batch_size, num_queries, height, num_heads]
# Combine into 2D bias matrix
rpb_matrix = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
2
) # [batch_size, num_queries, height, width, num_heads]
rpb_matrix = rpb_matrix.flatten(2, 3) # [batch_size, num_queries, height*width, num_heads]
rpb_matrix = rpb_matrix.permute(0, 3, 1, 2).contiguous() # [batch_size, num_heads, num_queries, height*width]
return rpb_matrix
@check_model_inputs
def forward(
self,
vision_features: torch.Tensor,
text_features: torch.Tensor,
vision_pos_encoding: torch.Tensor,
text_mask: torch.Tensor | None = None,
spatial_shapes: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Sam3DETRDecoderOutput:
"""
Forward pass for the DETR decoder.
Args:
vision_features: Vision features [batch_size, height*width, hidden_size]
text_features: Text features [batch_size, seq_len, hidden_size]
vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
text_mask: Text padding mask [batch_size, seq_len] where True=valid, False=padding
spatial_shapes: Spatial shapes [num_levels, 2]
Returns:
Sam3DETRDecoderOutput containing decoder outputs from all layers.
"""
batch_size = vision_features.shape[0]
query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1)
reference_boxes = reference_boxes.sigmoid()
presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1)
# Concatenate presence token with query embeddings
hidden_states = torch.cat([presence_token, query_embeds], dim=1)
text_cross_attn_mask = None
if text_mask is not None:
text_cross_attn_mask = create_bidirectional_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=text_mask,
encoder_hidden_states=text_features,
)
intermediate_outputs = []
intermediate_boxes = [reference_boxes]
intermediate_presence_logits = []
for layer in self.layers:
# Generate sine embeddings for conditional queries
reference_points_input = reference_boxes.unsqueeze(2)
query_sine_embed = self.position_encoding.encode_boxes(reference_points_input[:, :, 0, :])
query_pos = self.ref_point_head(query_sine_embed)
# Compute box relative position bias (RPB) attention mask
vision_cross_attn_mask = None
if spatial_shapes is not None and spatial_shapes.shape[0] == 1:
spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1])
rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape)
# Prepend zeros row for presence token (it attends to all vision tokens equally)
vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0)
hidden_states = layer(
hidden_states,
query_pos=query_pos,
text_features=text_features,
vision_features=vision_features,
vision_pos_encoding=vision_pos_encoding,
text_cross_attn_mask=text_cross_attn_mask,
vision_cross_attn_mask=vision_cross_attn_mask,
**kwargs,
)
# Extract query hidden states (without presence token) for box refinement
query_hidden_states = hidden_states[:, 1:]
# Box refinement: predict delta and update reference boxes
reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states))
new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid()
reference_boxes = new_reference_boxes.detach()
intermediate_outputs.append(self.output_layer_norm(query_hidden_states))
intermediate_boxes.append(new_reference_boxes)
# Process presence token
presence_hidden = hidden_states[:, :1]
presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1)
presence_logits = presence_logits.clamp(
min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
)
intermediate_presence_logits.append(presence_logits)
# Stack outputs from all layers
intermediate_outputs = torch.stack(intermediate_outputs)
intermediate_boxes = torch.stack(intermediate_boxes[:-1])
intermediate_presence_logits = torch.stack(intermediate_presence_logits)
return Sam3DETRDecoderOutput(
intermediate_hidden_states=intermediate_outputs,
reference_boxes=intermediate_boxes,
presence_logits=intermediate_presence_logits,
)
class Sam3DotProductScoring(nn.Module):
"""
Computes classification scores by computing dot product between projected decoder queries and pooled text features.
This is used to determine confidence/presence scores for each query.
"""
def __init__(self, config: Sam3Config):
super().__init__()
self.config = config
hidden_size = config.detr_decoder_config.hidden_size
projection_dim = config.detr_decoder_config.hidden_size
self.text_mlp = Sam3DecoderMLP(
input_dim=hidden_size,
hidden_dim=config.detr_decoder_config.intermediate_size,
output_dim=hidden_size,
num_layers=2,
)
self.text_mlp_dropout = nn.Dropout(config.detr_decoder_config.dropout)
self.text_mlp_out_norm = nn.LayerNorm(hidden_size)
# Projections for text and query features
self.text_proj = nn.Linear(hidden_size, projection_dim)
self.query_proj = nn.Linear(hidden_size, projection_dim)
# Scale factor for dot product
self.scale = float(1.0 / np.sqrt(projection_dim))
# Clamping to avoid numerical issues
self.clamp_logits = True
self.clamp_max_val = 12.0
def _pool_text_features(self, text_features: torch.Tensor, text_mask: torch.Tensor | None) -> torch.Tensor:
"""
Mean pool text features, accounting for padding.
Args:
text_features: [batch_size, seq_len, hidden_size]
text_mask: [batch_size, seq_len] where True indicates valid tokens, False indicates padding
Returns:
pooled_text: [batch_size, hidden_size]
"""
if text_mask is None:
# No padding, simple mean
return text_features.mean(dim=1)
is_valid = text_mask.to(text_features.dtype).unsqueeze(-1) # [batch_size, seq_len, 1]
# Count valid tokens per batch
num_valid = is_valid.sum(dim=1).clamp(min=1.0) # [batch_size, 1]
# Mean pool only over valid tokens
pooled_text = (text_features * is_valid).sum(dim=1) / num_valid # [batch_size, hidden_size]
return pooled_text
def forward(
self,
decoder_hidden_states: torch.Tensor,
text_features: torch.Tensor,
text_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Compute classification scores via dot product.
Args:
decoder_hidden_states: [num_layers, batch_size, num_queries, hidden_size]
text_features: [batch_size, seq_len, hidden_size]
text_mask: [batch_size, seq_len] where True=valid, False=padding
Returns:
scores: [num_layers, batch_size, num_queries, 1]
"""
orig_text_features = text_features
text_features = self.text_mlp(text_features)
text_features = self.text_mlp_dropout(text_features)
text_features = text_features + orig_text_features
text_features = self.text_mlp_out_norm(text_features)
pooled_text = self._pool_text_features(text_features, text_mask)
proj_text = self.text_proj(pooled_text)
proj_queries = self.query_proj(decoder_hidden_states)
proj_text = proj_text.unsqueeze(-1)
scores = torch.matmul(proj_queries, proj_text.unsqueeze(0))
scores = scores * self.scale
if self.clamp_logits:
scores = scores.clamp(min=-self.clamp_max_val, max=self.clamp_max_val)
return scores
class Sam3MaskEmbedder(nn.Module):
"""
MLP that embeds object queries for mask prediction.
Similar to MaskFormer's mask embedder.
"""
def __init__(self, config: Sam3MaskDecoderConfig):
super().__init__()
self.config = config
hidden_size = config.hidden_size
self.layers = nn.ModuleList(
[
nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, hidden_size),
]
)
self.activation = nn.ReLU()
def forward(self, queries: torch.Tensor) -> torch.Tensor:
"""
Args:
queries: Query embeddings [batch_size, num_queries, hidden_size]
Returns:
Mask embeddings [batch_size, num_queries, hidden_size]
"""
hidden_states = queries
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states)
if i < len(self.layers) - 1:
hidden_states = self.activation(hidden_states)
return hidden_states
class Sam3PixelDecoder(nn.Module):
"""
Feature Pyramid Network (FPN) decoder that generates pixel-level features.
Inspired by MaskFormer's pixel decoder.
"""
def __init__(self, config: Sam3MaskDecoderConfig):
super().__init__()
self.config = config
hidden_size = config.hidden_size
num_upsampling_stages = config.num_upsampling_stages
# Create conv layers and norms for FPN
self.conv_layers = nn.ModuleList(
[
nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1)
for _ in range(num_upsampling_stages)
]
)
self.norms = nn.ModuleList([nn.GroupNorm(8, hidden_size) for _ in range(num_upsampling_stages)])
self.out_channels = hidden_size
def forward(self, backbone_features: list[torch.Tensor]) -> torch.Tensor:
"""
Args:
backbone_features: List of backbone features [batch_size, hidden_size, H_i, W_i]
from low to high resolution (assumes already projected to hidden_size)
Returns:
Pixel embeddings [batch_size, hidden_size, H, W] at the finest resolution
"""
# Start from the coarsest feature (last in list)
prev_fpn = backbone_features[-1]
# Iterate through features from coarse to fine (excluding the last which we started with)
for layer_idx, backbone_feat in enumerate(reversed(backbone_features[:-1])):
# Upsample previous FPN output to match current backbone feature size
prev_fpn = F.interpolate(prev_fpn, size=backbone_feat.shape[-2:], mode="nearest")
# Add skip connection
prev_fpn = prev_fpn + backbone_feat
# Apply conv and norm
prev_fpn = self.conv_layers[layer_idx](prev_fpn)
prev_fpn = self.norms[layer_idx](prev_fpn)
prev_fpn = F.relu(prev_fpn)
return prev_fpn
class Sam3MaskDecoder(Sam3PreTrainedModel):
"""
Mask decoder that combines object queries with pixel-level features to predict instance masks.
Also produces a semantic segmentation output and supports cross-attention to prompts.
"""
_can_record_outputs = {
"attentions": Sam3Attention,
}
def __init__(self, config: Sam3MaskDecoderConfig):
super().__init__(config)
self.config = config
hidden_size = config.hidden_size
# Pixel decoder (FPN)
self.pixel_decoder = Sam3PixelDecoder(config)
# Mask embedder (MLP to transform queries)
self.mask_embedder = Sam3MaskEmbedder(config)
# Projection from pixel decoder output to mask embedding space
self.instance_projection = nn.Conv2d(self.pixel_decoder.out_channels, hidden_size, kernel_size=1)
# Semantic segmentation head (always present in UniversalSegmentationHead)
self.semantic_projection = nn.Conv2d(self.pixel_decoder.out_channels, 1, kernel_size=1)
self.prompt_cross_attn = Sam3Attention(config)
self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)
self.post_init()
@check_model_inputs
def forward(
self,
decoder_queries: torch.Tensor,
backbone_features: list[torch.Tensor],
encoder_hidden_states: torch.Tensor,
prompt_features: torch.Tensor | None = None,
prompt_mask: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Sam3MaskDecoderOutput:
"""
Args:
decoder_queries: Decoder output queries [batch_size, num_queries, hidden_size]
backbone_features: List of backbone features to process through FPN
encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
prompt_features: Prompt features (text + geometry) for cross-attention [batch_size, prompt_len, hidden_size]
prompt_mask: Padding mask [batch_size, prompt_len] where True=valid, False=padding
Returns:
Sam3MaskDecoderOutput containing predicted masks and semantic segmentation.
"""
if prompt_features is not None:
# Cross-attention: encoder features attend to prompt features
residual = encoder_hidden_states
normed_hidden_states = self.prompt_cross_attn_norm(encoder_hidden_states)
cross_attn_mask = None
if prompt_mask is not None:
cross_attn_mask = create_bidirectional_mask(
config=self.config,
input_embeds=normed_hidden_states,
encoder_hidden_states=prompt_features,
attention_mask=prompt_mask,
)
attn_output, _ = self.prompt_cross_attn(
query=normed_hidden_states,
key=prompt_features,
value=prompt_features,
attention_mask=cross_attn_mask,
**kwargs,
)
encoder_hidden_states = residual + self.prompt_cross_attn_dropout(attn_output)
# Process backbone features through FPN to get pixel embeddings
pixel_embed = self._embed_pixels(
backbone_features=backbone_features,
encoder_hidden_states=encoder_hidden_states,
)
# Predict instance masks via dot product between query embeddings and pixel embeddings
instance_embeds = self.instance_projection(pixel_embed)
mask_embeddings = self.mask_embedder(decoder_queries)
pred_masks = torch.einsum("bqc,bchw->bqhw", mask_embeddings, instance_embeds)
# Generate semantic segmentation
semantic_seg = self.semantic_projection(pixel_embed)
return Sam3MaskDecoderOutput(
pred_masks=pred_masks,
semantic_seg=semantic_seg,
)
def _embed_pixels(
self,
backbone_features: list[torch.Tensor],
encoder_hidden_states: torch.Tensor,
) -> torch.Tensor:
"""
Embed pixels by combining backbone FPN features with encoder vision features.
The encoder vision features replace the finest-resolution backbone feature.
Args:
backbone_features: List of backbone features [batch_size, C, H_i, W_i]
encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
Returns:
Pixel embeddings [batch_size, hidden_size, H, W]
"""
backbone_visual_feats = [feat.clone() for feat in backbone_features]
# Extract vision features from encoder output and reshape to spatial format
spatial_dim = backbone_features[-1].shape[-2] * backbone_features[-1].shape[-1]
encoder_visual_embed = encoder_hidden_states[:, :spatial_dim, :]
batch_size, _, hidden_size = encoder_visual_embed.shape
height, width = backbone_features[-1].shape[-2:]
encoder_visual_embed = encoder_visual_embed.transpose(1, 2).reshape(batch_size, hidden_size, height, width)
# Replace finest backbone feature with encoder vision features
backbone_visual_feats[-1] = encoder_visual_embed
# Process through FPN decoder
pixel_embed = self.pixel_decoder(backbone_visual_feats)
return pixel_embed
class Sam3Model(Sam3PreTrainedModel):
input_modalities = ["image", "text"]
_checkpoint_conversion_mapping = {
r"detector_model.(.+)": r"\1" # the regex allows to remove the prefix, and add it back in revert mode
}
_keys_to_ignore_on_load_unexpected = [
r"^tracker_model.",
r"^tracker_neck.",
]
def __init__(self, config: Sam3Config):
# loading from a sam3_video config
if hasattr(config, "detector_config") and config.detector_config is not None:
detector_config = config.detector_config
if isinstance(detector_config, dict):
detector_config = Sam3Config(**detector_config)
config = detector_config
super().__init__(config)
self.vision_encoder = Sam3VisionModel(config.vision_config)
self.text_encoder = CLIPTextModelWithProjection(config.text_config)
self.vocab_size = config.text_config.vocab_size
# Project text features from text encoder hidden size to model hidden size
# CLIP text encoder outputs 1024-dim features, but we need 256-dim for DETR
self.text_projection = nn.Linear(config.text_config.hidden_size, config.detr_encoder_config.hidden_size)
# Pass _attn_implementation to subconfigs BEFORE creating modules
config.geometry_encoder_config._attn_implementation = config._attn_implementation
config.detr_encoder_config._attn_implementation = config._attn_implementation
config.detr_decoder_config._attn_implementation = config._attn_implementation
config.mask_decoder_config._attn_implementation = config._attn_implementation
self.geometry_encoder = Sam3GeometryEncoder(config.geometry_encoder_config)
self.detr_encoder = Sam3DetrEncoder(config.detr_encoder_config)
self.detr_decoder = Sam3DetrDecoder(config.detr_decoder_config)
self.mask_decoder = Sam3MaskDecoder(config.mask_decoder_config)
# Dot product scoring to compute classification scores
self.dot_product_scoring = Sam3DotProductScoring(config)
self.post_init()
@can_return_tuple
@auto_docstring
def get_text_features(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPooling:
r"""
Example:
```python
>>> from transformers import Sam3Model, Sam3Processor
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> model = Sam3Model.from_pretrained("facebook/sam3")
>>> processor = Sam3Processor.from_pretrained("facebook/sam3")
>>> # Pre-compute text embeddings
>>> text_inputs = processor(text="cat", return_tensors="pt")
>>> text_embeds = model.get_text_features(**text_inputs).pooler_output
>>> # Reuse text embeddings for multiple images
>>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> img_inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(pixel_values=img_inputs.pixel_values, text_embeds=text_embeds)
```
"""
text_outputs = self.text_encoder(
input_ids=input_ids, attention_mask=attention_mask, return_dict=True, **kwargs
)
last_hidden_state = text_outputs.last_hidden_state
text_outputs.pooler_output = self.text_projection(last_hidden_state)
return text_outputs
@auto_docstring
def get_vision_features(
self,
pixel_values: torch.FloatTensor,
**kwargs: Unpack[TransformersKwargs],
) -> Sam3VisionEncoderOutput:
r"""
Example:
```python
>>> from transformers import Sam3Model, Sam3Processor
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> model = Sam3Model.from_pretrained("facebook/sam3")
>>> processor = Sam3Processor.from_pretrained("facebook/sam3")
>>> # Pre-compute vision embeddings
>>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> img_inputs = processor(images=image, return_tensors="pt")
>>> vision_embeds = model.get_vision_features(pixel_values=img_inputs.pixel_values)
>>> # Reuse vision embeddings for multiple text prompts
>>> text_inputs = processor(text="cat", return_tensors="pt")
>>> outputs = model(vision_embeds=vision_embeds, input_ids=text_inputs.input_ids)
```
"""
vision_outputs = self.vision_encoder(pixel_values, **kwargs)
return vision_outputs
@check_model_inputs
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
vision_embeds: Sam3VisionEncoderOutput | None = None,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
text_embeds: torch.FloatTensor | None = None,
input_boxes: torch.FloatTensor | None = None,
input_boxes_labels: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> Sam3ImageSegmentationOutput:
r"""
vision_embeds (`Sam3VisionEncoderOutput`, *optional*):
Pre-computed vision embeddings. Can be used to easily reuse vision embeddings. If provided, `pixel_values`
should not be passed. Mutually exclusive with `pixel_values`.
text_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Pre-computed text embeddings. Can be used to easily reuse text embeddings. If provided, `input_ids`
should not be passed. Mutually exclusive with `input_ids`.
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`, *optional*):
Normalized box coordinates in [0, 1] range, in (cx, cy, w, h) format.
input_boxes_labels (`torch.LongTensor` of shape `(batch_size, num_boxes)`, *optional*):
Labels for boxes: 1 (positive), 0 (negative).
Example:
```python
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> from transformers import AutoModel, AutoProcessor
>>> model = AutoModel.from_pretrained("facebook/sam3")
>>> processor = AutoProcessor.from_pretrained("facebook/sam3")
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read())).convert("RGB")
>>> text = "car"
>>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> # Get segmentation output
>>> outputs = model(**inputs)
>>> pred_masks = outputs.pred_masks
>>> pred_boxes = outputs.pred_boxes
```
"""
if (pixel_values is None) == (vision_embeds is None):
raise ValueError("You must specify exactly one of pixel_values or vision_embeds")
if (input_ids is None) == (text_embeds is None):
raise ValueError("You must specify exactly one of input_ids or text_embeds")
if pixel_values is not None:
batch_size = pixel_values.shape[0]
device = pixel_values.device
else:
batch_size = vision_embeds.fpn_hidden_states[0].shape[0]
device = vision_embeds.fpn_hidden_states[0].device
if vision_embeds is None:
vision_outputs = self.vision_encoder(pixel_values, **kwargs)
else:
vision_outputs = vision_embeds
fpn_hidden_states = vision_outputs.fpn_hidden_states[:-1]
fpn_position_encoding = vision_outputs.fpn_position_encoding[:-1]
if text_embeds is None:
text_features = self.get_text_features(
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
).pooler_output
else:
text_features = text_embeds
text_mask = attention_mask.bool() if attention_mask is not None else None
has_geometry_prompts = input_boxes is not None and input_boxes.numel() > 0
geometry_prompt_features = None
geometry_prompt_mask = None
if has_geometry_prompts:
if input_boxes is not None and input_boxes.numel() > 0:
box_embeddings = input_boxes # [batch_size, num_boxes, 4]
box_labels = (
input_boxes_labels
if input_boxes_labels is not None
else torch.ones_like(box_embeddings[..., 0], dtype=torch.long)
)
box_mask = (
(input_boxes_labels != -10)
if input_boxes_labels is not None
else torch.ones(batch_size, input_boxes.shape[1], dtype=torch.bool, device=device)
)
box_labels = torch.where(box_labels == -10, 0, box_labels)
else:
box_embeddings = torch.zeros(batch_size, 0, 4, dtype=text_features.dtype, device=device)
box_labels = torch.zeros(batch_size, 0, dtype=torch.long, device=device)
box_mask = torch.zeros(batch_size, 0, dtype=torch.bool, device=device)
geometry_outputs = self.geometry_encoder(
box_embeddings=box_embeddings,
box_mask=box_mask,
box_labels=box_labels,
img_feats=fpn_hidden_states,
img_pos_embeds=fpn_position_encoding,
)
geometry_prompt_features = geometry_outputs.last_hidden_state
geometry_prompt_mask = geometry_outputs.attention_mask
if geometry_prompt_features is not None:
# Repeat text_features for all geometry prompts
if text_features.shape[0] == 1 and geometry_prompt_features.shape[0] > 1:
text_features = text_features.repeat(geometry_prompt_features.shape[0], 1, 1)
combined_prompt_features = torch.cat([text_features, geometry_prompt_features], dim=1)
if text_mask is not None and text_mask.shape[0] == 1 and geometry_prompt_mask.shape[0] > 1:
text_mask = text_mask.repeat(geometry_prompt_mask.shape[0], 1)
if text_mask is not None and geometry_prompt_mask is not None:
combined_prompt_mask = torch.cat([text_mask, geometry_prompt_mask], dim=1)
elif text_mask is not None:
geo_valid_mask = torch.ones(
batch_size, geometry_prompt_features.shape[1], dtype=torch.bool, device=device
)
combined_prompt_mask = torch.cat([text_mask, geo_valid_mask], dim=1)
elif geometry_prompt_mask is not None:
text_valid_mask = torch.ones(batch_size, text_features.shape[1], dtype=torch.bool, device=device)
combined_prompt_mask = torch.cat([text_valid_mask, geometry_prompt_mask], dim=1)
else:
combined_prompt_mask = None
else:
combined_prompt_features = text_features
combined_prompt_mask = text_mask
encoder_outputs = self.detr_encoder(
vision_features=[fpn_hidden_states[-1]],
text_features=combined_prompt_features,
vision_pos_embeds=[fpn_position_encoding[-1]],
text_mask=combined_prompt_mask,
**kwargs,
)
decoder_outputs = self.detr_decoder(
vision_features=encoder_outputs.last_hidden_state,
text_features=encoder_outputs.text_features,
vision_pos_encoding=encoder_outputs.pos_embeds_flattened,
text_mask=combined_prompt_mask,
spatial_shapes=encoder_outputs.spatial_shapes,
**kwargs,
)
# Refine boxes from decoder
all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states)
reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes)
all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid()
all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh)
all_pred_logits = self.dot_product_scoring(
decoder_hidden_states=decoder_outputs.intermediate_hidden_states,
text_features=encoder_outputs.text_features,
text_mask=combined_prompt_mask,
).squeeze(-1)
pred_logits = all_pred_logits[-1]
pred_boxes = all_pred_boxes[-1]
decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1]
presence_logits = decoder_outputs.presence_logits[-1]
mask_outputs = self.mask_decoder(
decoder_queries=decoder_hidden_states,
backbone_features=list(fpn_hidden_states),
encoder_hidden_states=encoder_outputs.last_hidden_state,
prompt_features=combined_prompt_features,
prompt_mask=combined_prompt_mask,
**kwargs,
)
return Sam3ImageSegmentationOutput(
pred_masks=mask_outputs.pred_masks,
pred_boxes=pred_boxes,
pred_logits=pred_logits,
presence_logits=presence_logits,
semantic_seg=mask_outputs.semantic_seg,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_reference_boxes=decoder_outputs.reference_boxes,
encoder_hidden_states=encoder_outputs.hidden_states,
vision_hidden_states=vision_outputs.hidden_states,
vision_attentions=vision_outputs.attentions,
detr_encoder_attentions=encoder_outputs.attentions,
detr_decoder_attentions=decoder_outputs.attentions,
mask_decoder_attentions=mask_outputs.attentions,
)
__all__ = ["Sam3Model", "Sam3VisionModel", "Sam3ViTModel", "Sam3PreTrainedModel"]