You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
2413 lines
99 KiB
2413 lines
99 KiB
# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
from collections.abc import Callable, Iterable
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
from torch import Tensor
|
|
|
|
from transformers import CLIPTextModelWithProjection
|
|
|
|
from ... import initialization as init
|
|
from ...activations import ACT2FN
|
|
from ...masking_utils import create_bidirectional_mask
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPooling,
|
|
ModelOutput,
|
|
)
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
from ...utils import auto_docstring, can_return_tuple, logging
|
|
from ...utils.generic import TransformersKwargs, check_model_inputs, is_flash_attention_requested
|
|
from ..auto import AutoModel
|
|
from .configuration_sam3 import (
|
|
Sam3Config,
|
|
Sam3DETRDecoderConfig,
|
|
Sam3DETREncoderConfig,
|
|
Sam3GeometryEncoderConfig,
|
|
Sam3MaskDecoderConfig,
|
|
Sam3VisionConfig,
|
|
Sam3ViTConfig,
|
|
)
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class Sam3VisionEncoderOutput(BaseModelOutputWithPooling):
|
|
r"""
|
|
fpn_hidden_states (`tuple[torch.FloatTensor]`):
|
|
Tuple of multi-level FPN feature maps.
|
|
fpn_position_encoding (`tuple[torch.FloatTensor]`):
|
|
Tuple of position encodings for each FPN level.
|
|
"""
|
|
|
|
fpn_hidden_states: tuple[torch.FloatTensor, ...] = None
|
|
fpn_position_encoding: tuple[torch.FloatTensor, ...] = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class Sam3GeometryEncoderOutput(ModelOutput):
|
|
r"""
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_prompts, hidden_size)`):
|
|
Encoded geometry prompt features (boxes).
|
|
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_prompts)`, *optional*):
|
|
Attention mask for geometry prompts where True indicates valid positions and False indicates padding.
|
|
"""
|
|
|
|
last_hidden_state: torch.FloatTensor = None
|
|
attention_mask: torch.BoolTensor | None = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class Sam3DETREncoderOutput(ModelOutput):
|
|
r"""
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Encoded vision features (flattened from multi-level features).
|
|
pos_embeds_flattened (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Flattened position embeddings for the vision features.
|
|
text_features (`torch.FloatTensor` of shape `(batch_size, text_seq_len, hidden_size)`, *optional*):
|
|
Text features (may be pooled after encoder processing).
|
|
spatial_shapes (`torch.LongTensor` of shape `(num_levels, 2)`, *optional*):
|
|
Spatial shapes (height, width) for each feature pyramid level.
|
|
hidden_states (`tuple[torch.FloatTensor]`, *optional*):
|
|
Tuple of hidden states from all encoder layers.
|
|
attentions (`tuple[torch.FloatTensor]`, *optional*):
|
|
Tuple of attention weights from all encoder layers.
|
|
"""
|
|
|
|
last_hidden_state: torch.FloatTensor = None
|
|
pos_embeds_flattened: torch.FloatTensor | None = None
|
|
text_features: torch.FloatTensor | None = None
|
|
spatial_shapes: torch.LongTensor | None = None
|
|
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
attentions: tuple[torch.FloatTensor] | None = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class Sam3DETRDecoderOutput(ModelOutput):
|
|
r"""
|
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, hidden_size)`):
|
|
Decoder hidden states from all layers.
|
|
reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`):
|
|
Predicted reference boxes from all decoder layers in (cx, cy, w, h) format.
|
|
presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`):
|
|
Presence logits from all decoder layers indicating object presence confidence.
|
|
hidden_states (`tuple[torch.FloatTensor]`, *optional*):
|
|
Tuple of hidden states from all decoder layers.
|
|
attentions (`tuple[torch.FloatTensor]`, *optional*):
|
|
Tuple of attention weights from all decoder layers (self-attention and cross-attention).
|
|
"""
|
|
|
|
intermediate_hidden_states: torch.FloatTensor = None
|
|
reference_boxes: torch.FloatTensor = None
|
|
presence_logits: torch.FloatTensor = None
|
|
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
attentions: tuple[torch.FloatTensor] | None = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class Sam3MaskDecoderOutput(ModelOutput):
|
|
r"""
|
|
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
|
|
Predicted segmentation masks for each query.
|
|
semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
|
|
Semantic segmentation output.
|
|
attentions (`tuple[torch.FloatTensor]`, *optional*):
|
|
Tuple of attention weights from mask decoder cross-attention layers.
|
|
"""
|
|
|
|
pred_masks: torch.FloatTensor = None
|
|
semantic_seg: torch.FloatTensor | None = None
|
|
attentions: tuple[torch.FloatTensor] | None = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class Sam3ImageSegmentationOutput(ModelOutput):
|
|
r"""
|
|
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
|
|
Predicted segmentation masks for each query.
|
|
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
|
Predicted bounding boxes in (x1, y1, x2, y2) format.
|
|
pred_logits (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
Classification confidence scores for each query, computed via dot product between
|
|
decoder query features and text features.
|
|
presence_logits (`torch.FloatTensor` of shape `(batch_size, 1)`, *optional*):
|
|
Presence logits from the DETR decoder presence token (last layer only). These indicate whether objects
|
|
are present in the scene. Can be used to compute final scores by multiplying with pred_logits:
|
|
`final_scores = pred_logits.sigmoid() * presence_logits.sigmoid()`.
|
|
semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
|
|
Semantic segmentation output.
|
|
decoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
|
|
Tuple of hidden states from all DETR decoder layers. Each tensor has shape `(batch_size, num_queries, hidden_size)`.
|
|
decoder_reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`, *optional*):
|
|
Reference boxes from all DETR decoder layers.
|
|
encoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
|
|
Tuple of hidden states from all DETR encoder layers.
|
|
vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
|
|
Tuple of hidden states from all vision encoder (ViT) layers.
|
|
vision_attentions (`tuple[torch.FloatTensor]`, *optional*):
|
|
Attention weights from vision encoder (ViT) layers.
|
|
detr_encoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
|
|
Attention weights from DETR encoder layers.
|
|
detr_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
|
|
Attention weights from DETR decoder layers (self-attention and cross-attention).
|
|
mask_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
|
|
Attention weights from mask decoder layers.
|
|
"""
|
|
|
|
pred_masks: torch.FloatTensor = None
|
|
pred_boxes: torch.FloatTensor = None
|
|
pred_logits: torch.FloatTensor | None = None
|
|
presence_logits: torch.FloatTensor | None = None
|
|
semantic_seg: torch.FloatTensor | None = None
|
|
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
decoder_reference_boxes: torch.FloatTensor | None = None
|
|
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
vision_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
vision_attentions: tuple[torch.FloatTensor] | None = None
|
|
detr_encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
detr_decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
mask_decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
|
|
|
|
def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
|
|
"""The inverse function for sigmoid activation function."""
|
|
x = x.clamp(min=0, max=1)
|
|
x1 = x.clamp(min=eps)
|
|
x2 = (1 - x).clamp(min=eps)
|
|
return torch.log(x1 / x2)
|
|
|
|
|
|
def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
|
|
"""
|
|
Concatenates two right-padded sequences, such that the resulting sequence
|
|
is contiguous and also right-padded.
|
|
|
|
Tensors are batch-first, masks are batch-first with True=valid, False=padding.
|
|
|
|
Args:
|
|
seq1: A tensor of shape (batch_size, seq1_length, hidden_size).
|
|
mask1: A tensor of shape (batch_size, seq1_length) with True=valid, False=padding.
|
|
seq2: A tensor of shape (batch_size, seq2_length, hidden_size).
|
|
mask2: A tensor of shape (batch_size, seq2_length) with True=valid, False=padding.
|
|
return_index: If True, also returns the index of the ids of the element of seq2
|
|
in the concatenated sequence. This can be used to retrieve the elements of seq2.
|
|
|
|
Returns:
|
|
A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
|
|
otherwise (concatenated_sequence, concatenated_mask, index).
|
|
The concatenated_mask uses True=valid, False=padding convention.
|
|
"""
|
|
batch_size, seq1_length, hidden_size = seq1.shape
|
|
batch_size2, seq2_length, hidden_size2 = seq2.shape
|
|
|
|
assert batch_size == batch_size2 == mask1.size(0) == mask2.size(0)
|
|
assert hidden_size == hidden_size2
|
|
assert seq1_length == mask1.size(1)
|
|
assert seq2_length == mask2.size(1)
|
|
|
|
actual_seq1_lengths = mask1.sum(dim=-1)
|
|
actual_seq2_lengths = mask2.sum(dim=-1)
|
|
|
|
final_lengths = actual_seq1_lengths + actual_seq2_lengths
|
|
max_length = seq1_length + seq2_length
|
|
|
|
concatenated_mask = (
|
|
torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) < final_lengths[:, None]
|
|
)
|
|
|
|
concatenated_sequence = torch.zeros((batch_size, max_length, hidden_size), device=seq2.device, dtype=seq2.dtype)
|
|
concatenated_sequence[:, :seq1_length, :] = seq1
|
|
|
|
# Shift seq2 elements to start at the end of valid seq1
|
|
index = torch.arange(seq2_length, device=seq2.device)[None].repeat(batch_size, 1)
|
|
index = index + actual_seq1_lengths[:, None]
|
|
|
|
# Scatter seq2 into the right positions
|
|
concatenated_sequence = concatenated_sequence.scatter(1, index[:, :, None].expand(-1, -1, hidden_size), seq2)
|
|
|
|
if return_index:
|
|
return concatenated_sequence, concatenated_mask, index
|
|
|
|
return concatenated_sequence, concatenated_mask
|
|
|
|
|
|
def box_cxcywh_to_xyxy(x):
|
|
"""Convert boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format."""
|
|
x_c, y_c, w, h = x.unbind(-1)
|
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
|
return torch.stack(b, dim=-1)
|
|
|
|
|
|
class Sam3MLP(nn.Module):
|
|
def __init__(self, config: Sam3ViTConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: torch.Tensor | None,
|
|
scaling: float | None = None,
|
|
dropout: float = 0.0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
if scaling is None:
|
|
scaling = query.size(-1) ** -0.5
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class Sam3Attention(nn.Module):
|
|
"""
|
|
Multi-head attention.
|
|
Handles standard [batch_size, seq_len, hidden_size] tensors.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // config.num_attention_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
self.is_causal = False
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
query: [batch_size, query_len, hidden_size]
|
|
key: [batch_size, key_len, hidden_size]
|
|
value: [batch_size, value_len, hidden_size]
|
|
attention_mask: [batch_size, num_heads, query_len, key_len] or broadcastable
|
|
|
|
Returns:
|
|
Tuple of (output, attention_weights)
|
|
output: [batch_size, query_len, hidden_size]
|
|
attention_weights: [batch_size, num_heads, query_len, key_len]
|
|
"""
|
|
batch_size = query.shape[0]
|
|
query_len = query.shape[1]
|
|
key_len = key.shape[1]
|
|
|
|
query = self.q_proj(query).view(batch_size, query_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
|
key = self.k_proj(key).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
|
value = self.v_proj(value).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
if (
|
|
is_flash_attention_requested(self.config)
|
|
and attention_mask is not None
|
|
and attention_mask.dtype != torch.bool
|
|
):
|
|
# Relative position bias tensors are represented as float masks and are incompatible with Flash Attention
|
|
# Fallback to SDPA for this call only so the rest of the model can still benefit from FA
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
|
|
logger.warning_once(
|
|
"Sam3Attention: falling back to SDPA for relative-position cross-attention because "
|
|
"Flash Attention does not support additive bias masks."
|
|
)
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query,
|
|
key,
|
|
value,
|
|
attention_mask=attention_mask,
|
|
dropout=0.0,
|
|
scaling=self.scaling,
|
|
is_causal=self.is_causal,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(batch_size, query_len, self.num_attention_heads * self.head_dim).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class Sam3ViTRotaryEmbedding(nn.Module):
|
|
"""
|
|
Vision Rotary Position Embedding for SAM3, following transformers library standards.
|
|
Supports 2D (axial) rotary embeddings for spatial dimensions.
|
|
"""
|
|
|
|
def __init__(self, config: Sam3ViTConfig, end_x: int, end_y: int, scale: float = 1.0):
|
|
super().__init__()
|
|
dim = config.hidden_size // config.num_attention_heads
|
|
# Ensure even dimension for proper axial splitting
|
|
if dim % 4 != 0:
|
|
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
self.end_x, self.end_y = end_x, end_y
|
|
self.dim = dim
|
|
self.rope_theta = config.rope_theta
|
|
self.scale = scale
|
|
freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
|
|
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
x_positions = (flattened_indices % end_x) * scale
|
|
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * scale
|
|
freqs_x = torch.outer(x_positions, freqs).float()
|
|
freqs_y = torch.outer(y_positions, freqs).float()
|
|
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
# directly register the cos and sin embeddings as we have a fixed feature shape
|
|
self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
|
|
self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
|
|
|
|
@torch.no_grad()
|
|
def forward(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# As the feature map size is fixed for each stage, we can just return the pre-computed embeddings.
|
|
return self.rope_embeddings_cos, self.rope_embeddings_sin
|
|
|
|
|
|
def rotate_pairwise(x):
|
|
"""
|
|
pairwise rotation of the hidden dims of the input. Differerent from Llama Half-Tensor Rotation.
|
|
|
|
This is an optimized version of the following more explicit implementation:
|
|
```python
|
|
x_rotated = torch.zeros_like(x, dtype=x.dtype, device=x.device)
|
|
x_rotated[..., ::2] = -x[..., 1::2]
|
|
x_rotated[..., 1::2] = x[..., ::2]
|
|
return x_rotated
|
|
```
|
|
"""
|
|
x = x.view(*x.shape[:-1], -1, 2)
|
|
x1, x2 = x.unbind(dim=-1)
|
|
x = torch.stack((-x2, x1), dim=-1)
|
|
return x.flatten(start_dim=-2)
|
|
|
|
|
|
def apply_rotary_pos_emb_2d(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Apply rotary position embedding to query and key tensors for self-attention.
|
|
|
|
Args:
|
|
q: Query tensor of shape (batch_size, num_windows, seq_len, num_heads, head_dim)
|
|
k: Key tensor of shape (batch_size, num_windows, seq_len, num_heads, head_dim)
|
|
cos: Cosine position embedding of shape (seq_len, head_dim)
|
|
sin: Sine position embedding of shape (seq_len, head_dim)
|
|
|
|
Returns:
|
|
Rotated (q, k) tensors
|
|
"""
|
|
q_embed = q.float()
|
|
q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
|
|
|
|
k_embed = k.float()
|
|
k_embed = (k_embed * cos) + (rotate_pairwise(k_embed) * sin)
|
|
|
|
return q_embed.type_as(q), k_embed.type_as(k)
|
|
|
|
|
|
class Sam3ViTRoPEAttention(nn.Module):
|
|
"""Self-attention with rotary position encoding."""
|
|
|
|
def __init__(self, config: Sam3ViTConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // config.num_attention_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attention_dropout = config.attention_dropout
|
|
self.is_causal = False
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Tensor:
|
|
batch_size, height, width, _ = hidden_states.shape
|
|
seq_len = height * width
|
|
new_shape = (batch_size, seq_len, self.num_attention_heads, self.head_dim)
|
|
query = self.q_proj(hidden_states).view(*new_shape).transpose(1, 2)
|
|
key = self.k_proj(hidden_states).view(*new_shape).transpose(1, 2)
|
|
value = self.v_proj(hidden_states).view(*new_shape).transpose(1, 2)
|
|
cos, sin = position_embeddings
|
|
query, key = apply_rotary_pos_emb_2d(query, key, cos=cos, sin=sin)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query,
|
|
key,
|
|
value,
|
|
attention_mask=None,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
is_causal=self.is_causal,
|
|
**kwargs,
|
|
)
|
|
attn_output = attn_output.reshape(batch_size, height, width, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class Sam3ViTPatchEmbeddings(nn.Module):
|
|
"""
|
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
|
Transformer.
|
|
"""
|
|
|
|
def __init__(self, config: Sam3ViTConfig):
|
|
super().__init__()
|
|
image_size, patch_size = config.pretrain_image_size, config.patch_size
|
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
|
|
|
image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size)
|
|
patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size)
|
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
self.image_size = image_size
|
|
self.patch_size = patch_size
|
|
self.num_channels = num_channels
|
|
self.num_patches = num_patches
|
|
|
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False)
|
|
|
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2)
|
|
return embeddings
|
|
|
|
|
|
class Sam3ViTEmbeddings(nn.Module):
|
|
"""
|
|
Construct the patch embeddings and position embeddings for SAM3 ViT.
|
|
|
|
Position embeddings are tiled (not interpolated) when resizing to match different input sizes.
|
|
"""
|
|
|
|
def __init__(self, config: Sam3ViTConfig):
|
|
super().__init__()
|
|
|
|
self.patch_embeddings = Sam3ViTPatchEmbeddings(config)
|
|
num_patches = self.patch_embeddings.num_patches
|
|
self.position_embeddings = nn.Parameter(
|
|
torch.randn(1, num_patches, config.hidden_size)
|
|
) # !Remove cls token in convert weights!
|
|
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
self.patch_size = config.patch_size
|
|
|
|
def _tile_position_embeddings(
|
|
self,
|
|
position_embeddings: torch.Tensor,
|
|
height: int,
|
|
width: int,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Tile position embeddings to match target spatial dimensions.
|
|
Args:
|
|
position_embeddings: Shape [1, num_pretrain_patches, hidden_size]
|
|
height: Target height in patches
|
|
width: Target width in patches
|
|
|
|
Returns:
|
|
Shape [1, height * width, hidden_size]
|
|
"""
|
|
pretrain_size = int(position_embeddings.shape[1] ** 0.5)
|
|
|
|
# Skip tiling if sizes match (but always tile during tracing for consistent graph)
|
|
if not torch.jit.is_tracing() and pretrain_size == height and pretrain_size == width:
|
|
return position_embeddings.reshape(1, height * width, -1)
|
|
|
|
# Tile position embeddings to match target spatial dimensions
|
|
hidden_size = position_embeddings.shape[-1]
|
|
pos_embed = position_embeddings.reshape(1, pretrain_size, pretrain_size, hidden_size).permute(0, 3, 1, 2)
|
|
repeat_h = height // pretrain_size + 1
|
|
repeat_w = width // pretrain_size + 1
|
|
pos_embed = pos_embed.tile([1, 1, repeat_h, repeat_w])[:, :, :height, :width]
|
|
return pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
interpolate_pos_encoding: bool = False,
|
|
) -> torch.Tensor:
|
|
height, width = pixel_values.shape[-2:]
|
|
embeddings = self.patch_embeddings(pixel_values)
|
|
|
|
# Calculate spatial dimensions in patches
|
|
height_patches = height // self.patch_size
|
|
width_patches = width // self.patch_size
|
|
|
|
position_embeddings = self._tile_position_embeddings(
|
|
self.position_embeddings,
|
|
height_patches,
|
|
width_patches,
|
|
)
|
|
embeddings = embeddings + position_embeddings
|
|
embeddings = self.dropout(embeddings)
|
|
|
|
return embeddings
|
|
|
|
|
|
def window_partition(hidden_state, window_size):
|
|
"""
|
|
Partition into non-overlapping windows with padding if needed.
|
|
|
|
Args:
|
|
hidden_state (`torch.Tensor`):
|
|
Input tokens with [batch_size, height, width, num_channels].
|
|
window_size (`int`):
|
|
Window size.
|
|
|
|
Returns:
|
|
`tuple(torch.FloatTensor)` comprising various elements:
|
|
- windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
|
|
- (padded_height, padded_width): padded height and width before partition
|
|
"""
|
|
batch_size, height, width, num_channels = hidden_state.shape
|
|
pad_height = (window_size - height % window_size) % window_size
|
|
pad_width = (window_size - width % window_size) % window_size
|
|
|
|
# Noop in case pad_width == 0 and pad_height == 0.
|
|
hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
|
|
|
|
padded_height, padded_width = height + pad_height, width + pad_width
|
|
|
|
hidden_state = hidden_state.view(
|
|
batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
|
|
)
|
|
windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
|
|
return windows, (padded_height, padded_width)
|
|
|
|
|
|
def window_unpartition(windows, window_size, pad_height_width, height_width):
|
|
"""
|
|
Window unpartition into original sequences and removing padding.
|
|
|
|
Args:
|
|
windows (`torch.Tensor`):
|
|
Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
|
|
window_size (`int`):
|
|
Window size.
|
|
pad_height_width (`tuple[int]`):
|
|
Padded height and width (padded_height, padded_width).
|
|
height_width (`tuple[int]`):
|
|
Original height and width before padding.
|
|
|
|
Returns:
|
|
hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
|
|
"""
|
|
padded_height, padded_width = pad_height_width
|
|
height, width = height_width
|
|
batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
|
|
hidden_state = windows.view(
|
|
batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
|
|
)
|
|
hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
|
|
hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
|
|
|
|
# We always have height <= padded_height and width <= padded_width
|
|
hidden_state = hidden_state[:, :height, :width, :].contiguous()
|
|
return hidden_state
|
|
|
|
|
|
class Sam3ViTLayerScale(nn.Module):
|
|
def __init__(self, config) -> None:
|
|
super().__init__()
|
|
self.lambda1 = nn.Parameter(config.layer_scale_init_value * torch.ones(config.hidden_size))
|
|
|
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
return hidden_state * self.lambda1
|
|
|
|
|
|
class Sam3ViTLayer(GradientCheckpointingLayer):
|
|
"""Vision Transformer layer with rotary position embeddings and optional windowed attention."""
|
|
|
|
def __init__(self, config: Sam3ViTConfig, window_size: int = 0) -> None:
|
|
super().__init__()
|
|
|
|
hidden_size = config.hidden_size
|
|
image_size = config.image_size
|
|
image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
|
|
|
|
patch_size = config.patch_size
|
|
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
|
|
|
|
input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
|
self.layer_norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
|
rotary_input_size = input_size if window_size == 0 else (window_size, window_size)
|
|
rotary_scale = config.window_size / rotary_input_size[0]
|
|
self.rotary_emb = Sam3ViTRotaryEmbedding(
|
|
config, end_x=rotary_input_size[0], end_y=rotary_input_size[1], scale=rotary_scale
|
|
)
|
|
self.attention = Sam3ViTRoPEAttention(config)
|
|
self.layer_norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
|
self.mlp = Sam3MLP(config)
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
|
|
self.window_size = window_size
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.Tensor:
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
|
|
if self.window_size > 0:
|
|
height, width = hidden_states.shape[1], hidden_states.shape[2]
|
|
# Partition into non-overlapping windows for efficient attention
|
|
hidden_states, pad_height_width = window_partition(hidden_states, self.window_size)
|
|
|
|
position_embeddings = self.rotary_emb()
|
|
hidden_states, _ = self.attention(hidden_states, position_embeddings, **kwargs)
|
|
|
|
if self.window_size > 0:
|
|
# Reverse window partition to restore original spatial layout
|
|
hidden_states = window_unpartition(hidden_states, self.window_size, pad_height_width, (height, width))
|
|
|
|
hidden_states = residual + hidden_states
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + self.dropout(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
@auto_docstring
|
|
class Sam3PreTrainedModel(PreTrainedModel):
|
|
config_class = Sam3Config
|
|
base_model_prefix = "sam3"
|
|
main_input_name = "pixel_values"
|
|
input_modalities = ["image", "text"]
|
|
_supports_sdpa = True
|
|
_supports_flash_attn = True
|
|
_supports_flex_attn = True
|
|
_supports_attention_backend = True
|
|
|
|
def _init_weights(self, module):
|
|
super()._init_weights(module)
|
|
if isinstance(module, Sam3ViTEmbeddings):
|
|
init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
|
|
elif isinstance(module, Sam3ViTRotaryEmbedding):
|
|
end_x, end_y = module.end_x, module.end_y
|
|
dim = module.dim
|
|
freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
|
|
x_positions = (flattened_indices % end_x) * module.scale
|
|
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale
|
|
freqs_x = torch.outer(x_positions, freqs).float()
|
|
freqs_y = torch.outer(y_positions, freqs).float()
|
|
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
|
|
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
|
|
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
|
|
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
|
|
|
|
|
|
@auto_docstring
|
|
class Sam3ViTModel(Sam3PreTrainedModel):
|
|
def __init__(self, config: Sam3ViTConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.embeddings = Sam3ViTEmbeddings(config)
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
Sam3ViTLayer(config, window_size=config.window_size if i not in config.global_attn_indexes else 0)
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> Sam3ViTPatchEmbeddings:
|
|
return self.embeddings.patch_embeddings
|
|
|
|
@check_model_inputs(tie_last_hidden_states=False)
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> BaseModelOutput:
|
|
hidden_states = self.embeddings(pixel_values) # [batch_size, seq_len, hidden_size]
|
|
|
|
batch_size = hidden_states.shape[0]
|
|
height = pixel_values.shape[-2] // self.config.patch_size
|
|
width = pixel_values.shape[-1] // self.config.patch_size
|
|
hidden_size = hidden_states.shape[-1]
|
|
|
|
# Reshape to spatial format for windowed attention: [batch_size, height, width, hidden_size]
|
|
hidden_states = hidden_states.view(batch_size, height, width, hidden_size)
|
|
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
for layer in self.layers:
|
|
hidden_states = layer(hidden_states, **kwargs)
|
|
|
|
# Reshape back to sequence format: [batch_size, height*width, hidden_size]
|
|
hidden_states = hidden_states.view(batch_size, height * width, hidden_size)
|
|
|
|
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
|
|
|
|
class Sam3SinePositionEmbedding(nn.Module):
|
|
"""
|
|
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
|
need paper, generalized to work on images.
|
|
"""
|
|
|
|
def __init__(
|
|
self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None
|
|
):
|
|
super().__init__()
|
|
if scale is not None and normalize is False:
|
|
raise ValueError("normalize should be True if scale is passed")
|
|
self.num_pos_feats = num_pos_feats
|
|
self.temperature = temperature
|
|
self.normalize = normalize
|
|
self.scale = 2 * math.pi if scale is None else scale
|
|
|
|
def encode_1d_positions(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Encode 1D coordinate pairs using sine/cosine positional embeddings.
|
|
|
|
Args:
|
|
x: 1D tensor of x coordinates (flattened)
|
|
y: 1D tensor of y coordinates (flattened)
|
|
|
|
Returns:
|
|
Tuple of (pos_x, pos_y) positional embeddings
|
|
"""
|
|
x_embed = x * self.scale
|
|
y_embed = y * self.scale
|
|
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).to(x.dtype)
|
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
|
|
pos_x = x_embed[:, None] / dim_t
|
|
pos_y = y_embed[:, None] / dim_t
|
|
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
|
|
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
|
|
return pos_x, pos_y
|
|
|
|
def encode_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Encode 4D box coordinates (x, y, w, h) for decoder conditioning using sine/cosine embeddings.
|
|
|
|
Args:
|
|
boxes: Box coordinates [batch_size, num_queries, 4] in (x, y, w, h) format
|
|
|
|
Returns:
|
|
Position embeddings [batch_size, num_queries, num_pos_feats*4]
|
|
"""
|
|
assert boxes.size(-1) == 4, f"Expected 4D box coordinates (x, y, w, h), got shape {boxes.shape}"
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=boxes.device).to(boxes.dtype)
|
|
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
|
|
|
|
x_embed = boxes[:, :, 0] * self.scale
|
|
y_embed = boxes[:, :, 1] * self.scale
|
|
w_embed = boxes[:, :, 2] * self.scale
|
|
h_embed = boxes[:, :, 3] * self.scale
|
|
|
|
pos_x = x_embed[:, :, None] / dim_t
|
|
pos_y = y_embed[:, :, None] / dim_t
|
|
pos_w = w_embed[:, :, None] / dim_t
|
|
pos_h = h_embed[:, :, None] / dim_t
|
|
|
|
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
|
|
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
|
|
|
|
return pos
|
|
|
|
@compile_compatible_method_lru_cache(maxsize=4)
|
|
def forward(
|
|
self,
|
|
shape: torch.Size,
|
|
device: torch.device | str,
|
|
dtype: torch.dtype,
|
|
mask: Tensor | None = None,
|
|
) -> Tensor:
|
|
if mask is None:
|
|
mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
|
|
not_mask = (~mask).to(dtype)
|
|
y_embed = not_mask.cumsum(1)
|
|
x_embed = not_mask.cumsum(2)
|
|
if self.normalize:
|
|
eps = 1e-6
|
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
|
|
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
|
|
|
|
pos_x = x_embed[:, :, :, None] / dim_t
|
|
pos_y = y_embed[:, :, :, None] / dim_t
|
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
return pos
|
|
|
|
|
|
class Sam3FPNLayer(nn.Module):
|
|
def __init__(self, in_channels: int, fpn_dim: int, scale_factor: float):
|
|
super().__init__()
|
|
self.scale_factor = scale_factor
|
|
|
|
# Build the upsampling/downsampling layers based on scale factor
|
|
self.scale_layers = nn.ModuleList()
|
|
|
|
if scale_factor == 4.0:
|
|
self.scale_layers.append(nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2))
|
|
self.scale_layers.append(nn.GELU())
|
|
self.scale_layers.append(nn.ConvTranspose2d(in_channels // 2, in_channels // 4, kernel_size=2, stride=2))
|
|
intermediate_channels = in_channels // 4
|
|
elif scale_factor == 2.0:
|
|
self.scale_layers.append(nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2))
|
|
intermediate_channels = in_channels // 2
|
|
elif scale_factor == 1.0:
|
|
intermediate_channels = in_channels
|
|
elif scale_factor == 0.5:
|
|
self.scale_layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
|
|
intermediate_channels = in_channels
|
|
else:
|
|
raise NotImplementedError(f"scale_factor={scale_factor} is not supported yet.")
|
|
|
|
self.proj1 = nn.Conv2d(in_channels=intermediate_channels, out_channels=fpn_dim, kernel_size=1)
|
|
self.proj2 = nn.Conv2d(in_channels=fpn_dim, out_channels=fpn_dim, kernel_size=3, padding=1)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = hidden_states.to(self.proj1.weight.dtype)
|
|
for layer in self.scale_layers:
|
|
hidden_states = layer(hidden_states)
|
|
|
|
hidden_states = self.proj1(hidden_states)
|
|
hidden_states = self.proj2(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Sam3VisionNeck(nn.Module):
|
|
def __init__(self, config: Sam3VisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
|
|
|
|
# Create one FPN layer per scale factor
|
|
self.fpn_layers = nn.ModuleList(
|
|
[
|
|
Sam3FPNLayer(
|
|
in_channels=config.backbone_config.hidden_size, fpn_dim=config.fpn_hidden_size, scale_factor=scale
|
|
)
|
|
for scale in config.scale_factors
|
|
]
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
|
|
fpn_hidden_states = ()
|
|
fpn_position_encoding = ()
|
|
|
|
for fpn_layer in self.fpn_layers:
|
|
fpn_output = fpn_layer(hidden_states)
|
|
fpn_hidden_states += (fpn_output,)
|
|
# Generate position encoding for this FPN level
|
|
pos_enc = self.position_encoding(fpn_output.shape, fpn_output.device, fpn_output.dtype)
|
|
fpn_position_encoding += (pos_enc,)
|
|
|
|
return fpn_hidden_states, fpn_position_encoding
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The vision model from Sam without any head or projection on top.
|
|
"""
|
|
)
|
|
class Sam3VisionModel(Sam3PreTrainedModel):
|
|
config_class = Sam3VisionConfig
|
|
main_input_name = "pixel_values"
|
|
_can_record_outputs = {
|
|
"hidden_states": Sam3ViTLayer,
|
|
"attentions": Sam3ViTRoPEAttention,
|
|
}
|
|
|
|
def __init__(self, config: Sam3VisionConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.backbone = AutoModel.from_config(config.backbone_config)
|
|
self.neck = Sam3VisionNeck(config)
|
|
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.backbone.get_input_embeddings()
|
|
|
|
@check_model_inputs
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | Sam3VisionEncoderOutput:
|
|
if pixel_values is None:
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
backbone_output = self.backbone(pixel_values, **kwargs)
|
|
hidden_states = backbone_output.last_hidden_state # [batch_size, seq_len, hidden_size]
|
|
|
|
# Reshape for FPN neck: [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size, height, width]
|
|
batch_size = hidden_states.shape[0]
|
|
height = pixel_values.shape[-2] // self.config.backbone_config.patch_size
|
|
width = pixel_values.shape[-1] // self.config.backbone_config.patch_size
|
|
hidden_states_spatial = hidden_states.view(batch_size, height, width, -1).permute(0, 3, 1, 2)
|
|
fpn_hidden_states, fpn_position_encoding = self.neck(hidden_states_spatial)
|
|
|
|
return Sam3VisionEncoderOutput(
|
|
last_hidden_state=hidden_states,
|
|
fpn_hidden_states=fpn_hidden_states,
|
|
fpn_position_encoding=fpn_position_encoding,
|
|
)
|
|
|
|
|
|
class Sam3GeometryEncoderLayer(nn.Module):
|
|
def __init__(self, config: Sam3GeometryEncoderConfig):
|
|
super().__init__()
|
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size)
|
|
self.self_attn = Sam3Attention(config)
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
|
|
self.cross_attn = Sam3Attention(config)
|
|
self.layer_norm2 = nn.LayerNorm(config.hidden_size)
|
|
|
|
self.mlp = Sam3MLP(config)
|
|
self.layer_norm3 = nn.LayerNorm(config.hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
prompt_feats: Tensor,
|
|
vision_feats: Tensor,
|
|
vision_pos_encoding: Tensor,
|
|
prompt_mask: Tensor,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
residual = prompt_feats
|
|
hidden_states = self.layer_norm1(prompt_feats)
|
|
hidden_states, _ = self.self_attn(
|
|
query=hidden_states, key=hidden_states, value=hidden_states, attention_mask=prompt_mask, **kwargs
|
|
)
|
|
hidden_states = self.dropout(hidden_states) + residual
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
key = vision_feats + vision_pos_encoding
|
|
hidden_states, _ = self.cross_attn(query=hidden_states, key=key, value=vision_feats, **kwargs)
|
|
hidden_states = self.dropout(hidden_states) + residual
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm3(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.dropout(hidden_states) + residual
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Sam3GeometryEncoder(nn.Module):
|
|
"""
|
|
Encoder for geometric prompts (boxes).
|
|
|
|
Boxes are encoded using three approaches:
|
|
- Direct projection: linear projection from coordinate space to hidden_size
|
|
- Pooling: pool features from the backbone at the specified location (ROI align for boxes)
|
|
- Position encoding: use position encoding of the box center
|
|
|
|
These encodings are combined additively and further processed with transformer layers.
|
|
"""
|
|
|
|
def __init__(self, config: Sam3GeometryEncoderConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.roi_size = config.roi_size
|
|
|
|
self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=True)
|
|
self.label_embed = nn.Embedding(2, self.hidden_size)
|
|
self.cls_embed = nn.Embedding(1, self.hidden_size)
|
|
|
|
# Box encoding layers
|
|
self.boxes_direct_project = nn.Linear(4, self.hidden_size)
|
|
self.boxes_pool_project = nn.Conv2d(self.hidden_size, self.hidden_size, self.roi_size)
|
|
self.boxes_pos_enc_project = nn.Linear(self.hidden_size + 2, self.hidden_size)
|
|
|
|
# Image feature normalization
|
|
self.vision_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
|
|
# Prompt projection and normalization
|
|
self.final_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
|
self.prompt_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
|
|
# Transformer layers
|
|
self.layers = nn.ModuleList([Sam3GeometryEncoderLayer(config) for _ in range(config.num_layers)])
|
|
self.output_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
|
|
def _encode_box_coordinates(
|
|
self, center_x: torch.Tensor, center_y: torch.Tensor, width: torch.Tensor, height: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
Encode box coordinates by combining position-encoded centers with raw width/height.
|
|
|
|
Args:
|
|
center_x: 1D tensor of box center x coordinates
|
|
center_y: 1D tensor of box center y coordinates
|
|
width: 1D tensor of box widths
|
|
height: 1D tensor of box heights
|
|
|
|
Returns:
|
|
Encoded box coordinates [N, embedding_dim]
|
|
"""
|
|
pos_x, pos_y = self.position_encoding.encode_1d_positions(center_x, center_y)
|
|
pos = torch.cat((pos_y, pos_x, height[:, None], width[:, None]), dim=1)
|
|
return pos
|
|
|
|
def _encode_boxes(self, boxes, boxes_mask, boxes_labels, vision_features):
|
|
"""Encode box prompts. Mask convention: True=valid, False=padding."""
|
|
batch_size, num_boxes = boxes.shape[:2]
|
|
height, width = vision_features.shape[-2:]
|
|
boxes_embed = self.boxes_direct_project(boxes)
|
|
|
|
# Pool features using ROI align
|
|
# Convert boxes from CxCyWH to xyxy format and denormalize
|
|
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
|
|
scale = torch.tensor([width, height, width, height], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
|
|
scale = scale.view(1, 1, 4)
|
|
boxes_xyxy = boxes_xyxy * scale
|
|
# ROI align expects list of boxes per batch element,
|
|
# convert from bfloat16 to float16 as roi_align only supports float16 and float32
|
|
dtype = torch.float16 if vision_features.dtype == torch.bfloat16 else vision_features.dtype
|
|
sampled_features = torchvision.ops.roi_align(
|
|
vision_features.to(dtype), boxes_xyxy.to(dtype).unbind(0), self.roi_size
|
|
).to(vision_features.dtype)
|
|
|
|
pooled_projection = self.boxes_pool_project(sampled_features)
|
|
pooled_projection = pooled_projection.view(batch_size, num_boxes, self.hidden_size)
|
|
boxes_embed = boxes_embed + pooled_projection
|
|
|
|
# Add position encoding
|
|
center_x, center_y, box_width, box_height = boxes.unbind(-1)
|
|
pos_enc = self._encode_box_coordinates(
|
|
center_x.flatten(), center_y.flatten(), box_width.flatten(), box_height.flatten()
|
|
)
|
|
pos_enc = pos_enc.view(batch_size, num_boxes, pos_enc.shape[-1])
|
|
pos_projection = self.boxes_pos_enc_project(pos_enc)
|
|
boxes_embed = boxes_embed + pos_projection
|
|
|
|
# Add label embeddings (positive/negative)
|
|
label_embed = self.label_embed(boxes_labels.long())
|
|
return label_embed + boxes_embed, boxes_mask
|
|
|
|
def forward(
|
|
self,
|
|
box_embeddings: torch.Tensor,
|
|
box_mask: torch.Tensor,
|
|
box_labels: torch.Tensor,
|
|
img_feats: tuple[torch.Tensor, ...],
|
|
img_pos_embeds: tuple[torch.Tensor, ...] | None = None,
|
|
):
|
|
"""
|
|
Forward pass for encoding geometric prompts.
|
|
|
|
Args:
|
|
box_embeddings: Box coordinates in CxCyWH format [batch_size, num_boxes, 4]
|
|
box_mask: Attention mask for boxes [batch_size, num_boxes]
|
|
box_labels: Labels for boxes (positive/negative) [batch_size, num_boxes]
|
|
img_feats: Image features from vision encoder
|
|
img_pos_embeds: Optional position embeddings for image features
|
|
|
|
Returns:
|
|
Sam3GeometryEncoderOutput containing encoded geometry features and attention mask.
|
|
"""
|
|
batch_size = box_embeddings.shape[0]
|
|
|
|
# Prepare vision features for cross-attention: flatten spatial dimensions
|
|
vision_feats = img_feats[-1] # [B, C, H, W]
|
|
vision_pos_embeds = img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(vision_feats)
|
|
vision_feats_flat = vision_feats.flatten(2).transpose(1, 2) # [B, H*W, C]
|
|
vision_pos_embeds_flat = vision_pos_embeds.flatten(2).transpose(1, 2) # [B, H*W, C]
|
|
|
|
# Normalize image features for pooling operations
|
|
img_feats_last = img_feats[-1] # [B, C, H, W]
|
|
img_feats_last = img_feats_last.permute(0, 2, 3, 1) # [B, H, W, C]
|
|
normalized_img_feats = self.vision_layer_norm(img_feats_last)
|
|
normalized_img_feats = normalized_img_feats.permute(0, 3, 1, 2) # [B, C, H, W]
|
|
|
|
prompt_embeds, prompt_mask = self._encode_boxes(box_embeddings, box_mask, box_labels, normalized_img_feats)
|
|
|
|
# Add CLS token (always valid)
|
|
cls_embed = self.cls_embed.weight.view(1, self.hidden_size).unsqueeze(0).expand(batch_size, -1, -1)
|
|
cls_mask = torch.ones(batch_size, 1, dtype=prompt_mask.dtype, device=prompt_mask.device)
|
|
prompt_embeds, prompt_mask = concat_padded_sequences(prompt_embeds, prompt_mask, cls_embed, cls_mask)
|
|
|
|
prompt_embeds = self.prompt_layer_norm(self.final_proj(prompt_embeds))
|
|
|
|
# Create bidirectional attention mask for transformer layers
|
|
prompt_attention_mask = None
|
|
if prompt_mask is not None:
|
|
prompt_attention_mask = create_bidirectional_mask(
|
|
config=self.config,
|
|
input_embeds=prompt_embeds,
|
|
attention_mask=prompt_mask,
|
|
)
|
|
|
|
# Apply transformer layers with cross-attention to vision features
|
|
for layer in self.layers:
|
|
prompt_embeds = layer(
|
|
prompt_feats=prompt_embeds,
|
|
vision_feats=vision_feats_flat,
|
|
vision_pos_encoding=vision_pos_embeds_flat,
|
|
prompt_mask=prompt_attention_mask,
|
|
)
|
|
|
|
# Final output normalization
|
|
prompt_embeds = self.output_layer_norm(prompt_embeds)
|
|
|
|
return Sam3GeometryEncoderOutput(
|
|
last_hidden_state=prompt_embeds,
|
|
attention_mask=prompt_mask,
|
|
)
|
|
|
|
|
|
class Sam3DetrEncoderLayer(nn.Module):
|
|
"""DETR encoder layer with self-attention and cross-attention."""
|
|
|
|
def __init__(self, config: Sam3DETREncoderConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size)
|
|
self.self_attn = Sam3Attention(config)
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
|
|
self.cross_attn = Sam3Attention(config)
|
|
self.layer_norm2 = nn.LayerNorm(config.hidden_size)
|
|
|
|
self.mlp = Sam3MLP(config)
|
|
self.layer_norm3 = nn.LayerNorm(config.hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
vision_feats: Tensor,
|
|
prompt_feats: Tensor,
|
|
vision_pos_encoding: Tensor,
|
|
prompt_cross_attn_mask: Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
"""
|
|
Forward pass for DETR encoder layer.
|
|
|
|
Args:
|
|
vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states)
|
|
prompt_feats: Text prompt features [batch_size, text_len, hidden_size]
|
|
vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size]
|
|
prompt_cross_attn_mask: Cross-attention mask for prompt features
|
|
|
|
Returns:
|
|
Updated vision features [batch_size, vision_len, hidden_size]
|
|
"""
|
|
# Self-attention on vision features with position encoding
|
|
residual = vision_feats
|
|
hidden_states = self.layer_norm1(vision_feats)
|
|
hidden_states_with_pos = hidden_states + vision_pos_encoding
|
|
hidden_states, _ = self.self_attn(
|
|
query=hidden_states_with_pos,
|
|
key=hidden_states_with_pos,
|
|
value=hidden_states,
|
|
**kwargs,
|
|
)
|
|
hidden_states = self.dropout(hidden_states) + residual
|
|
|
|
# Cross-attention: vision queries attend to text/prompt features
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
|
|
hidden_states, _ = self.cross_attn(
|
|
query=hidden_states,
|
|
key=prompt_feats,
|
|
value=prompt_feats,
|
|
attention_mask=prompt_cross_attn_mask,
|
|
**kwargs,
|
|
)
|
|
hidden_states = self.dropout(hidden_states) + residual
|
|
|
|
# MLP
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm3(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.dropout(hidden_states) + residual
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Sam3DetrEncoder(Sam3PreTrainedModel):
|
|
"""
|
|
DETR-style encoder that processes multi-level vision features with text fusion.
|
|
|
|
This encoder processes vision features from multiple levels (e.g., FPN features at different
|
|
resolutions) and fuses them with text prompts through a stack of transformer encoder layers.
|
|
"""
|
|
|
|
_can_record_outputs = {
|
|
"hidden_states": Sam3DetrEncoderLayer,
|
|
"attentions": Sam3Attention,
|
|
}
|
|
|
|
def __init__(self, config: Sam3DETREncoderConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.layers = nn.ModuleList([Sam3DetrEncoderLayer(config) for _ in range(config.num_layers)])
|
|
|
|
self.post_init()
|
|
|
|
def _prepare_multilevel_features(
|
|
self,
|
|
vision_features: list[torch.Tensor],
|
|
vision_pos_embeds: list[torch.Tensor],
|
|
):
|
|
"""
|
|
Prepare multi-level vision features by flattening spatial dimensions and adding level embeddings.
|
|
|
|
Args:
|
|
vision_features: List of vision features at different levels [batch_size, channels, height, width]
|
|
vision_pos_embeds: List of position embeddings for each level [batch_size, channels, height, width]
|
|
|
|
Returns:
|
|
Tuple containing flattened features, position embeddings, and spatial metadata
|
|
"""
|
|
features_flattened = []
|
|
pos_embeds_flattened = []
|
|
spatial_shapes = []
|
|
|
|
for features, pos_embed in zip(vision_features, vision_pos_embeds):
|
|
height, width = features.shape[-2:]
|
|
spatial_shapes.append((height, width))
|
|
|
|
# Flatten spatial dimensions: [batch_size, channels, height, width] -> [batch_size, height*width, channels]
|
|
features = features.flatten(2).transpose(1, 2)
|
|
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
|
|
|
features_flattened.append(features)
|
|
pos_embeds_flattened.append(pos_embed)
|
|
|
|
# Concatenate all levels into single sequence
|
|
features_flattened = torch.cat(features_flattened, dim=1)
|
|
pos_embeds_flattened = torch.cat(pos_embeds_flattened, dim=1)
|
|
|
|
spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=features_flattened.device)
|
|
|
|
return (
|
|
features_flattened,
|
|
pos_embeds_flattened,
|
|
spatial_shapes,
|
|
)
|
|
|
|
@check_model_inputs
|
|
def forward(
|
|
self,
|
|
vision_features: list[torch.Tensor],
|
|
text_features: torch.Tensor,
|
|
vision_pos_embeds: list[torch.Tensor] | None = None,
|
|
text_mask: torch.Tensor | None = None,
|
|
spatial_sizes: list[tuple[int, int]] | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | Sam3DETREncoderOutput:
|
|
"""
|
|
Forward pass for the DETR encoder.
|
|
|
|
Args:
|
|
vision_features: List of vision features at different levels
|
|
text_features: Text prompt features [batch_size, seq_len, hidden_size]
|
|
vision_pos_embeds: Optional list of position embeddings for each level
|
|
text_mask: Optional text padding mask [batch_size, seq_len]
|
|
spatial_sizes: Optional list of (height, width) tuples for reshaping
|
|
|
|
Returns:
|
|
Sam3DETREncoderOutput containing encoded features and metadata.
|
|
"""
|
|
batch_size = vision_features[0].shape[0] if vision_features[0].dim() == 4 else vision_features[0].shape[1]
|
|
|
|
# TODO: See if we can remove that reshaping and just use the features as is.
|
|
if spatial_sizes is not None:
|
|
for i, (height, width) in enumerate(spatial_sizes):
|
|
# Reshape from [height*width, batch_size, channels] to [batch_size, channels, height, width]
|
|
vision_features[i] = vision_features[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
|
|
vision_pos_embeds[i] = vision_pos_embeds[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
|
|
|
|
# Flatten multi-level features for encoder processing
|
|
(
|
|
features_flattened,
|
|
pos_embeds_flattened,
|
|
spatial_shapes,
|
|
) = self._prepare_multilevel_features(vision_features, vision_pos_embeds)
|
|
|
|
prompt_cross_attn_mask = None
|
|
if text_mask is not None:
|
|
prompt_cross_attn_mask = create_bidirectional_mask(
|
|
config=self.config,
|
|
input_embeds=features_flattened,
|
|
attention_mask=text_mask,
|
|
encoder_hidden_states=text_features,
|
|
)
|
|
|
|
hidden_states = features_flattened
|
|
for layer in self.layers:
|
|
hidden_states = layer(
|
|
hidden_states,
|
|
prompt_feats=text_features,
|
|
vision_pos_encoding=pos_embeds_flattened,
|
|
prompt_cross_attn_mask=prompt_cross_attn_mask,
|
|
**kwargs,
|
|
)
|
|
return Sam3DETREncoderOutput(
|
|
last_hidden_state=hidden_states,
|
|
pos_embeds_flattened=pos_embeds_flattened,
|
|
text_features=text_features,
|
|
spatial_shapes=spatial_shapes,
|
|
)
|
|
|
|
|
|
class Sam3DecoderMLP(nn.Module):
|
|
"""Simple 2 or 3-layer MLP for decoder components."""
|
|
|
|
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 2):
|
|
super().__init__()
|
|
if num_layers == 2:
|
|
self.layer1 = nn.Linear(input_dim, hidden_dim)
|
|
self.layer2 = nn.Linear(hidden_dim, output_dim)
|
|
self.layer3 = None
|
|
elif num_layers == 3:
|
|
self.layer1 = nn.Linear(input_dim, hidden_dim)
|
|
self.layer2 = nn.Linear(hidden_dim, hidden_dim)
|
|
self.layer3 = nn.Linear(hidden_dim, output_dim)
|
|
else:
|
|
raise ValueError(f"Only 2 or 3 layers supported, got {num_layers}")
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = F.relu(self.layer1(x))
|
|
if self.layer3 is not None:
|
|
x = F.relu(self.layer2(x))
|
|
x = self.layer3(x)
|
|
else:
|
|
x = self.layer2(x)
|
|
return x
|
|
|
|
|
|
class Sam3DetrDecoderLayer(nn.Module):
|
|
"""DETR decoder layer with self-attention, text cross-attention, and vision cross-attention."""
|
|
|
|
def __init__(self, config: Sam3DETRDecoderConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.self_attn = Sam3Attention(config)
|
|
self.self_attn_dropout = nn.Dropout(config.dropout)
|
|
self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
|
|
|
self.text_cross_attn = Sam3Attention(config)
|
|
self.text_cross_attn_dropout = nn.Dropout(config.dropout)
|
|
self.text_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
|
|
|
self.vision_cross_attn = Sam3Attention(config)
|
|
self.vision_cross_attn_dropout = nn.Dropout(config.dropout)
|
|
self.vision_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
|
|
|
self.mlp = Sam3MLP(config)
|
|
self.mlp_layer_norm = nn.LayerNorm(config.hidden_size)
|
|
self.mlp_dropout = nn.Dropout(config.dropout)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
query_pos: torch.Tensor,
|
|
text_features: torch.Tensor,
|
|
vision_features: torch.Tensor,
|
|
vision_pos_encoding: torch.Tensor,
|
|
text_cross_attn_mask: torch.Tensor | None = None,
|
|
vision_cross_attn_mask: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Forward pass for decoder layer.
|
|
|
|
Args:
|
|
hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0)
|
|
query_pos: Query position embeddings [batch_size, num_queries, hidden_size]
|
|
text_features: Text features [batch_size, seq_len, hidden_size]
|
|
vision_features: Vision features [batch_size, height*width, hidden_size]
|
|
vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
|
|
text_cross_attn_mask: Text cross-attention mask
|
|
vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token
|
|
|
|
Returns:
|
|
Updated hidden states (including presence token at position 0)
|
|
"""
|
|
# Prepend zeros to query_pos for presence token
|
|
query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0)
|
|
|
|
# Self-attention with query position encoding
|
|
residual = hidden_states
|
|
query_with_pos = hidden_states + query_pos
|
|
attn_output, _ = self.self_attn(
|
|
query=query_with_pos,
|
|
key=query_with_pos,
|
|
value=hidden_states,
|
|
attention_mask=None,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + self.self_attn_dropout(attn_output)
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
# Text cross-attention: queries attend to text features
|
|
residual = hidden_states
|
|
query_with_pos = hidden_states + query_pos
|
|
|
|
attn_output, _ = self.text_cross_attn(
|
|
query=query_with_pos,
|
|
key=text_features,
|
|
value=text_features,
|
|
attention_mask=text_cross_attn_mask,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + self.text_cross_attn_dropout(attn_output)
|
|
hidden_states = self.text_cross_attn_layer_norm(hidden_states)
|
|
|
|
# Vision cross-attention: queries attend to vision features (with RPB)
|
|
residual = hidden_states
|
|
query_with_pos = hidden_states + query_pos
|
|
key_with_pos = vision_features + vision_pos_encoding
|
|
attn_output, _ = self.vision_cross_attn(
|
|
query=query_with_pos,
|
|
key=key_with_pos,
|
|
value=vision_features,
|
|
attention_mask=vision_cross_attn_mask,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + self.vision_cross_attn_dropout(attn_output)
|
|
hidden_states = self.vision_cross_attn_layer_norm(hidden_states)
|
|
|
|
# MLP
|
|
residual = hidden_states
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + self.mlp_dropout(hidden_states)
|
|
hidden_states = self.mlp_layer_norm(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Sam3DetrDecoder(Sam3PreTrainedModel):
|
|
"""
|
|
DETR-style decoder with box refinement and presence token.
|
|
|
|
Simplified version that assumes:
|
|
- Box refinement is always enabled
|
|
- Intermediate outputs are always returned
|
|
- BoxRPB (relative position bias) with log-scale encoding
|
|
- Presence token is used
|
|
"""
|
|
|
|
_can_record_outputs = {
|
|
"hidden_states": Sam3DetrDecoderLayer,
|
|
"attentions": Sam3Attention,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
config: Sam3DETRDecoderConfig,
|
|
):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.layers = nn.ModuleList([Sam3DetrDecoderLayer(config) for _ in range(config.num_layers)])
|
|
|
|
self.output_layer_norm = nn.LayerNorm(config.hidden_size)
|
|
|
|
self.box_head = Sam3DecoderMLP(config.hidden_size, config.hidden_size, 4, 3)
|
|
|
|
self.query_embed = nn.Embedding(config.num_queries, config.hidden_size)
|
|
self.reference_points = nn.Embedding(config.num_queries, 4)
|
|
|
|
self.presence_token = nn.Embedding(1, config.hidden_size)
|
|
self.presence_head = Sam3DecoderMLP(config.hidden_size, config.hidden_size, 1, 3)
|
|
self.presence_layer_norm = nn.LayerNorm(config.hidden_size)
|
|
self.clamp_presence_logit_max_val = 10.0
|
|
|
|
self.ref_point_head = Sam3DecoderMLP(2 * config.hidden_size, config.hidden_size, config.hidden_size, 2)
|
|
|
|
self.box_rpb_embed_x = Sam3DecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
|
|
self.box_rpb_embed_y = Sam3DecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
|
|
|
|
self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=False)
|
|
|
|
self.post_init()
|
|
|
|
@compile_compatible_method_lru_cache(maxsize=1)
|
|
def _get_coords(
|
|
self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Generate normalized coordinate grids."""
|
|
coords_h = torch.arange(0, height, device=device, dtype=dtype) / height
|
|
coords_w = torch.arange(0, width, device=device, dtype=dtype) / width
|
|
return coords_h, coords_w
|
|
|
|
def _get_rpb_matrix(
|
|
self, reference_boxes: torch.Tensor, spatial_shape: tuple[torch.Tensor, torch.Tensor]
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute box relative position bias (RPB) matrix using log-scale encoding.
|
|
RPB helps the decoder attend to relevant spatial locations based on predicted box positions.
|
|
|
|
Args:
|
|
reference_boxes: Reference boxes [batch_size, num_queries, 4] in sigmoid space
|
|
spatial_shape: (height, width) of the vision features as tensors
|
|
|
|
Returns:
|
|
RPB matrix [batch_size, num_heads, num_queries, height*width]
|
|
"""
|
|
height, width = spatial_shape
|
|
boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes)
|
|
batch_size, num_queries, _ = boxes_xyxy.shape
|
|
|
|
# Generate coordinate grids
|
|
coords_h, coords_w = self._get_coords(
|
|
height, width, dtype=reference_boxes.dtype, device=reference_boxes.device
|
|
)
|
|
|
|
# Compute deltas between coordinates and box boundaries
|
|
deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
|
|
deltas_y = deltas_y.view(batch_size, num_queries, -1, 2)
|
|
deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
|
|
deltas_x = deltas_x.view(batch_size, num_queries, -1, 2)
|
|
|
|
# Apply log-scale encoding
|
|
deltas_x_log = deltas_x * 8
|
|
deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / math.log2(8)
|
|
deltas_y_log = deltas_y * 8
|
|
deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / math.log2(8)
|
|
|
|
# Embed deltas
|
|
deltas_x = self.box_rpb_embed_x(deltas_x_log) # [batch_size, num_queries, width, num_heads]
|
|
deltas_y = self.box_rpb_embed_y(deltas_y_log) # [batch_size, num_queries, height, num_heads]
|
|
|
|
# Combine into 2D bias matrix
|
|
rpb_matrix = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
|
|
2
|
|
) # [batch_size, num_queries, height, width, num_heads]
|
|
rpb_matrix = rpb_matrix.flatten(2, 3) # [batch_size, num_queries, height*width, num_heads]
|
|
rpb_matrix = rpb_matrix.permute(0, 3, 1, 2).contiguous() # [batch_size, num_heads, num_queries, height*width]
|
|
return rpb_matrix
|
|
|
|
@check_model_inputs
|
|
def forward(
|
|
self,
|
|
vision_features: torch.Tensor,
|
|
text_features: torch.Tensor,
|
|
vision_pos_encoding: torch.Tensor,
|
|
text_mask: torch.Tensor | None = None,
|
|
spatial_shapes: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | Sam3DETRDecoderOutput:
|
|
"""
|
|
Forward pass for the DETR decoder.
|
|
|
|
Args:
|
|
vision_features: Vision features [batch_size, height*width, hidden_size]
|
|
text_features: Text features [batch_size, seq_len, hidden_size]
|
|
vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
|
|
text_mask: Text padding mask [batch_size, seq_len] where True=valid, False=padding
|
|
spatial_shapes: Spatial shapes [num_levels, 2]
|
|
|
|
Returns:
|
|
Sam3DETRDecoderOutput containing decoder outputs from all layers.
|
|
"""
|
|
batch_size = vision_features.shape[0]
|
|
|
|
query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
|
reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
|
reference_boxes = reference_boxes.sigmoid()
|
|
presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
|
|
|
# Concatenate presence token with query embeddings
|
|
hidden_states = torch.cat([presence_token, query_embeds], dim=1)
|
|
|
|
text_cross_attn_mask = None
|
|
if text_mask is not None:
|
|
text_cross_attn_mask = create_bidirectional_mask(
|
|
config=self.config,
|
|
input_embeds=hidden_states,
|
|
attention_mask=text_mask,
|
|
encoder_hidden_states=text_features,
|
|
)
|
|
|
|
intermediate_outputs = []
|
|
intermediate_boxes = [reference_boxes]
|
|
intermediate_presence_logits = []
|
|
|
|
for layer in self.layers:
|
|
# Generate sine embeddings for conditional queries
|
|
reference_points_input = reference_boxes.unsqueeze(2)
|
|
query_sine_embed = self.position_encoding.encode_boxes(reference_points_input[:, :, 0, :])
|
|
query_pos = self.ref_point_head(query_sine_embed)
|
|
|
|
# Compute box relative position bias (RPB) attention mask
|
|
vision_cross_attn_mask = None
|
|
if spatial_shapes is not None and spatial_shapes.shape[0] == 1:
|
|
spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1])
|
|
rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape)
|
|
# Prepend zeros row for presence token (it attends to all vision tokens equally)
|
|
vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0)
|
|
|
|
hidden_states = layer(
|
|
hidden_states,
|
|
query_pos=query_pos,
|
|
text_features=text_features,
|
|
vision_features=vision_features,
|
|
vision_pos_encoding=vision_pos_encoding,
|
|
text_cross_attn_mask=text_cross_attn_mask,
|
|
vision_cross_attn_mask=vision_cross_attn_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
# Extract query hidden states (without presence token) for box refinement
|
|
query_hidden_states = hidden_states[:, 1:]
|
|
|
|
# Box refinement: predict delta and update reference boxes
|
|
reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
|
|
delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states))
|
|
new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid()
|
|
reference_boxes = new_reference_boxes.detach()
|
|
|
|
intermediate_outputs.append(self.output_layer_norm(query_hidden_states))
|
|
intermediate_boxes.append(new_reference_boxes)
|
|
|
|
# Process presence token
|
|
presence_hidden = hidden_states[:, :1]
|
|
presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1)
|
|
presence_logits = presence_logits.clamp(
|
|
min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
|
|
)
|
|
intermediate_presence_logits.append(presence_logits)
|
|
|
|
# Stack outputs from all layers
|
|
intermediate_outputs = torch.stack(intermediate_outputs)
|
|
intermediate_boxes = torch.stack(intermediate_boxes[:-1])
|
|
intermediate_presence_logits = torch.stack(intermediate_presence_logits)
|
|
|
|
return Sam3DETRDecoderOutput(
|
|
intermediate_hidden_states=intermediate_outputs,
|
|
reference_boxes=intermediate_boxes,
|
|
presence_logits=intermediate_presence_logits,
|
|
)
|
|
|
|
|
|
class Sam3DotProductScoring(nn.Module):
|
|
"""
|
|
Computes classification scores by computing dot product between projected decoder queries and pooled text features.
|
|
This is used to determine confidence/presence scores for each query.
|
|
"""
|
|
|
|
def __init__(self, config: Sam3Config):
|
|
super().__init__()
|
|
self.config = config
|
|
hidden_size = config.detr_decoder_config.hidden_size
|
|
projection_dim = config.detr_decoder_config.hidden_size
|
|
|
|
self.text_mlp = Sam3DecoderMLP(
|
|
input_dim=hidden_size,
|
|
hidden_dim=config.detr_decoder_config.intermediate_size,
|
|
output_dim=hidden_size,
|
|
num_layers=2,
|
|
)
|
|
self.text_mlp_dropout = nn.Dropout(config.detr_decoder_config.dropout)
|
|
self.text_mlp_out_norm = nn.LayerNorm(hidden_size)
|
|
|
|
# Projections for text and query features
|
|
self.text_proj = nn.Linear(hidden_size, projection_dim)
|
|
self.query_proj = nn.Linear(hidden_size, projection_dim)
|
|
|
|
# Scale factor for dot product
|
|
self.scale = float(1.0 / np.sqrt(projection_dim))
|
|
|
|
# Clamping to avoid numerical issues
|
|
self.clamp_logits = True
|
|
self.clamp_max_val = 12.0
|
|
|
|
def _pool_text_features(self, text_features: torch.Tensor, text_mask: torch.Tensor | None) -> torch.Tensor:
|
|
"""
|
|
Mean pool text features, accounting for padding.
|
|
|
|
Args:
|
|
text_features: [batch_size, seq_len, hidden_size]
|
|
text_mask: [batch_size, seq_len] where True indicates valid tokens, False indicates padding
|
|
|
|
Returns:
|
|
pooled_text: [batch_size, hidden_size]
|
|
"""
|
|
if text_mask is None:
|
|
# No padding, simple mean
|
|
return text_features.mean(dim=1)
|
|
|
|
is_valid = text_mask.to(text_features.dtype).unsqueeze(-1) # [batch_size, seq_len, 1]
|
|
|
|
# Count valid tokens per batch
|
|
num_valid = is_valid.sum(dim=1).clamp(min=1.0) # [batch_size, 1]
|
|
|
|
# Mean pool only over valid tokens
|
|
pooled_text = (text_features * is_valid).sum(dim=1) / num_valid # [batch_size, hidden_size]
|
|
|
|
return pooled_text
|
|
|
|
def forward(
|
|
self,
|
|
decoder_hidden_states: torch.Tensor,
|
|
text_features: torch.Tensor,
|
|
text_mask: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute classification scores via dot product.
|
|
|
|
Args:
|
|
decoder_hidden_states: [num_layers, batch_size, num_queries, hidden_size]
|
|
text_features: [batch_size, seq_len, hidden_size]
|
|
text_mask: [batch_size, seq_len] where True=valid, False=padding
|
|
|
|
Returns:
|
|
scores: [num_layers, batch_size, num_queries, 1]
|
|
"""
|
|
orig_text_features = text_features
|
|
text_features = self.text_mlp(text_features)
|
|
text_features = self.text_mlp_dropout(text_features)
|
|
text_features = text_features + orig_text_features
|
|
text_features = self.text_mlp_out_norm(text_features)
|
|
|
|
pooled_text = self._pool_text_features(text_features, text_mask)
|
|
|
|
proj_text = self.text_proj(pooled_text)
|
|
proj_queries = self.query_proj(decoder_hidden_states)
|
|
|
|
proj_text = proj_text.unsqueeze(-1)
|
|
scores = torch.matmul(proj_queries, proj_text.unsqueeze(0))
|
|
scores = scores * self.scale
|
|
if self.clamp_logits:
|
|
scores = scores.clamp(min=-self.clamp_max_val, max=self.clamp_max_val)
|
|
|
|
return scores
|
|
|
|
|
|
class Sam3MaskEmbedder(nn.Module):
|
|
"""
|
|
MLP that embeds object queries for mask prediction.
|
|
Similar to MaskFormer's mask embedder.
|
|
"""
|
|
|
|
def __init__(self, config: Sam3MaskDecoderConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
hidden_size = config.hidden_size
|
|
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
nn.Linear(hidden_size, hidden_size),
|
|
nn.Linear(hidden_size, hidden_size),
|
|
nn.Linear(hidden_size, hidden_size),
|
|
]
|
|
)
|
|
self.activation = nn.ReLU()
|
|
|
|
def forward(self, queries: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
queries: Query embeddings [batch_size, num_queries, hidden_size]
|
|
|
|
Returns:
|
|
Mask embeddings [batch_size, num_queries, hidden_size]
|
|
"""
|
|
hidden_states = queries
|
|
for i, layer in enumerate(self.layers):
|
|
hidden_states = layer(hidden_states)
|
|
if i < len(self.layers) - 1:
|
|
hidden_states = self.activation(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class Sam3PixelDecoder(nn.Module):
|
|
"""
|
|
Feature Pyramid Network (FPN) decoder that generates pixel-level features.
|
|
Inspired by MaskFormer's pixel decoder.
|
|
"""
|
|
|
|
def __init__(self, config: Sam3MaskDecoderConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
hidden_size = config.hidden_size
|
|
num_upsampling_stages = config.num_upsampling_stages
|
|
|
|
# Create conv layers and norms for FPN
|
|
self.conv_layers = nn.ModuleList(
|
|
[
|
|
nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1)
|
|
for _ in range(num_upsampling_stages)
|
|
]
|
|
)
|
|
self.norms = nn.ModuleList([nn.GroupNorm(8, hidden_size) for _ in range(num_upsampling_stages)])
|
|
|
|
self.out_channels = hidden_size
|
|
|
|
def forward(self, backbone_features: list[torch.Tensor]) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
backbone_features: List of backbone features [batch_size, hidden_size, H_i, W_i]
|
|
from low to high resolution (assumes already projected to hidden_size)
|
|
|
|
Returns:
|
|
Pixel embeddings [batch_size, hidden_size, H, W] at the finest resolution
|
|
"""
|
|
# Start from the coarsest feature (last in list)
|
|
prev_fpn = backbone_features[-1]
|
|
# Iterate through features from coarse to fine (excluding the last which we started with)
|
|
for layer_idx, backbone_feat in enumerate(reversed(backbone_features[:-1])):
|
|
# Upsample previous FPN output to match current backbone feature size
|
|
prev_fpn = F.interpolate(prev_fpn, size=backbone_feat.shape[-2:], mode="nearest")
|
|
|
|
# Add skip connection
|
|
prev_fpn = prev_fpn + backbone_feat
|
|
|
|
# Apply conv and norm
|
|
prev_fpn = self.conv_layers[layer_idx](prev_fpn)
|
|
prev_fpn = self.norms[layer_idx](prev_fpn)
|
|
prev_fpn = F.relu(prev_fpn)
|
|
|
|
return prev_fpn
|
|
|
|
|
|
class Sam3MaskDecoder(Sam3PreTrainedModel):
|
|
"""
|
|
Mask decoder that combines object queries with pixel-level features to predict instance masks.
|
|
Also produces a semantic segmentation output and supports cross-attention to prompts.
|
|
"""
|
|
|
|
_can_record_outputs = {
|
|
"attentions": Sam3Attention,
|
|
}
|
|
|
|
def __init__(self, config: Sam3MaskDecoderConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
hidden_size = config.hidden_size
|
|
|
|
# Pixel decoder (FPN)
|
|
self.pixel_decoder = Sam3PixelDecoder(config)
|
|
|
|
# Mask embedder (MLP to transform queries)
|
|
self.mask_embedder = Sam3MaskEmbedder(config)
|
|
|
|
# Projection from pixel decoder output to mask embedding space
|
|
self.instance_projection = nn.Conv2d(self.pixel_decoder.out_channels, hidden_size, kernel_size=1)
|
|
|
|
# Semantic segmentation head (always present in UniversalSegmentationHead)
|
|
self.semantic_projection = nn.Conv2d(self.pixel_decoder.out_channels, 1, kernel_size=1)
|
|
|
|
self.prompt_cross_attn = Sam3Attention(config)
|
|
self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
|
|
self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)
|
|
|
|
self.post_init()
|
|
|
|
@check_model_inputs
|
|
def forward(
|
|
self,
|
|
decoder_queries: torch.Tensor,
|
|
backbone_features: list[torch.Tensor],
|
|
encoder_hidden_states: torch.Tensor,
|
|
prompt_features: torch.Tensor | None = None,
|
|
prompt_mask: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | Sam3MaskDecoderOutput:
|
|
"""
|
|
Args:
|
|
decoder_queries: Decoder output queries [batch_size, num_queries, hidden_size]
|
|
backbone_features: List of backbone features to process through FPN
|
|
encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
|
|
prompt_features: Prompt features (text + geometry) for cross-attention [batch_size, prompt_len, hidden_size]
|
|
prompt_mask: Padding mask [batch_size, prompt_len] where True=valid, False=padding
|
|
|
|
Returns:
|
|
Sam3MaskDecoderOutput containing predicted masks and semantic segmentation.
|
|
"""
|
|
if prompt_features is not None:
|
|
# Cross-attention: encoder features attend to prompt features
|
|
residual = encoder_hidden_states
|
|
normed_hidden_states = self.prompt_cross_attn_norm(encoder_hidden_states)
|
|
|
|
cross_attn_mask = None
|
|
if prompt_mask is not None:
|
|
cross_attn_mask = create_bidirectional_mask(
|
|
config=self.config,
|
|
input_embeds=normed_hidden_states,
|
|
encoder_hidden_states=prompt_features,
|
|
attention_mask=prompt_mask,
|
|
)
|
|
|
|
attn_output, _ = self.prompt_cross_attn(
|
|
query=normed_hidden_states,
|
|
key=prompt_features,
|
|
value=prompt_features,
|
|
attention_mask=cross_attn_mask,
|
|
**kwargs,
|
|
)
|
|
encoder_hidden_states = residual + self.prompt_cross_attn_dropout(attn_output)
|
|
|
|
# Process backbone features through FPN to get pixel embeddings
|
|
pixel_embed = self._embed_pixels(
|
|
backbone_features=backbone_features,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
|
|
# Predict instance masks via dot product between query embeddings and pixel embeddings
|
|
instance_embeds = self.instance_projection(pixel_embed)
|
|
mask_embeddings = self.mask_embedder(decoder_queries)
|
|
pred_masks = torch.einsum("bqc,bchw->bqhw", mask_embeddings, instance_embeds)
|
|
|
|
# Generate semantic segmentation
|
|
semantic_seg = self.semantic_projection(pixel_embed)
|
|
|
|
return Sam3MaskDecoderOutput(
|
|
pred_masks=pred_masks,
|
|
semantic_seg=semantic_seg,
|
|
)
|
|
|
|
def _embed_pixels(
|
|
self,
|
|
backbone_features: list[torch.Tensor],
|
|
encoder_hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Embed pixels by combining backbone FPN features with encoder vision features.
|
|
The encoder vision features replace the finest-resolution backbone feature.
|
|
|
|
Args:
|
|
backbone_features: List of backbone features [batch_size, C, H_i, W_i]
|
|
encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
|
|
|
|
Returns:
|
|
Pixel embeddings [batch_size, hidden_size, H, W]
|
|
"""
|
|
backbone_visual_feats = [feat.clone() for feat in backbone_features]
|
|
|
|
# Extract vision features from encoder output and reshape to spatial format
|
|
spatial_dim = backbone_features[-1].shape[-2] * backbone_features[-1].shape[-1]
|
|
encoder_visual_embed = encoder_hidden_states[:, :spatial_dim, :]
|
|
batch_size, _, hidden_size = encoder_visual_embed.shape
|
|
height, width = backbone_features[-1].shape[-2:]
|
|
encoder_visual_embed = encoder_visual_embed.transpose(1, 2).reshape(batch_size, hidden_size, height, width)
|
|
|
|
# Replace finest backbone feature with encoder vision features
|
|
backbone_visual_feats[-1] = encoder_visual_embed
|
|
|
|
# Process through FPN decoder
|
|
pixel_embed = self.pixel_decoder(backbone_visual_feats)
|
|
|
|
return pixel_embed
|
|
|
|
|
|
class Sam3Model(Sam3PreTrainedModel):
|
|
input_modalities = ["image", "text"]
|
|
_checkpoint_conversion_mapping = {
|
|
r"detector_model.(.+)": r"\1" # the regex allows to remove the prefix, and add it back in revert mode
|
|
}
|
|
_keys_to_ignore_on_load_unexpected = [
|
|
r"^tracker_model.",
|
|
r"^tracker_neck.",
|
|
]
|
|
|
|
def __init__(self, config: Sam3Config):
|
|
# loading from a sam3_video config
|
|
if hasattr(config, "detector_config") and config.detector_config is not None:
|
|
detector_config = config.detector_config
|
|
if isinstance(detector_config, dict):
|
|
detector_config = Sam3Config(**detector_config)
|
|
config = detector_config
|
|
super().__init__(config)
|
|
self.vision_encoder = Sam3VisionModel(config.vision_config)
|
|
self.text_encoder = CLIPTextModelWithProjection(config.text_config)
|
|
self.vocab_size = config.text_config.vocab_size
|
|
|
|
# Project text features from text encoder hidden size to model hidden size
|
|
# CLIP text encoder outputs 1024-dim features, but we need 256-dim for DETR
|
|
self.text_projection = nn.Linear(config.text_config.hidden_size, config.detr_encoder_config.hidden_size)
|
|
|
|
# Pass _attn_implementation to subconfigs BEFORE creating modules
|
|
config.geometry_encoder_config._attn_implementation = config._attn_implementation
|
|
config.detr_encoder_config._attn_implementation = config._attn_implementation
|
|
config.detr_decoder_config._attn_implementation = config._attn_implementation
|
|
config.mask_decoder_config._attn_implementation = config._attn_implementation
|
|
|
|
self.geometry_encoder = Sam3GeometryEncoder(config.geometry_encoder_config)
|
|
self.detr_encoder = Sam3DetrEncoder(config.detr_encoder_config)
|
|
self.detr_decoder = Sam3DetrDecoder(config.detr_decoder_config)
|
|
self.mask_decoder = Sam3MaskDecoder(config.mask_decoder_config)
|
|
|
|
# Dot product scoring to compute classification scores
|
|
self.dot_product_scoring = Sam3DotProductScoring(config)
|
|
|
|
self.post_init()
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def get_text_features(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | BaseModelOutputWithPooling:
|
|
r"""
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import Sam3Model, Sam3Processor
|
|
>>> from PIL import Image
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
|
|
>>> model = Sam3Model.from_pretrained("facebook/sam3")
|
|
>>> processor = Sam3Processor.from_pretrained("facebook/sam3")
|
|
|
|
>>> # Pre-compute text embeddings
|
|
>>> text_inputs = processor(text="cat", return_tensors="pt")
|
|
>>> text_embeds = model.get_text_features(**text_inputs).pooler_output
|
|
|
|
>>> # Reuse text embeddings for multiple images
|
|
>>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read()))
|
|
>>> img_inputs = processor(images=image, return_tensors="pt")
|
|
>>> outputs = model(pixel_values=img_inputs.pixel_values, text_embeds=text_embeds)
|
|
```
|
|
"""
|
|
text_outputs = self.text_encoder(
|
|
input_ids=input_ids, attention_mask=attention_mask, return_dict=True, **kwargs
|
|
)
|
|
last_hidden_state = text_outputs.last_hidden_state
|
|
text_outputs.pooler_output = self.text_projection(last_hidden_state)
|
|
|
|
return text_outputs
|
|
|
|
@auto_docstring
|
|
def get_vision_features(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Sam3VisionEncoderOutput:
|
|
r"""
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import Sam3Model, Sam3Processor
|
|
>>> from PIL import Image
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
|
|
>>> model = Sam3Model.from_pretrained("facebook/sam3")
|
|
>>> processor = Sam3Processor.from_pretrained("facebook/sam3")
|
|
|
|
>>> # Pre-compute vision embeddings
|
|
>>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read()))
|
|
>>> img_inputs = processor(images=image, return_tensors="pt")
|
|
>>> vision_embeds = model.get_vision_features(pixel_values=img_inputs.pixel_values)
|
|
|
|
>>> # Reuse vision embeddings for multiple text prompts
|
|
>>> text_inputs = processor(text="cat", return_tensors="pt")
|
|
>>> outputs = model(vision_embeds=vision_embeds, input_ids=text_inputs.input_ids)
|
|
```
|
|
"""
|
|
vision_outputs = self.vision_encoder(pixel_values, **kwargs)
|
|
return vision_outputs
|
|
|
|
@check_model_inputs
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
vision_embeds: Sam3VisionEncoderOutput | None = None,
|
|
input_ids: torch.LongTensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
text_embeds: torch.FloatTensor | None = None,
|
|
input_boxes: torch.FloatTensor | None = None,
|
|
input_boxes_labels: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Sam3ImageSegmentationOutput:
|
|
r"""
|
|
vision_embeds (`Sam3VisionEncoderOutput`, *optional*):
|
|
Pre-computed vision embeddings. Can be used to easily reuse vision embeddings. If provided, `pixel_values`
|
|
should not be passed. Mutually exclusive with `pixel_values`.
|
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Pre-computed text embeddings. Can be used to easily reuse text embeddings. If provided, `input_ids`
|
|
should not be passed. Mutually exclusive with `input_ids`.
|
|
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`, *optional*):
|
|
Normalized box coordinates in [0, 1] range, in (cx, cy, w, h) format.
|
|
input_boxes_labels (`torch.LongTensor` of shape `(batch_size, num_boxes)`, *optional*):
|
|
Labels for boxes: 1 (positive), 0 (negative).
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
>>> from transformers import AutoModel, AutoProcessor
|
|
|
|
>>> model = AutoModel.from_pretrained("facebook/sam3")
|
|
>>> processor = AutoProcessor.from_pretrained("facebook/sam3")
|
|
|
|
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read())).convert("RGB")
|
|
>>> text = "car"
|
|
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
|
|
|
>>> # Get segmentation output
|
|
>>> outputs = model(**inputs)
|
|
>>> pred_masks = outputs.pred_masks
|
|
>>> pred_boxes = outputs.pred_boxes
|
|
```
|
|
"""
|
|
if (pixel_values is None) == (vision_embeds is None):
|
|
raise ValueError("You must specify exactly one of pixel_values or vision_embeds")
|
|
|
|
if (input_ids is None) == (text_embeds is None):
|
|
raise ValueError("You must specify exactly one of input_ids or text_embeds")
|
|
|
|
if pixel_values is not None:
|
|
batch_size = pixel_values.shape[0]
|
|
device = pixel_values.device
|
|
else:
|
|
batch_size = vision_embeds.fpn_hidden_states[0].shape[0]
|
|
device = vision_embeds.fpn_hidden_states[0].device
|
|
|
|
if vision_embeds is None:
|
|
vision_outputs = self.vision_encoder(pixel_values, **kwargs)
|
|
else:
|
|
vision_outputs = vision_embeds
|
|
|
|
fpn_hidden_states = vision_outputs.fpn_hidden_states[:-1]
|
|
fpn_position_encoding = vision_outputs.fpn_position_encoding[:-1]
|
|
|
|
if text_embeds is None:
|
|
text_features = self.get_text_features(
|
|
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
|
|
).pooler_output
|
|
else:
|
|
text_features = text_embeds
|
|
|
|
text_mask = attention_mask.bool() if attention_mask is not None else None
|
|
has_geometry_prompts = input_boxes is not None and input_boxes.numel() > 0
|
|
|
|
geometry_prompt_features = None
|
|
geometry_prompt_mask = None
|
|
|
|
if has_geometry_prompts:
|
|
if input_boxes is not None and input_boxes.numel() > 0:
|
|
box_embeddings = input_boxes # [batch_size, num_boxes, 4]
|
|
box_labels = (
|
|
input_boxes_labels
|
|
if input_boxes_labels is not None
|
|
else torch.ones_like(box_embeddings[..., 0], dtype=torch.long)
|
|
)
|
|
box_mask = (
|
|
(input_boxes_labels != -10)
|
|
if input_boxes_labels is not None
|
|
else torch.ones(batch_size, input_boxes.shape[1], dtype=torch.bool, device=device)
|
|
)
|
|
box_labels = torch.where(box_labels == -10, 0, box_labels)
|
|
else:
|
|
box_embeddings = torch.zeros(batch_size, 0, 4, dtype=text_features.dtype, device=device)
|
|
box_labels = torch.zeros(batch_size, 0, dtype=torch.long, device=device)
|
|
box_mask = torch.zeros(batch_size, 0, dtype=torch.bool, device=device)
|
|
|
|
geometry_outputs = self.geometry_encoder(
|
|
box_embeddings=box_embeddings,
|
|
box_mask=box_mask,
|
|
box_labels=box_labels,
|
|
img_feats=fpn_hidden_states,
|
|
img_pos_embeds=fpn_position_encoding,
|
|
)
|
|
|
|
geometry_prompt_features = geometry_outputs.last_hidden_state
|
|
geometry_prompt_mask = geometry_outputs.attention_mask
|
|
|
|
if geometry_prompt_features is not None:
|
|
# Repeat text_features for all geometry prompts
|
|
if text_features.shape[0] == 1 and geometry_prompt_features.shape[0] > 1:
|
|
text_features = text_features.repeat(geometry_prompt_features.shape[0], 1, 1)
|
|
combined_prompt_features = torch.cat([text_features, geometry_prompt_features], dim=1)
|
|
if text_mask is not None and text_mask.shape[0] == 1 and geometry_prompt_mask.shape[0] > 1:
|
|
text_mask = text_mask.repeat(geometry_prompt_mask.shape[0], 1)
|
|
|
|
if text_mask is not None and geometry_prompt_mask is not None:
|
|
combined_prompt_mask = torch.cat([text_mask, geometry_prompt_mask], dim=1)
|
|
elif text_mask is not None:
|
|
geo_valid_mask = torch.ones(
|
|
batch_size, geometry_prompt_features.shape[1], dtype=torch.bool, device=device
|
|
)
|
|
combined_prompt_mask = torch.cat([text_mask, geo_valid_mask], dim=1)
|
|
elif geometry_prompt_mask is not None:
|
|
text_valid_mask = torch.ones(batch_size, text_features.shape[1], dtype=torch.bool, device=device)
|
|
combined_prompt_mask = torch.cat([text_valid_mask, geometry_prompt_mask], dim=1)
|
|
else:
|
|
combined_prompt_mask = None
|
|
else:
|
|
combined_prompt_features = text_features
|
|
combined_prompt_mask = text_mask
|
|
|
|
encoder_outputs = self.detr_encoder(
|
|
vision_features=[fpn_hidden_states[-1]],
|
|
text_features=combined_prompt_features,
|
|
vision_pos_embeds=[fpn_position_encoding[-1]],
|
|
text_mask=combined_prompt_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
decoder_outputs = self.detr_decoder(
|
|
vision_features=encoder_outputs.last_hidden_state,
|
|
text_features=encoder_outputs.text_features,
|
|
vision_pos_encoding=encoder_outputs.pos_embeds_flattened,
|
|
text_mask=combined_prompt_mask,
|
|
spatial_shapes=encoder_outputs.spatial_shapes,
|
|
**kwargs,
|
|
)
|
|
|
|
# Refine boxes from decoder
|
|
all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states)
|
|
reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes)
|
|
all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid()
|
|
all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh)
|
|
|
|
all_pred_logits = self.dot_product_scoring(
|
|
decoder_hidden_states=decoder_outputs.intermediate_hidden_states,
|
|
text_features=encoder_outputs.text_features,
|
|
text_mask=combined_prompt_mask,
|
|
).squeeze(-1)
|
|
|
|
pred_logits = all_pred_logits[-1]
|
|
pred_boxes = all_pred_boxes[-1]
|
|
decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1]
|
|
presence_logits = decoder_outputs.presence_logits[-1]
|
|
|
|
mask_outputs = self.mask_decoder(
|
|
decoder_queries=decoder_hidden_states,
|
|
backbone_features=list(fpn_hidden_states),
|
|
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
prompt_features=combined_prompt_features,
|
|
prompt_mask=combined_prompt_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
return Sam3ImageSegmentationOutput(
|
|
pred_masks=mask_outputs.pred_masks,
|
|
pred_boxes=pred_boxes,
|
|
pred_logits=pred_logits,
|
|
presence_logits=presence_logits,
|
|
semantic_seg=mask_outputs.semantic_seg,
|
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
decoder_reference_boxes=decoder_outputs.reference_boxes,
|
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
vision_hidden_states=vision_outputs.hidden_states,
|
|
vision_attentions=vision_outputs.attentions,
|
|
detr_encoder_attentions=encoder_outputs.attentions,
|
|
detr_decoder_attentions=decoder_outputs.attentions,
|
|
mask_decoder_attentions=mask_outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = ["Sam3Model", "Sam3VisionModel", "Sam3ViTModel", "Sam3PreTrainedModel"]
|