# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections.abc import Callable, Iterable from dataclasses import dataclass import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision from torch import Tensor from transformers import CLIPTextModelWithProjection from ... import initialization as init from ...activations import ACT2FN from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, ModelOutput, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import auto_docstring, can_return_tuple, logging from ...utils.generic import TransformersKwargs, check_model_inputs, is_flash_attention_requested from ..auto import AutoModel from .configuration_sam3 import ( Sam3Config, Sam3DETRDecoderConfig, Sam3DETREncoderConfig, Sam3GeometryEncoderConfig, Sam3MaskDecoderConfig, Sam3VisionConfig, Sam3ViTConfig, ) logger = logging.get_logger(__name__) @dataclass @auto_docstring class Sam3VisionEncoderOutput(BaseModelOutputWithPooling): r""" fpn_hidden_states (`tuple[torch.FloatTensor]`): Tuple of multi-level FPN feature maps. fpn_position_encoding (`tuple[torch.FloatTensor]`): Tuple of position encodings for each FPN level. """ fpn_hidden_states: tuple[torch.FloatTensor, ...] = None fpn_position_encoding: tuple[torch.FloatTensor, ...] = None @dataclass @auto_docstring class Sam3GeometryEncoderOutput(ModelOutput): r""" last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_prompts, hidden_size)`): Encoded geometry prompt features (boxes). attention_mask (`torch.BoolTensor` of shape `(batch_size, num_prompts)`, *optional*): Attention mask for geometry prompts where True indicates valid positions and False indicates padding. """ last_hidden_state: torch.FloatTensor = None attention_mask: torch.BoolTensor | None = None @dataclass @auto_docstring class Sam3DETREncoderOutput(ModelOutput): r""" last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Encoded vision features (flattened from multi-level features). pos_embeds_flattened (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Flattened position embeddings for the vision features. text_features (`torch.FloatTensor` of shape `(batch_size, text_seq_len, hidden_size)`, *optional*): Text features (may be pooled after encoder processing). spatial_shapes (`torch.LongTensor` of shape `(num_levels, 2)`, *optional*): Spatial shapes (height, width) for each feature pyramid level. hidden_states (`tuple[torch.FloatTensor]`, *optional*): Tuple of hidden states from all encoder layers. attentions (`tuple[torch.FloatTensor]`, *optional*): Tuple of attention weights from all encoder layers. """ last_hidden_state: torch.FloatTensor = None pos_embeds_flattened: torch.FloatTensor | None = None text_features: torch.FloatTensor | None = None spatial_shapes: torch.LongTensor | None = None hidden_states: tuple[torch.FloatTensor] | None = None attentions: tuple[torch.FloatTensor] | None = None @dataclass @auto_docstring class Sam3DETRDecoderOutput(ModelOutput): r""" intermediate_hidden_states (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, hidden_size)`): Decoder hidden states from all layers. reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`): Predicted reference boxes from all decoder layers in (cx, cy, w, h) format. presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`): Presence logits from all decoder layers indicating object presence confidence. hidden_states (`tuple[torch.FloatTensor]`, *optional*): Tuple of hidden states from all decoder layers. attentions (`tuple[torch.FloatTensor]`, *optional*): Tuple of attention weights from all decoder layers (self-attention and cross-attention). """ intermediate_hidden_states: torch.FloatTensor = None reference_boxes: torch.FloatTensor = None presence_logits: torch.FloatTensor = None hidden_states: tuple[torch.FloatTensor] | None = None attentions: tuple[torch.FloatTensor] | None = None @dataclass @auto_docstring class Sam3MaskDecoderOutput(ModelOutput): r""" pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`): Predicted segmentation masks for each query. semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*): Semantic segmentation output. attentions (`tuple[torch.FloatTensor]`, *optional*): Tuple of attention weights from mask decoder cross-attention layers. """ pred_masks: torch.FloatTensor = None semantic_seg: torch.FloatTensor | None = None attentions: tuple[torch.FloatTensor] | None = None @dataclass @auto_docstring class Sam3ImageSegmentationOutput(ModelOutput): r""" pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`): Predicted segmentation masks for each query. pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): Predicted bounding boxes in (x1, y1, x2, y2) format. pred_logits (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): Classification confidence scores for each query, computed via dot product between decoder query features and text features. presence_logits (`torch.FloatTensor` of shape `(batch_size, 1)`, *optional*): Presence logits from the DETR decoder presence token (last layer only). These indicate whether objects are present in the scene. Can be used to compute final scores by multiplying with pred_logits: `final_scores = pred_logits.sigmoid() * presence_logits.sigmoid()`. semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*): Semantic segmentation output. decoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Tuple of hidden states from all DETR decoder layers. Each tensor has shape `(batch_size, num_queries, hidden_size)`. decoder_reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`, *optional*): Reference boxes from all DETR decoder layers. encoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Tuple of hidden states from all DETR encoder layers. vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Tuple of hidden states from all vision encoder (ViT) layers. vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from vision encoder (ViT) layers. detr_encoder_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from DETR encoder layers. detr_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from DETR decoder layers (self-attention and cross-attention). mask_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from mask decoder layers. """ pred_masks: torch.FloatTensor = None pred_boxes: torch.FloatTensor = None pred_logits: torch.FloatTensor | None = None presence_logits: torch.FloatTensor | None = None semantic_seg: torch.FloatTensor | None = None decoder_hidden_states: tuple[torch.FloatTensor] | None = None decoder_reference_boxes: torch.FloatTensor | None = None encoder_hidden_states: tuple[torch.FloatTensor] | None = None vision_hidden_states: tuple[torch.FloatTensor] | None = None vision_attentions: tuple[torch.FloatTensor] | None = None detr_encoder_attentions: tuple[torch.FloatTensor] | None = None detr_decoder_attentions: tuple[torch.FloatTensor] | None = None mask_decoder_attentions: tuple[torch.FloatTensor] | None = None def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: """The inverse function for sigmoid activation function.""" x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1 / x2) def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False): """ Concatenates two right-padded sequences, such that the resulting sequence is contiguous and also right-padded. Tensors are batch-first, masks are batch-first with True=valid, False=padding. Args: seq1: A tensor of shape (batch_size, seq1_length, hidden_size). mask1: A tensor of shape (batch_size, seq1_length) with True=valid, False=padding. seq2: A tensor of shape (batch_size, seq2_length, hidden_size). mask2: A tensor of shape (batch_size, seq2_length) with True=valid, False=padding. return_index: If True, also returns the index of the ids of the element of seq2 in the concatenated sequence. This can be used to retrieve the elements of seq2. Returns: A tuple (concatenated_sequence, concatenated_mask) if return_index is False, otherwise (concatenated_sequence, concatenated_mask, index). The concatenated_mask uses True=valid, False=padding convention. """ batch_size, seq1_length, hidden_size = seq1.shape batch_size2, seq2_length, hidden_size2 = seq2.shape assert batch_size == batch_size2 == mask1.size(0) == mask2.size(0) assert hidden_size == hidden_size2 assert seq1_length == mask1.size(1) assert seq2_length == mask2.size(1) actual_seq1_lengths = mask1.sum(dim=-1) actual_seq2_lengths = mask2.sum(dim=-1) final_lengths = actual_seq1_lengths + actual_seq2_lengths max_length = seq1_length + seq2_length concatenated_mask = ( torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) < final_lengths[:, None] ) concatenated_sequence = torch.zeros((batch_size, max_length, hidden_size), device=seq2.device, dtype=seq2.dtype) concatenated_sequence[:, :seq1_length, :] = seq1 # Shift seq2 elements to start at the end of valid seq1 index = torch.arange(seq2_length, device=seq2.device)[None].repeat(batch_size, 1) index = index + actual_seq1_lengths[:, None] # Scatter seq2 into the right positions concatenated_sequence = concatenated_sequence.scatter(1, index[:, :, None].expand(-1, -1, hidden_size), seq2) if return_index: return concatenated_sequence, concatenated_mask, index return concatenated_sequence, concatenated_mask def box_cxcywh_to_xyxy(x): """Convert boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format.""" x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) class Sam3MLP(nn.Module): def __init__(self, config: Sam3ViTConfig): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, scaling: float | None = None, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attention_mask = attention_mask[:, :, :, : key.shape[-2]] attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Sam3Attention(nn.Module): """ Multi-head attention. Handles standard [batch_size, seq_len, hidden_size] tensors. """ def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = self.hidden_size // config.num_attention_heads self.scaling = self.head_dim**-0.5 self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) self.k_proj = nn.Linear(self.hidden_size, self.hidden_size) self.v_proj = nn.Linear(self.hidden_size, self.hidden_size) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: """ Args: query: [batch_size, query_len, hidden_size] key: [batch_size, key_len, hidden_size] value: [batch_size, value_len, hidden_size] attention_mask: [batch_size, num_heads, query_len, key_len] or broadcastable Returns: Tuple of (output, attention_weights) output: [batch_size, query_len, hidden_size] attention_weights: [batch_size, num_heads, query_len, key_len] """ batch_size = query.shape[0] query_len = query.shape[1] key_len = key.shape[1] query = self.q_proj(query).view(batch_size, query_len, self.num_attention_heads, self.head_dim).transpose(1, 2) key = self.k_proj(key).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2) value = self.v_proj(value).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) if ( is_flash_attention_requested(self.config) and attention_mask is not None and attention_mask.dtype != torch.bool ): # Relative position bias tensors are represented as float masks and are incompatible with Flash Attention # Fallback to SDPA for this call only so the rest of the model can still benefit from FA attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"] logger.warning_once( "Sam3Attention: falling back to SDPA for relative-position cross-attention because " "Flash Attention does not support additive bias masks." ) attn_output, attn_weights = attention_interface( self, query, key, value, attention_mask=attention_mask, dropout=0.0, scaling=self.scaling, is_causal=self.is_causal, **kwargs, ) attn_output = attn_output.reshape(batch_size, query_len, self.num_attention_heads * self.head_dim).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Sam3ViTRotaryEmbedding(nn.Module): """ Vision Rotary Position Embedding for SAM3, following transformers library standards. Supports 2D (axial) rotary embeddings for spatial dimensions. """ def __init__(self, config: Sam3ViTConfig, end_x: int, end_y: int, scale: float = 1.0): super().__init__() dim = config.hidden_size // config.num_attention_heads # Ensure even dimension for proper axial splitting if dim % 4 != 0: raise ValueError("Dimension must be divisible by 4 for axial RoPE") self.end_x, self.end_y = end_x, end_y self.dim = dim self.rope_theta = config.rope_theta self.scale = scale freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) flattened_indices = torch.arange(end_x * end_y, dtype=torch.long) x_positions = (flattened_indices % end_x) * scale y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * scale freqs_x = torch.outer(x_positions, freqs).float() freqs_y = torch.outer(y_positions, freqs).float() inv_freq = torch.cat([freqs_x, freqs_y], dim=-1) inv_freq = inv_freq.repeat_interleave(2, dim=-1) # directly register the cos and sin embeddings as we have a fixed feature shape self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False) self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False) @torch.no_grad() def forward(self) -> tuple[torch.Tensor, torch.Tensor]: # As the feature map size is fixed for each stage, we can just return the pre-computed embeddings. return self.rope_embeddings_cos, self.rope_embeddings_sin def rotate_pairwise(x): """ pairwise rotation of the hidden dims of the input. Differerent from Llama Half-Tensor Rotation. This is an optimized version of the following more explicit implementation: ```python x_rotated = torch.zeros_like(x, dtype=x.dtype, device=x.device) x_rotated[..., ::2] = -x[..., 1::2] x_rotated[..., 1::2] = x[..., ::2] return x_rotated ``` """ x = x.view(*x.shape[:-1], -1, 2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return x.flatten(start_dim=-2) def apply_rotary_pos_emb_2d( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary position embedding to query and key tensors for self-attention. Args: q: Query tensor of shape (batch_size, num_windows, seq_len, num_heads, head_dim) k: Key tensor of shape (batch_size, num_windows, seq_len, num_heads, head_dim) cos: Cosine position embedding of shape (seq_len, head_dim) sin: Sine position embedding of shape (seq_len, head_dim) Returns: Rotated (q, k) tensors """ q_embed = q.float() q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin) k_embed = k.float() k_embed = (k_embed * cos) + (rotate_pairwise(k_embed) * sin) return q_embed.type_as(q), k_embed.type_as(k) class Sam3ViTRoPEAttention(nn.Module): """Self-attention with rotary position encoding.""" def __init__(self, config: Sam3ViTConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = self.hidden_size // config.num_attention_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) self.k_proj = nn.Linear(self.hidden_size, self.hidden_size) self.v_proj = nn.Linear(self.hidden_size, self.hidden_size) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], **kwargs: Unpack[TransformersKwargs], ) -> Tensor: batch_size, height, width, _ = hidden_states.shape seq_len = height * width new_shape = (batch_size, seq_len, self.num_attention_heads, self.head_dim) query = self.q_proj(hidden_states).view(*new_shape).transpose(1, 2) key = self.k_proj(hidden_states).view(*new_shape).transpose(1, 2) value = self.v_proj(hidden_states).view(*new_shape).transpose(1, 2) cos, sin = position_embeddings query, key = apply_rotary_pos_emb_2d(query, key, cos=cos, sin=sin) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, query, key, value, attention_mask=None, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, is_causal=self.is_causal, **kwargs, ) attn_output = attn_output.reshape(batch_size, height, width, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Sam3ViTPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config: Sam3ViTConfig): super().__init__() image_size, patch_size = config.pretrain_image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2) return embeddings class Sam3ViTEmbeddings(nn.Module): """ Construct the patch embeddings and position embeddings for SAM3 ViT. Position embeddings are tiled (not interpolated) when resizing to match different input sizes. """ def __init__(self, config: Sam3ViTConfig): super().__init__() self.patch_embeddings = Sam3ViTPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter( torch.randn(1, num_patches, config.hidden_size) ) # !Remove cls token in convert weights! self.dropout = nn.Dropout(config.hidden_dropout) self.patch_size = config.patch_size def _tile_position_embeddings( self, position_embeddings: torch.Tensor, height: int, width: int, ) -> torch.Tensor: """ Tile position embeddings to match target spatial dimensions. Args: position_embeddings: Shape [1, num_pretrain_patches, hidden_size] height: Target height in patches width: Target width in patches Returns: Shape [1, height * width, hidden_size] """ pretrain_size = int(position_embeddings.shape[1] ** 0.5) # Skip tiling if sizes match (but always tile during tracing for consistent graph) if not torch.jit.is_tracing() and pretrain_size == height and pretrain_size == width: return position_embeddings.reshape(1, height * width, -1) # Tile position embeddings to match target spatial dimensions hidden_size = position_embeddings.shape[-1] pos_embed = position_embeddings.reshape(1, pretrain_size, pretrain_size, hidden_size).permute(0, 3, 1, 2) repeat_h = height // pretrain_size + 1 repeat_w = width // pretrain_size + 1 pos_embed = pos_embed.tile([1, 1, repeat_h, repeat_w])[:, :, :height, :width] return pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, hidden_size) def forward( self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False, ) -> torch.Tensor: height, width = pixel_values.shape[-2:] embeddings = self.patch_embeddings(pixel_values) # Calculate spatial dimensions in patches height_patches = height // self.patch_size width_patches = width // self.patch_size position_embeddings = self._tile_position_embeddings( self.position_embeddings, height_patches, width_patches, ) embeddings = embeddings + position_embeddings embeddings = self.dropout(embeddings) return embeddings def window_partition(hidden_state, window_size): """ Partition into non-overlapping windows with padding if needed. Args: hidden_state (`torch.Tensor`): Input tokens with [batch_size, height, width, num_channels]. window_size (`int`): Window size. Returns: `tuple(torch.FloatTensor)` comprising various elements: - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels]. - (padded_height, padded_width): padded height and width before partition """ batch_size, height, width, num_channels = hidden_state.shape pad_height = (window_size - height % window_size) % window_size pad_width = (window_size - width % window_size) % window_size # Noop in case pad_width == 0 and pad_height == 0. hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height)) padded_height, padded_width = height + pad_height, width + pad_width hidden_state = hidden_state.view( batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels ) windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) return windows, (padded_height, padded_width) def window_unpartition(windows, window_size, pad_height_width, height_width): """ Window unpartition into original sequences and removing padding. Args: windows (`torch.Tensor`): Input tokens with [batch_size * num_windows, window_size, window_size, num_channels]. window_size (`int`): Window size. pad_height_width (`tuple[int]`): Padded height and width (padded_height, padded_width). height_width (`tuple[int]`): Original height and width before padding. Returns: hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels]. """ padded_height, padded_width = pad_height_width height, width = height_width batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size) hidden_state = windows.view( batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1 ) hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous() hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1) # We always have height <= padded_height and width <= padded_width hidden_state = hidden_state[:, :height, :width, :].contiguous() return hidden_state class Sam3ViTLayerScale(nn.Module): def __init__(self, config) -> None: super().__init__() self.lambda1 = nn.Parameter(config.layer_scale_init_value * torch.ones(config.hidden_size)) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return hidden_state * self.lambda1 class Sam3ViTLayer(GradientCheckpointingLayer): """Vision Transformer layer with rotary position embeddings and optional windowed attention.""" def __init__(self, config: Sam3ViTConfig, window_size: int = 0) -> None: super().__init__() hidden_size = config.hidden_size image_size = config.image_size image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size) patch_size = config.patch_size patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size) input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.layer_norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) rotary_input_size = input_size if window_size == 0 else (window_size, window_size) rotary_scale = config.window_size / rotary_input_size[0] self.rotary_emb = Sam3ViTRotaryEmbedding( config, end_x=rotary_input_size[0], end_y=rotary_input_size[1], scale=rotary_scale ) self.attention = Sam3ViTRoPEAttention(config) self.layer_norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) self.mlp = Sam3MLP(config) self.dropout = nn.Dropout(config.hidden_dropout) self.window_size = window_size def forward( self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) if self.window_size > 0: height, width = hidden_states.shape[1], hidden_states.shape[2] # Partition into non-overlapping windows for efficient attention hidden_states, pad_height_width = window_partition(hidden_states, self.window_size) position_embeddings = self.rotary_emb() hidden_states, _ = self.attention(hidden_states, position_embeddings, **kwargs) if self.window_size > 0: # Reverse window partition to restore original spatial layout hidden_states = window_unpartition(hidden_states, self.window_size, pad_height_width, (height, width)) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.dropout(hidden_states) return hidden_states @auto_docstring class Sam3PreTrainedModel(PreTrainedModel): config_class = Sam3Config base_model_prefix = "sam3" main_input_name = "pixel_values" input_modalities = ["image", "text"] _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Sam3ViTEmbeddings): init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) elif isinstance(module, Sam3ViTRotaryEmbedding): end_x, end_y = module.end_x, module.end_y dim = module.dim freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) flattened_indices = torch.arange(end_x * end_y, dtype=torch.long) x_positions = (flattened_indices % end_x) * module.scale y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale freqs_x = torch.outer(x_positions, freqs).float() freqs_y = torch.outer(y_positions, freqs).float() inv_freq = torch.cat([freqs_x, freqs_y], dim=-1) inv_freq = inv_freq.repeat_interleave(2, dim=-1) init.copy_(module.rope_embeddings_cos, inv_freq.cos()) init.copy_(module.rope_embeddings_sin, inv_freq.sin()) @auto_docstring class Sam3ViTModel(Sam3PreTrainedModel): def __init__(self, config: Sam3ViTConfig): super().__init__(config) self.config = config self.embeddings = Sam3ViTEmbeddings(config) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layers = nn.ModuleList( [ Sam3ViTLayer(config, window_size=config.window_size if i not in config.global_attn_indexes else 0) for i in range(config.num_hidden_layers) ] ) self.post_init() def get_input_embeddings(self) -> Sam3ViTPatchEmbeddings: return self.embeddings.patch_embeddings @check_model_inputs(tie_last_hidden_states=False) @auto_docstring def forward( self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: hidden_states = self.embeddings(pixel_values) # [batch_size, seq_len, hidden_size] batch_size = hidden_states.shape[0] height = pixel_values.shape[-2] // self.config.patch_size width = pixel_values.shape[-1] // self.config.patch_size hidden_size = hidden_states.shape[-1] # Reshape to spatial format for windowed attention: [batch_size, height, width, hidden_size] hidden_states = hidden_states.view(batch_size, height, width, hidden_size) hidden_states = self.layer_norm(hidden_states) for layer in self.layers: hidden_states = layer(hidden_states, **kwargs) # Reshape back to sequence format: [batch_size, height*width, hidden_size] hidden_states = hidden_states.view(batch_size, height * width, hidden_size) return BaseModelOutput(last_hidden_state=hidden_states) class Sam3SinePositionEmbedding(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__( self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None ): super().__init__() if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize self.scale = 2 * math.pi if scale is None else scale def encode_1d_positions(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Encode 1D coordinate pairs using sine/cosine positional embeddings. Args: x: 1D tensor of x coordinates (flattened) y: 1D tensor of y coordinates (flattened) Returns: Tuple of (pos_x, pos_y) positional embeddings """ x_embed = x * self.scale y_embed = y * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).to(x.dtype) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, None] / dim_t pos_y = y_embed[:, None] / dim_t pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) return pos_x, pos_y def encode_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """ Encode 4D box coordinates (x, y, w, h) for decoder conditioning using sine/cosine embeddings. Args: boxes: Box coordinates [batch_size, num_queries, 4] in (x, y, w, h) format Returns: Position embeddings [batch_size, num_queries, num_pos_feats*4] """ assert boxes.size(-1) == 4, f"Expected 4D box coordinates (x, y, w, h), got shape {boxes.shape}" dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=boxes.device).to(boxes.dtype) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) x_embed = boxes[:, :, 0] * self.scale y_embed = boxes[:, :, 1] * self.scale w_embed = boxes[:, :, 2] * self.scale h_embed = boxes[:, :, 3] * self.scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_w = w_embed[:, :, None] / dim_t pos_h = h_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) return pos @compile_compatible_method_lru_cache(maxsize=4) def forward( self, shape: torch.Size, device: torch.device | str, dtype: torch.dtype, mask: Tensor | None = None, ) -> Tensor: if mask is None: mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool) not_mask = (~mask).to(dtype) y_embed = not_mask.cumsum(1) x_embed = not_mask.cumsum(2) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos class Sam3FPNLayer(nn.Module): def __init__(self, in_channels: int, fpn_dim: int, scale_factor: float): super().__init__() self.scale_factor = scale_factor # Build the upsampling/downsampling layers based on scale factor self.scale_layers = nn.ModuleList() if scale_factor == 4.0: self.scale_layers.append(nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)) self.scale_layers.append(nn.GELU()) self.scale_layers.append(nn.ConvTranspose2d(in_channels // 2, in_channels // 4, kernel_size=2, stride=2)) intermediate_channels = in_channels // 4 elif scale_factor == 2.0: self.scale_layers.append(nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)) intermediate_channels = in_channels // 2 elif scale_factor == 1.0: intermediate_channels = in_channels elif scale_factor == 0.5: self.scale_layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) intermediate_channels = in_channels else: raise NotImplementedError(f"scale_factor={scale_factor} is not supported yet.") self.proj1 = nn.Conv2d(in_channels=intermediate_channels, out_channels=fpn_dim, kernel_size=1) self.proj2 = nn.Conv2d(in_channels=fpn_dim, out_channels=fpn_dim, kernel_size=3, padding=1) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.to(self.proj1.weight.dtype) for layer in self.scale_layers: hidden_states = layer(hidden_states) hidden_states = self.proj1(hidden_states) hidden_states = self.proj2(hidden_states) return hidden_states class Sam3VisionNeck(nn.Module): def __init__(self, config: Sam3VisionConfig): super().__init__() self.config = config self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True) # Create one FPN layer per scale factor self.fpn_layers = nn.ModuleList( [ Sam3FPNLayer( in_channels=config.backbone_config.hidden_size, fpn_dim=config.fpn_hidden_size, scale_factor=scale ) for scale in config.scale_factors ] ) def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]: fpn_hidden_states = () fpn_position_encoding = () for fpn_layer in self.fpn_layers: fpn_output = fpn_layer(hidden_states) fpn_hidden_states += (fpn_output,) # Generate position encoding for this FPN level pos_enc = self.position_encoding(fpn_output.shape, fpn_output.device, fpn_output.dtype) fpn_position_encoding += (pos_enc,) return fpn_hidden_states, fpn_position_encoding @auto_docstring( custom_intro=""" The vision model from Sam without any head or projection on top. """ ) class Sam3VisionModel(Sam3PreTrainedModel): config_class = Sam3VisionConfig main_input_name = "pixel_values" _can_record_outputs = { "hidden_states": Sam3ViTLayer, "attentions": Sam3ViTRoPEAttention, } def __init__(self, config: Sam3VisionConfig): super().__init__(config) self.config = config self.backbone = AutoModel.from_config(config.backbone_config) self.neck = Sam3VisionNeck(config) self.post_init() def get_input_embeddings(self): return self.backbone.get_input_embeddings() @check_model_inputs def forward( self, pixel_values: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Sam3VisionEncoderOutput: if pixel_values is None: raise ValueError("You have to specify pixel_values") backbone_output = self.backbone(pixel_values, **kwargs) hidden_states = backbone_output.last_hidden_state # [batch_size, seq_len, hidden_size] # Reshape for FPN neck: [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size, height, width] batch_size = hidden_states.shape[0] height = pixel_values.shape[-2] // self.config.backbone_config.patch_size width = pixel_values.shape[-1] // self.config.backbone_config.patch_size hidden_states_spatial = hidden_states.view(batch_size, height, width, -1).permute(0, 3, 1, 2) fpn_hidden_states, fpn_position_encoding = self.neck(hidden_states_spatial) return Sam3VisionEncoderOutput( last_hidden_state=hidden_states, fpn_hidden_states=fpn_hidden_states, fpn_position_encoding=fpn_position_encoding, ) class Sam3GeometryEncoderLayer(nn.Module): def __init__(self, config: Sam3GeometryEncoderConfig): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size) self.self_attn = Sam3Attention(config) self.dropout = nn.Dropout(config.dropout) self.cross_attn = Sam3Attention(config) self.layer_norm2 = nn.LayerNorm(config.hidden_size) self.mlp = Sam3MLP(config) self.layer_norm3 = nn.LayerNorm(config.hidden_size) def forward( self, prompt_feats: Tensor, vision_feats: Tensor, vision_pos_encoding: Tensor, prompt_mask: Tensor, **kwargs: Unpack[TransformersKwargs], ): residual = prompt_feats hidden_states = self.layer_norm1(prompt_feats) hidden_states, _ = self.self_attn( query=hidden_states, key=hidden_states, value=hidden_states, attention_mask=prompt_mask, **kwargs ) hidden_states = self.dropout(hidden_states) + residual residual = hidden_states hidden_states = self.layer_norm2(hidden_states) key = vision_feats + vision_pos_encoding hidden_states, _ = self.cross_attn(query=hidden_states, key=key, value=vision_feats, **kwargs) hidden_states = self.dropout(hidden_states) + residual residual = hidden_states hidden_states = self.layer_norm3(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.dropout(hidden_states) + residual return hidden_states class Sam3GeometryEncoder(nn.Module): """ Encoder for geometric prompts (boxes). Boxes are encoded using three approaches: - Direct projection: linear projection from coordinate space to hidden_size - Pooling: pool features from the backbone at the specified location (ROI align for boxes) - Position encoding: use position encoding of the box center These encodings are combined additively and further processed with transformer layers. """ def __init__(self, config: Sam3GeometryEncoderConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.roi_size = config.roi_size self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=True) self.label_embed = nn.Embedding(2, self.hidden_size) self.cls_embed = nn.Embedding(1, self.hidden_size) # Box encoding layers self.boxes_direct_project = nn.Linear(4, self.hidden_size) self.boxes_pool_project = nn.Conv2d(self.hidden_size, self.hidden_size, self.roi_size) self.boxes_pos_enc_project = nn.Linear(self.hidden_size + 2, self.hidden_size) # Image feature normalization self.vision_layer_norm = nn.LayerNorm(self.hidden_size) # Prompt projection and normalization self.final_proj = nn.Linear(self.hidden_size, self.hidden_size) self.prompt_layer_norm = nn.LayerNorm(self.hidden_size) # Transformer layers self.layers = nn.ModuleList([Sam3GeometryEncoderLayer(config) for _ in range(config.num_layers)]) self.output_layer_norm = nn.LayerNorm(self.hidden_size) def _encode_box_coordinates( self, center_x: torch.Tensor, center_y: torch.Tensor, width: torch.Tensor, height: torch.Tensor ) -> torch.Tensor: """ Encode box coordinates by combining position-encoded centers with raw width/height. Args: center_x: 1D tensor of box center x coordinates center_y: 1D tensor of box center y coordinates width: 1D tensor of box widths height: 1D tensor of box heights Returns: Encoded box coordinates [N, embedding_dim] """ pos_x, pos_y = self.position_encoding.encode_1d_positions(center_x, center_y) pos = torch.cat((pos_y, pos_x, height[:, None], width[:, None]), dim=1) return pos def _encode_boxes(self, boxes, boxes_mask, boxes_labels, vision_features): """Encode box prompts. Mask convention: True=valid, False=padding.""" batch_size, num_boxes = boxes.shape[:2] height, width = vision_features.shape[-2:] boxes_embed = self.boxes_direct_project(boxes) # Pool features using ROI align # Convert boxes from CxCyWH to xyxy format and denormalize boxes_xyxy = box_cxcywh_to_xyxy(boxes) scale = torch.tensor([width, height, width, height], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device) scale = scale.view(1, 1, 4) boxes_xyxy = boxes_xyxy * scale # ROI align expects list of boxes per batch element, # convert from bfloat16 to float16 as roi_align only supports float16 and float32 dtype = torch.float16 if vision_features.dtype == torch.bfloat16 else vision_features.dtype sampled_features = torchvision.ops.roi_align( vision_features.to(dtype), boxes_xyxy.to(dtype).unbind(0), self.roi_size ).to(vision_features.dtype) pooled_projection = self.boxes_pool_project(sampled_features) pooled_projection = pooled_projection.view(batch_size, num_boxes, self.hidden_size) boxes_embed = boxes_embed + pooled_projection # Add position encoding center_x, center_y, box_width, box_height = boxes.unbind(-1) pos_enc = self._encode_box_coordinates( center_x.flatten(), center_y.flatten(), box_width.flatten(), box_height.flatten() ) pos_enc = pos_enc.view(batch_size, num_boxes, pos_enc.shape[-1]) pos_projection = self.boxes_pos_enc_project(pos_enc) boxes_embed = boxes_embed + pos_projection # Add label embeddings (positive/negative) label_embed = self.label_embed(boxes_labels.long()) return label_embed + boxes_embed, boxes_mask def forward( self, box_embeddings: torch.Tensor, box_mask: torch.Tensor, box_labels: torch.Tensor, img_feats: tuple[torch.Tensor, ...], img_pos_embeds: tuple[torch.Tensor, ...] | None = None, ): """ Forward pass for encoding geometric prompts. Args: box_embeddings: Box coordinates in CxCyWH format [batch_size, num_boxes, 4] box_mask: Attention mask for boxes [batch_size, num_boxes] box_labels: Labels for boxes (positive/negative) [batch_size, num_boxes] img_feats: Image features from vision encoder img_pos_embeds: Optional position embeddings for image features Returns: Sam3GeometryEncoderOutput containing encoded geometry features and attention mask. """ batch_size = box_embeddings.shape[0] # Prepare vision features for cross-attention: flatten spatial dimensions vision_feats = img_feats[-1] # [B, C, H, W] vision_pos_embeds = img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(vision_feats) vision_feats_flat = vision_feats.flatten(2).transpose(1, 2) # [B, H*W, C] vision_pos_embeds_flat = vision_pos_embeds.flatten(2).transpose(1, 2) # [B, H*W, C] # Normalize image features for pooling operations img_feats_last = img_feats[-1] # [B, C, H, W] img_feats_last = img_feats_last.permute(0, 2, 3, 1) # [B, H, W, C] normalized_img_feats = self.vision_layer_norm(img_feats_last) normalized_img_feats = normalized_img_feats.permute(0, 3, 1, 2) # [B, C, H, W] prompt_embeds, prompt_mask = self._encode_boxes(box_embeddings, box_mask, box_labels, normalized_img_feats) # Add CLS token (always valid) cls_embed = self.cls_embed.weight.view(1, self.hidden_size).unsqueeze(0).expand(batch_size, -1, -1) cls_mask = torch.ones(batch_size, 1, dtype=prompt_mask.dtype, device=prompt_mask.device) prompt_embeds, prompt_mask = concat_padded_sequences(prompt_embeds, prompt_mask, cls_embed, cls_mask) prompt_embeds = self.prompt_layer_norm(self.final_proj(prompt_embeds)) # Create bidirectional attention mask for transformer layers prompt_attention_mask = None if prompt_mask is not None: prompt_attention_mask = create_bidirectional_mask( config=self.config, input_embeds=prompt_embeds, attention_mask=prompt_mask, ) # Apply transformer layers with cross-attention to vision features for layer in self.layers: prompt_embeds = layer( prompt_feats=prompt_embeds, vision_feats=vision_feats_flat, vision_pos_encoding=vision_pos_embeds_flat, prompt_mask=prompt_attention_mask, ) # Final output normalization prompt_embeds = self.output_layer_norm(prompt_embeds) return Sam3GeometryEncoderOutput( last_hidden_state=prompt_embeds, attention_mask=prompt_mask, ) class Sam3DetrEncoderLayer(nn.Module): """DETR encoder layer with self-attention and cross-attention.""" def __init__(self, config: Sam3DETREncoderConfig): super().__init__() self.config = config self.layer_norm1 = nn.LayerNorm(config.hidden_size) self.self_attn = Sam3Attention(config) self.dropout = nn.Dropout(config.dropout) self.cross_attn = Sam3Attention(config) self.layer_norm2 = nn.LayerNorm(config.hidden_size) self.mlp = Sam3MLP(config) self.layer_norm3 = nn.LayerNorm(config.hidden_size) def forward( self, vision_feats: Tensor, prompt_feats: Tensor, vision_pos_encoding: Tensor, prompt_cross_attn_mask: Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ): """ Forward pass for DETR encoder layer. Args: vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states) prompt_feats: Text prompt features [batch_size, text_len, hidden_size] vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size] prompt_cross_attn_mask: Cross-attention mask for prompt features Returns: Updated vision features [batch_size, vision_len, hidden_size] """ # Self-attention on vision features with position encoding residual = vision_feats hidden_states = self.layer_norm1(vision_feats) hidden_states_with_pos = hidden_states + vision_pos_encoding hidden_states, _ = self.self_attn( query=hidden_states_with_pos, key=hidden_states_with_pos, value=hidden_states, **kwargs, ) hidden_states = self.dropout(hidden_states) + residual # Cross-attention: vision queries attend to text/prompt features residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states, _ = self.cross_attn( query=hidden_states, key=prompt_feats, value=prompt_feats, attention_mask=prompt_cross_attn_mask, **kwargs, ) hidden_states = self.dropout(hidden_states) + residual # MLP residual = hidden_states hidden_states = self.layer_norm3(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.dropout(hidden_states) + residual return hidden_states class Sam3DetrEncoder(Sam3PreTrainedModel): """ DETR-style encoder that processes multi-level vision features with text fusion. This encoder processes vision features from multiple levels (e.g., FPN features at different resolutions) and fuses them with text prompts through a stack of transformer encoder layers. """ _can_record_outputs = { "hidden_states": Sam3DetrEncoderLayer, "attentions": Sam3Attention, } def __init__(self, config: Sam3DETREncoderConfig): super().__init__(config) self.config = config self.hidden_size = config.hidden_size self.layers = nn.ModuleList([Sam3DetrEncoderLayer(config) for _ in range(config.num_layers)]) self.post_init() def _prepare_multilevel_features( self, vision_features: list[torch.Tensor], vision_pos_embeds: list[torch.Tensor], ): """ Prepare multi-level vision features by flattening spatial dimensions and adding level embeddings. Args: vision_features: List of vision features at different levels [batch_size, channels, height, width] vision_pos_embeds: List of position embeddings for each level [batch_size, channels, height, width] Returns: Tuple containing flattened features, position embeddings, and spatial metadata """ features_flattened = [] pos_embeds_flattened = [] spatial_shapes = [] for features, pos_embed in zip(vision_features, vision_pos_embeds): height, width = features.shape[-2:] spatial_shapes.append((height, width)) # Flatten spatial dimensions: [batch_size, channels, height, width] -> [batch_size, height*width, channels] features = features.flatten(2).transpose(1, 2) pos_embed = pos_embed.flatten(2).transpose(1, 2) features_flattened.append(features) pos_embeds_flattened.append(pos_embed) # Concatenate all levels into single sequence features_flattened = torch.cat(features_flattened, dim=1) pos_embeds_flattened = torch.cat(pos_embeds_flattened, dim=1) spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=features_flattened.device) return ( features_flattened, pos_embeds_flattened, spatial_shapes, ) @check_model_inputs def forward( self, vision_features: list[torch.Tensor], text_features: torch.Tensor, vision_pos_embeds: list[torch.Tensor] | None = None, text_mask: torch.Tensor | None = None, spatial_sizes: list[tuple[int, int]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Sam3DETREncoderOutput: """ Forward pass for the DETR encoder. Args: vision_features: List of vision features at different levels text_features: Text prompt features [batch_size, seq_len, hidden_size] vision_pos_embeds: Optional list of position embeddings for each level text_mask: Optional text padding mask [batch_size, seq_len] spatial_sizes: Optional list of (height, width) tuples for reshaping Returns: Sam3DETREncoderOutput containing encoded features and metadata. """ batch_size = vision_features[0].shape[0] if vision_features[0].dim() == 4 else vision_features[0].shape[1] # TODO: See if we can remove that reshaping and just use the features as is. if spatial_sizes is not None: for i, (height, width) in enumerate(spatial_sizes): # Reshape from [height*width, batch_size, channels] to [batch_size, channels, height, width] vision_features[i] = vision_features[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1) vision_pos_embeds[i] = vision_pos_embeds[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1) # Flatten multi-level features for encoder processing ( features_flattened, pos_embeds_flattened, spatial_shapes, ) = self._prepare_multilevel_features(vision_features, vision_pos_embeds) prompt_cross_attn_mask = None if text_mask is not None: prompt_cross_attn_mask = create_bidirectional_mask( config=self.config, input_embeds=features_flattened, attention_mask=text_mask, encoder_hidden_states=text_features, ) hidden_states = features_flattened for layer in self.layers: hidden_states = layer( hidden_states, prompt_feats=text_features, vision_pos_encoding=pos_embeds_flattened, prompt_cross_attn_mask=prompt_cross_attn_mask, **kwargs, ) return Sam3DETREncoderOutput( last_hidden_state=hidden_states, pos_embeds_flattened=pos_embeds_flattened, text_features=text_features, spatial_shapes=spatial_shapes, ) class Sam3DecoderMLP(nn.Module): """Simple 2 or 3-layer MLP for decoder components.""" def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 2): super().__init__() if num_layers == 2: self.layer1 = nn.Linear(input_dim, hidden_dim) self.layer2 = nn.Linear(hidden_dim, output_dim) self.layer3 = None elif num_layers == 3: self.layer1 = nn.Linear(input_dim, hidden_dim) self.layer2 = nn.Linear(hidden_dim, hidden_dim) self.layer3 = nn.Linear(hidden_dim, output_dim) else: raise ValueError(f"Only 2 or 3 layers supported, got {num_layers}") def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.relu(self.layer1(x)) if self.layer3 is not None: x = F.relu(self.layer2(x)) x = self.layer3(x) else: x = self.layer2(x) return x class Sam3DetrDecoderLayer(nn.Module): """DETR decoder layer with self-attention, text cross-attention, and vision cross-attention.""" def __init__(self, config: Sam3DETRDecoderConfig): super().__init__() self.config = config self.self_attn = Sam3Attention(config) self.self_attn_dropout = nn.Dropout(config.dropout) self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size) self.text_cross_attn = Sam3Attention(config) self.text_cross_attn_dropout = nn.Dropout(config.dropout) self.text_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size) self.vision_cross_attn = Sam3Attention(config) self.vision_cross_attn_dropout = nn.Dropout(config.dropout) self.vision_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size) self.mlp = Sam3MLP(config) self.mlp_layer_norm = nn.LayerNorm(config.hidden_size) self.mlp_dropout = nn.Dropout(config.dropout) def forward( self, hidden_states: torch.Tensor, query_pos: torch.Tensor, text_features: torch.Tensor, vision_features: torch.Tensor, vision_pos_encoding: torch.Tensor, text_cross_attn_mask: torch.Tensor | None = None, vision_cross_attn_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: """ Forward pass for decoder layer. Args: hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0) query_pos: Query position embeddings [batch_size, num_queries, hidden_size] text_features: Text features [batch_size, seq_len, hidden_size] vision_features: Vision features [batch_size, height*width, hidden_size] vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size] text_cross_attn_mask: Text cross-attention mask vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token Returns: Updated hidden states (including presence token at position 0) """ # Prepend zeros to query_pos for presence token query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0) # Self-attention with query position encoding residual = hidden_states query_with_pos = hidden_states + query_pos attn_output, _ = self.self_attn( query=query_with_pos, key=query_with_pos, value=hidden_states, attention_mask=None, **kwargs, ) hidden_states = residual + self.self_attn_dropout(attn_output) hidden_states = self.self_attn_layer_norm(hidden_states) # Text cross-attention: queries attend to text features residual = hidden_states query_with_pos = hidden_states + query_pos attn_output, _ = self.text_cross_attn( query=query_with_pos, key=text_features, value=text_features, attention_mask=text_cross_attn_mask, **kwargs, ) hidden_states = residual + self.text_cross_attn_dropout(attn_output) hidden_states = self.text_cross_attn_layer_norm(hidden_states) # Vision cross-attention: queries attend to vision features (with RPB) residual = hidden_states query_with_pos = hidden_states + query_pos key_with_pos = vision_features + vision_pos_encoding attn_output, _ = self.vision_cross_attn( query=query_with_pos, key=key_with_pos, value=vision_features, attention_mask=vision_cross_attn_mask, **kwargs, ) hidden_states = residual + self.vision_cross_attn_dropout(attn_output) hidden_states = self.vision_cross_attn_layer_norm(hidden_states) # MLP residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = residual + self.mlp_dropout(hidden_states) hidden_states = self.mlp_layer_norm(hidden_states) return hidden_states class Sam3DetrDecoder(Sam3PreTrainedModel): """ DETR-style decoder with box refinement and presence token. Simplified version that assumes: - Box refinement is always enabled - Intermediate outputs are always returned - BoxRPB (relative position bias) with log-scale encoding - Presence token is used """ _can_record_outputs = { "hidden_states": Sam3DetrDecoderLayer, "attentions": Sam3Attention, } def __init__( self, config: Sam3DETRDecoderConfig, ): super().__init__(config) self.config = config self.hidden_size = config.hidden_size self.layers = nn.ModuleList([Sam3DetrDecoderLayer(config) for _ in range(config.num_layers)]) self.output_layer_norm = nn.LayerNorm(config.hidden_size) self.box_head = Sam3DecoderMLP(config.hidden_size, config.hidden_size, 4, 3) self.query_embed = nn.Embedding(config.num_queries, config.hidden_size) self.reference_points = nn.Embedding(config.num_queries, 4) self.presence_token = nn.Embedding(1, config.hidden_size) self.presence_head = Sam3DecoderMLP(config.hidden_size, config.hidden_size, 1, 3) self.presence_layer_norm = nn.LayerNorm(config.hidden_size) self.clamp_presence_logit_max_val = 10.0 self.ref_point_head = Sam3DecoderMLP(2 * config.hidden_size, config.hidden_size, config.hidden_size, 2) self.box_rpb_embed_x = Sam3DecoderMLP(2, config.hidden_size, config.num_attention_heads, 2) self.box_rpb_embed_y = Sam3DecoderMLP(2, config.hidden_size, config.num_attention_heads, 2) self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=False) self.post_init() @compile_compatible_method_lru_cache(maxsize=1) def _get_coords( self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device ) -> tuple[torch.Tensor, torch.Tensor]: """Generate normalized coordinate grids.""" coords_h = torch.arange(0, height, device=device, dtype=dtype) / height coords_w = torch.arange(0, width, device=device, dtype=dtype) / width return coords_h, coords_w def _get_rpb_matrix( self, reference_boxes: torch.Tensor, spatial_shape: tuple[torch.Tensor, torch.Tensor] ) -> torch.Tensor: """ Compute box relative position bias (RPB) matrix using log-scale encoding. RPB helps the decoder attend to relevant spatial locations based on predicted box positions. Args: reference_boxes: Reference boxes [batch_size, num_queries, 4] in sigmoid space spatial_shape: (height, width) of the vision features as tensors Returns: RPB matrix [batch_size, num_heads, num_queries, height*width] """ height, width = spatial_shape boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes) batch_size, num_queries, _ = boxes_xyxy.shape # Generate coordinate grids coords_h, coords_w = self._get_coords( height, width, dtype=reference_boxes.dtype, device=reference_boxes.device ) # Compute deltas between coordinates and box boundaries deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2] deltas_y = deltas_y.view(batch_size, num_queries, -1, 2) deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2] deltas_x = deltas_x.view(batch_size, num_queries, -1, 2) # Apply log-scale encoding deltas_x_log = deltas_x * 8 deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / math.log2(8) deltas_y_log = deltas_y * 8 deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / math.log2(8) # Embed deltas deltas_x = self.box_rpb_embed_x(deltas_x_log) # [batch_size, num_queries, width, num_heads] deltas_y = self.box_rpb_embed_y(deltas_y_log) # [batch_size, num_queries, height, num_heads] # Combine into 2D bias matrix rpb_matrix = deltas_y.unsqueeze(3) + deltas_x.unsqueeze( 2 ) # [batch_size, num_queries, height, width, num_heads] rpb_matrix = rpb_matrix.flatten(2, 3) # [batch_size, num_queries, height*width, num_heads] rpb_matrix = rpb_matrix.permute(0, 3, 1, 2).contiguous() # [batch_size, num_heads, num_queries, height*width] return rpb_matrix @check_model_inputs def forward( self, vision_features: torch.Tensor, text_features: torch.Tensor, vision_pos_encoding: torch.Tensor, text_mask: torch.Tensor | None = None, spatial_shapes: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Sam3DETRDecoderOutput: """ Forward pass for the DETR decoder. Args: vision_features: Vision features [batch_size, height*width, hidden_size] text_features: Text features [batch_size, seq_len, hidden_size] vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size] text_mask: Text padding mask [batch_size, seq_len] where True=valid, False=padding spatial_shapes: Spatial shapes [num_levels, 2] Returns: Sam3DETRDecoderOutput containing decoder outputs from all layers. """ batch_size = vision_features.shape[0] query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1) reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1) reference_boxes = reference_boxes.sigmoid() presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1) # Concatenate presence token with query embeddings hidden_states = torch.cat([presence_token, query_embeds], dim=1) text_cross_attn_mask = None if text_mask is not None: text_cross_attn_mask = create_bidirectional_mask( config=self.config, input_embeds=hidden_states, attention_mask=text_mask, encoder_hidden_states=text_features, ) intermediate_outputs = [] intermediate_boxes = [reference_boxes] intermediate_presence_logits = [] for layer in self.layers: # Generate sine embeddings for conditional queries reference_points_input = reference_boxes.unsqueeze(2) query_sine_embed = self.position_encoding.encode_boxes(reference_points_input[:, :, 0, :]) query_pos = self.ref_point_head(query_sine_embed) # Compute box relative position bias (RPB) attention mask vision_cross_attn_mask = None if spatial_shapes is not None and spatial_shapes.shape[0] == 1: spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1]) rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape) # Prepend zeros row for presence token (it attends to all vision tokens equally) vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0) hidden_states = layer( hidden_states, query_pos=query_pos, text_features=text_features, vision_features=vision_features, vision_pos_encoding=vision_pos_encoding, text_cross_attn_mask=text_cross_attn_mask, vision_cross_attn_mask=vision_cross_attn_mask, **kwargs, ) # Extract query hidden states (without presence token) for box refinement query_hidden_states = hidden_states[:, 1:] # Box refinement: predict delta and update reference boxes reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes) delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states)) new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid() reference_boxes = new_reference_boxes.detach() intermediate_outputs.append(self.output_layer_norm(query_hidden_states)) intermediate_boxes.append(new_reference_boxes) # Process presence token presence_hidden = hidden_states[:, :1] presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1) presence_logits = presence_logits.clamp( min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val ) intermediate_presence_logits.append(presence_logits) # Stack outputs from all layers intermediate_outputs = torch.stack(intermediate_outputs) intermediate_boxes = torch.stack(intermediate_boxes[:-1]) intermediate_presence_logits = torch.stack(intermediate_presence_logits) return Sam3DETRDecoderOutput( intermediate_hidden_states=intermediate_outputs, reference_boxes=intermediate_boxes, presence_logits=intermediate_presence_logits, ) class Sam3DotProductScoring(nn.Module): """ Computes classification scores by computing dot product between projected decoder queries and pooled text features. This is used to determine confidence/presence scores for each query. """ def __init__(self, config: Sam3Config): super().__init__() self.config = config hidden_size = config.detr_decoder_config.hidden_size projection_dim = config.detr_decoder_config.hidden_size self.text_mlp = Sam3DecoderMLP( input_dim=hidden_size, hidden_dim=config.detr_decoder_config.intermediate_size, output_dim=hidden_size, num_layers=2, ) self.text_mlp_dropout = nn.Dropout(config.detr_decoder_config.dropout) self.text_mlp_out_norm = nn.LayerNorm(hidden_size) # Projections for text and query features self.text_proj = nn.Linear(hidden_size, projection_dim) self.query_proj = nn.Linear(hidden_size, projection_dim) # Scale factor for dot product self.scale = float(1.0 / np.sqrt(projection_dim)) # Clamping to avoid numerical issues self.clamp_logits = True self.clamp_max_val = 12.0 def _pool_text_features(self, text_features: torch.Tensor, text_mask: torch.Tensor | None) -> torch.Tensor: """ Mean pool text features, accounting for padding. Args: text_features: [batch_size, seq_len, hidden_size] text_mask: [batch_size, seq_len] where True indicates valid tokens, False indicates padding Returns: pooled_text: [batch_size, hidden_size] """ if text_mask is None: # No padding, simple mean return text_features.mean(dim=1) is_valid = text_mask.to(text_features.dtype).unsqueeze(-1) # [batch_size, seq_len, 1] # Count valid tokens per batch num_valid = is_valid.sum(dim=1).clamp(min=1.0) # [batch_size, 1] # Mean pool only over valid tokens pooled_text = (text_features * is_valid).sum(dim=1) / num_valid # [batch_size, hidden_size] return pooled_text def forward( self, decoder_hidden_states: torch.Tensor, text_features: torch.Tensor, text_mask: torch.Tensor | None = None, ) -> torch.Tensor: """ Compute classification scores via dot product. Args: decoder_hidden_states: [num_layers, batch_size, num_queries, hidden_size] text_features: [batch_size, seq_len, hidden_size] text_mask: [batch_size, seq_len] where True=valid, False=padding Returns: scores: [num_layers, batch_size, num_queries, 1] """ orig_text_features = text_features text_features = self.text_mlp(text_features) text_features = self.text_mlp_dropout(text_features) text_features = text_features + orig_text_features text_features = self.text_mlp_out_norm(text_features) pooled_text = self._pool_text_features(text_features, text_mask) proj_text = self.text_proj(pooled_text) proj_queries = self.query_proj(decoder_hidden_states) proj_text = proj_text.unsqueeze(-1) scores = torch.matmul(proj_queries, proj_text.unsqueeze(0)) scores = scores * self.scale if self.clamp_logits: scores = scores.clamp(min=-self.clamp_max_val, max=self.clamp_max_val) return scores class Sam3MaskEmbedder(nn.Module): """ MLP that embeds object queries for mask prediction. Similar to MaskFormer's mask embedder. """ def __init__(self, config: Sam3MaskDecoderConfig): super().__init__() self.config = config hidden_size = config.hidden_size self.layers = nn.ModuleList( [ nn.Linear(hidden_size, hidden_size), nn.Linear(hidden_size, hidden_size), nn.Linear(hidden_size, hidden_size), ] ) self.activation = nn.ReLU() def forward(self, queries: torch.Tensor) -> torch.Tensor: """ Args: queries: Query embeddings [batch_size, num_queries, hidden_size] Returns: Mask embeddings [batch_size, num_queries, hidden_size] """ hidden_states = queries for i, layer in enumerate(self.layers): hidden_states = layer(hidden_states) if i < len(self.layers) - 1: hidden_states = self.activation(hidden_states) return hidden_states class Sam3PixelDecoder(nn.Module): """ Feature Pyramid Network (FPN) decoder that generates pixel-level features. Inspired by MaskFormer's pixel decoder. """ def __init__(self, config: Sam3MaskDecoderConfig): super().__init__() self.config = config hidden_size = config.hidden_size num_upsampling_stages = config.num_upsampling_stages # Create conv layers and norms for FPN self.conv_layers = nn.ModuleList( [ nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1) for _ in range(num_upsampling_stages) ] ) self.norms = nn.ModuleList([nn.GroupNorm(8, hidden_size) for _ in range(num_upsampling_stages)]) self.out_channels = hidden_size def forward(self, backbone_features: list[torch.Tensor]) -> torch.Tensor: """ Args: backbone_features: List of backbone features [batch_size, hidden_size, H_i, W_i] from low to high resolution (assumes already projected to hidden_size) Returns: Pixel embeddings [batch_size, hidden_size, H, W] at the finest resolution """ # Start from the coarsest feature (last in list) prev_fpn = backbone_features[-1] # Iterate through features from coarse to fine (excluding the last which we started with) for layer_idx, backbone_feat in enumerate(reversed(backbone_features[:-1])): # Upsample previous FPN output to match current backbone feature size prev_fpn = F.interpolate(prev_fpn, size=backbone_feat.shape[-2:], mode="nearest") # Add skip connection prev_fpn = prev_fpn + backbone_feat # Apply conv and norm prev_fpn = self.conv_layers[layer_idx](prev_fpn) prev_fpn = self.norms[layer_idx](prev_fpn) prev_fpn = F.relu(prev_fpn) return prev_fpn class Sam3MaskDecoder(Sam3PreTrainedModel): """ Mask decoder that combines object queries with pixel-level features to predict instance masks. Also produces a semantic segmentation output and supports cross-attention to prompts. """ _can_record_outputs = { "attentions": Sam3Attention, } def __init__(self, config: Sam3MaskDecoderConfig): super().__init__(config) self.config = config hidden_size = config.hidden_size # Pixel decoder (FPN) self.pixel_decoder = Sam3PixelDecoder(config) # Mask embedder (MLP to transform queries) self.mask_embedder = Sam3MaskEmbedder(config) # Projection from pixel decoder output to mask embedding space self.instance_projection = nn.Conv2d(self.pixel_decoder.out_channels, hidden_size, kernel_size=1) # Semantic segmentation head (always present in UniversalSegmentationHead) self.semantic_projection = nn.Conv2d(self.pixel_decoder.out_channels, 1, kernel_size=1) self.prompt_cross_attn = Sam3Attention(config) self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size) self.prompt_cross_attn_dropout = nn.Dropout(config.dropout) self.post_init() @check_model_inputs def forward( self, decoder_queries: torch.Tensor, backbone_features: list[torch.Tensor], encoder_hidden_states: torch.Tensor, prompt_features: torch.Tensor | None = None, prompt_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Sam3MaskDecoderOutput: """ Args: decoder_queries: Decoder output queries [batch_size, num_queries, hidden_size] backbone_features: List of backbone features to process through FPN encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size] prompt_features: Prompt features (text + geometry) for cross-attention [batch_size, prompt_len, hidden_size] prompt_mask: Padding mask [batch_size, prompt_len] where True=valid, False=padding Returns: Sam3MaskDecoderOutput containing predicted masks and semantic segmentation. """ if prompt_features is not None: # Cross-attention: encoder features attend to prompt features residual = encoder_hidden_states normed_hidden_states = self.prompt_cross_attn_norm(encoder_hidden_states) cross_attn_mask = None if prompt_mask is not None: cross_attn_mask = create_bidirectional_mask( config=self.config, input_embeds=normed_hidden_states, encoder_hidden_states=prompt_features, attention_mask=prompt_mask, ) attn_output, _ = self.prompt_cross_attn( query=normed_hidden_states, key=prompt_features, value=prompt_features, attention_mask=cross_attn_mask, **kwargs, ) encoder_hidden_states = residual + self.prompt_cross_attn_dropout(attn_output) # Process backbone features through FPN to get pixel embeddings pixel_embed = self._embed_pixels( backbone_features=backbone_features, encoder_hidden_states=encoder_hidden_states, ) # Predict instance masks via dot product between query embeddings and pixel embeddings instance_embeds = self.instance_projection(pixel_embed) mask_embeddings = self.mask_embedder(decoder_queries) pred_masks = torch.einsum("bqc,bchw->bqhw", mask_embeddings, instance_embeds) # Generate semantic segmentation semantic_seg = self.semantic_projection(pixel_embed) return Sam3MaskDecoderOutput( pred_masks=pred_masks, semantic_seg=semantic_seg, ) def _embed_pixels( self, backbone_features: list[torch.Tensor], encoder_hidden_states: torch.Tensor, ) -> torch.Tensor: """ Embed pixels by combining backbone FPN features with encoder vision features. The encoder vision features replace the finest-resolution backbone feature. Args: backbone_features: List of backbone features [batch_size, C, H_i, W_i] encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size] Returns: Pixel embeddings [batch_size, hidden_size, H, W] """ backbone_visual_feats = [feat.clone() for feat in backbone_features] # Extract vision features from encoder output and reshape to spatial format spatial_dim = backbone_features[-1].shape[-2] * backbone_features[-1].shape[-1] encoder_visual_embed = encoder_hidden_states[:, :spatial_dim, :] batch_size, _, hidden_size = encoder_visual_embed.shape height, width = backbone_features[-1].shape[-2:] encoder_visual_embed = encoder_visual_embed.transpose(1, 2).reshape(batch_size, hidden_size, height, width) # Replace finest backbone feature with encoder vision features backbone_visual_feats[-1] = encoder_visual_embed # Process through FPN decoder pixel_embed = self.pixel_decoder(backbone_visual_feats) return pixel_embed class Sam3Model(Sam3PreTrainedModel): input_modalities = ["image", "text"] _checkpoint_conversion_mapping = { r"detector_model.(.+)": r"\1" # the regex allows to remove the prefix, and add it back in revert mode } _keys_to_ignore_on_load_unexpected = [ r"^tracker_model.", r"^tracker_neck.", ] def __init__(self, config: Sam3Config): # loading from a sam3_video config if hasattr(config, "detector_config") and config.detector_config is not None: detector_config = config.detector_config if isinstance(detector_config, dict): detector_config = Sam3Config(**detector_config) config = detector_config super().__init__(config) self.vision_encoder = Sam3VisionModel(config.vision_config) self.text_encoder = CLIPTextModelWithProjection(config.text_config) self.vocab_size = config.text_config.vocab_size # Project text features from text encoder hidden size to model hidden size # CLIP text encoder outputs 1024-dim features, but we need 256-dim for DETR self.text_projection = nn.Linear(config.text_config.hidden_size, config.detr_encoder_config.hidden_size) # Pass _attn_implementation to subconfigs BEFORE creating modules config.geometry_encoder_config._attn_implementation = config._attn_implementation config.detr_encoder_config._attn_implementation = config._attn_implementation config.detr_decoder_config._attn_implementation = config._attn_implementation config.mask_decoder_config._attn_implementation = config._attn_implementation self.geometry_encoder = Sam3GeometryEncoder(config.geometry_encoder_config) self.detr_encoder = Sam3DetrEncoder(config.detr_encoder_config) self.detr_decoder = Sam3DetrDecoder(config.detr_decoder_config) self.mask_decoder = Sam3MaskDecoder(config.mask_decoder_config) # Dot product scoring to compute classification scores self.dot_product_scoring = Sam3DotProductScoring(config) self.post_init() @can_return_tuple @auto_docstring def get_text_features( self, input_ids: torch.LongTensor, attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" Example: ```python >>> from transformers import Sam3Model, Sam3Processor >>> from PIL import Image >>> import httpx >>> from io import BytesIO >>> model = Sam3Model.from_pretrained("facebook/sam3") >>> processor = Sam3Processor.from_pretrained("facebook/sam3") >>> # Pre-compute text embeddings >>> text_inputs = processor(text="cat", return_tensors="pt") >>> text_embeds = model.get_text_features(**text_inputs).pooler_output >>> # Reuse text embeddings for multiple images >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg" >>> with httpx.stream("GET", url) as response: ... image = Image.open(BytesIO(response.read())) >>> img_inputs = processor(images=image, return_tensors="pt") >>> outputs = model(pixel_values=img_inputs.pixel_values, text_embeds=text_embeds) ``` """ text_outputs = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, **kwargs ) last_hidden_state = text_outputs.last_hidden_state text_outputs.pooler_output = self.text_projection(last_hidden_state) return text_outputs @auto_docstring def get_vision_features( self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs], ) -> Sam3VisionEncoderOutput: r""" Example: ```python >>> from transformers import Sam3Model, Sam3Processor >>> from PIL import Image >>> import httpx >>> from io import BytesIO >>> model = Sam3Model.from_pretrained("facebook/sam3") >>> processor = Sam3Processor.from_pretrained("facebook/sam3") >>> # Pre-compute vision embeddings >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg" >>> with httpx.stream("GET", url) as response: ... image = Image.open(BytesIO(response.read())) >>> img_inputs = processor(images=image, return_tensors="pt") >>> vision_embeds = model.get_vision_features(pixel_values=img_inputs.pixel_values) >>> # Reuse vision embeddings for multiple text prompts >>> text_inputs = processor(text="cat", return_tensors="pt") >>> outputs = model(vision_embeds=vision_embeds, input_ids=text_inputs.input_ids) ``` """ vision_outputs = self.vision_encoder(pixel_values, **kwargs) return vision_outputs @check_model_inputs @auto_docstring def forward( self, pixel_values: torch.FloatTensor | None = None, vision_embeds: Sam3VisionEncoderOutput | None = None, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, text_embeds: torch.FloatTensor | None = None, input_boxes: torch.FloatTensor | None = None, input_boxes_labels: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> Sam3ImageSegmentationOutput: r""" vision_embeds (`Sam3VisionEncoderOutput`, *optional*): Pre-computed vision embeddings. Can be used to easily reuse vision embeddings. If provided, `pixel_values` should not be passed. Mutually exclusive with `pixel_values`. text_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Pre-computed text embeddings. Can be used to easily reuse text embeddings. If provided, `input_ids` should not be passed. Mutually exclusive with `input_ids`. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`, *optional*): Normalized box coordinates in [0, 1] range, in (cx, cy, w, h) format. input_boxes_labels (`torch.LongTensor` of shape `(batch_size, num_boxes)`, *optional*): Labels for boxes: 1 (positive), 0 (negative). Example: ```python >>> from PIL import Image >>> import httpx >>> from io import BytesIO >>> from transformers import AutoModel, AutoProcessor >>> model = AutoModel.from_pretrained("facebook/sam3") >>> processor = AutoProcessor.from_pretrained("facebook/sam3") >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" >>> with httpx.stream("GET", url) as response: ... image = Image.open(BytesIO(response.read())).convert("RGB") >>> text = "car" >>> inputs = processor(images=image, text=text, return_tensors="pt") >>> # Get segmentation output >>> outputs = model(**inputs) >>> pred_masks = outputs.pred_masks >>> pred_boxes = outputs.pred_boxes ``` """ if (pixel_values is None) == (vision_embeds is None): raise ValueError("You must specify exactly one of pixel_values or vision_embeds") if (input_ids is None) == (text_embeds is None): raise ValueError("You must specify exactly one of input_ids or text_embeds") if pixel_values is not None: batch_size = pixel_values.shape[0] device = pixel_values.device else: batch_size = vision_embeds.fpn_hidden_states[0].shape[0] device = vision_embeds.fpn_hidden_states[0].device if vision_embeds is None: vision_outputs = self.vision_encoder(pixel_values, **kwargs) else: vision_outputs = vision_embeds fpn_hidden_states = vision_outputs.fpn_hidden_states[:-1] fpn_position_encoding = vision_outputs.fpn_position_encoding[:-1] if text_embeds is None: text_features = self.get_text_features( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ).pooler_output else: text_features = text_embeds text_mask = attention_mask.bool() if attention_mask is not None else None has_geometry_prompts = input_boxes is not None and input_boxes.numel() > 0 geometry_prompt_features = None geometry_prompt_mask = None if has_geometry_prompts: if input_boxes is not None and input_boxes.numel() > 0: box_embeddings = input_boxes # [batch_size, num_boxes, 4] box_labels = ( input_boxes_labels if input_boxes_labels is not None else torch.ones_like(box_embeddings[..., 0], dtype=torch.long) ) box_mask = ( (input_boxes_labels != -10) if input_boxes_labels is not None else torch.ones(batch_size, input_boxes.shape[1], dtype=torch.bool, device=device) ) box_labels = torch.where(box_labels == -10, 0, box_labels) else: box_embeddings = torch.zeros(batch_size, 0, 4, dtype=text_features.dtype, device=device) box_labels = torch.zeros(batch_size, 0, dtype=torch.long, device=device) box_mask = torch.zeros(batch_size, 0, dtype=torch.bool, device=device) geometry_outputs = self.geometry_encoder( box_embeddings=box_embeddings, box_mask=box_mask, box_labels=box_labels, img_feats=fpn_hidden_states, img_pos_embeds=fpn_position_encoding, ) geometry_prompt_features = geometry_outputs.last_hidden_state geometry_prompt_mask = geometry_outputs.attention_mask if geometry_prompt_features is not None: # Repeat text_features for all geometry prompts if text_features.shape[0] == 1 and geometry_prompt_features.shape[0] > 1: text_features = text_features.repeat(geometry_prompt_features.shape[0], 1, 1) combined_prompt_features = torch.cat([text_features, geometry_prompt_features], dim=1) if text_mask is not None and text_mask.shape[0] == 1 and geometry_prompt_mask.shape[0] > 1: text_mask = text_mask.repeat(geometry_prompt_mask.shape[0], 1) if text_mask is not None and geometry_prompt_mask is not None: combined_prompt_mask = torch.cat([text_mask, geometry_prompt_mask], dim=1) elif text_mask is not None: geo_valid_mask = torch.ones( batch_size, geometry_prompt_features.shape[1], dtype=torch.bool, device=device ) combined_prompt_mask = torch.cat([text_mask, geo_valid_mask], dim=1) elif geometry_prompt_mask is not None: text_valid_mask = torch.ones(batch_size, text_features.shape[1], dtype=torch.bool, device=device) combined_prompt_mask = torch.cat([text_valid_mask, geometry_prompt_mask], dim=1) else: combined_prompt_mask = None else: combined_prompt_features = text_features combined_prompt_mask = text_mask encoder_outputs = self.detr_encoder( vision_features=[fpn_hidden_states[-1]], text_features=combined_prompt_features, vision_pos_embeds=[fpn_position_encoding[-1]], text_mask=combined_prompt_mask, **kwargs, ) decoder_outputs = self.detr_decoder( vision_features=encoder_outputs.last_hidden_state, text_features=encoder_outputs.text_features, vision_pos_encoding=encoder_outputs.pos_embeds_flattened, text_mask=combined_prompt_mask, spatial_shapes=encoder_outputs.spatial_shapes, **kwargs, ) # Refine boxes from decoder all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states) reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes) all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid() all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh) all_pred_logits = self.dot_product_scoring( decoder_hidden_states=decoder_outputs.intermediate_hidden_states, text_features=encoder_outputs.text_features, text_mask=combined_prompt_mask, ).squeeze(-1) pred_logits = all_pred_logits[-1] pred_boxes = all_pred_boxes[-1] decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1] presence_logits = decoder_outputs.presence_logits[-1] mask_outputs = self.mask_decoder( decoder_queries=decoder_hidden_states, backbone_features=list(fpn_hidden_states), encoder_hidden_states=encoder_outputs.last_hidden_state, prompt_features=combined_prompt_features, prompt_mask=combined_prompt_mask, **kwargs, ) return Sam3ImageSegmentationOutput( pred_masks=mask_outputs.pred_masks, pred_boxes=pred_boxes, pred_logits=pred_logits, presence_logits=presence_logits, semantic_seg=mask_outputs.semantic_seg, decoder_hidden_states=decoder_outputs.hidden_states, decoder_reference_boxes=decoder_outputs.reference_boxes, encoder_hidden_states=encoder_outputs.hidden_states, vision_hidden_states=vision_outputs.hidden_states, vision_attentions=vision_outputs.attentions, detr_encoder_attentions=encoder_outputs.attentions, detr_decoder_attentions=decoder_outputs.attentions, mask_decoder_attentions=mask_outputs.attentions, ) __all__ = ["Sam3Model", "Sam3VisionModel", "Sam3ViTModel", "Sam3PreTrainedModel"]