You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1460 lines
63 KiB
1460 lines
63 KiB
# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch SAM 2 model."""
|
|
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ... import initialization as init
|
|
from ...activations import ACT2FN
|
|
from ...image_processing_utils import BatchFeature, get_size_dict
|
|
from ...image_processing_utils_fast import BaseImageProcessorFast
|
|
from ...image_utils import (
|
|
IMAGENET_DEFAULT_MEAN,
|
|
IMAGENET_DEFAULT_STD,
|
|
ChannelDimension,
|
|
ImageInput,
|
|
PILImageResampling,
|
|
SizeDict,
|
|
pil_torch_interpolation_mapping,
|
|
)
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import BaseModelOutputWithPooling
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import ImagesKwargs, Unpack
|
|
from ...utils import ModelOutput, TensorType, auto_docstring, can_return_tuple, logging
|
|
from ...utils.generic import TransformersKwargs, check_model_inputs, is_flash_attention_requested
|
|
from ..auto import AutoModel
|
|
from ..maskformer.modeling_maskformer import MaskFormerSinePositionEmbedding
|
|
from ..sam.image_processing_sam_fast import SamImageProcessorFast
|
|
from ..sam.modeling_sam import (
|
|
SamLayerNorm,
|
|
SamMaskDecoder,
|
|
SamMaskEmbedding,
|
|
SamModel,
|
|
SamPromptEncoder,
|
|
SamTwoWayAttentionBlock,
|
|
SamTwoWayTransformer,
|
|
eager_attention_forward,
|
|
)
|
|
from ..vitdet.modeling_vitdet import window_partition, window_unpartition
|
|
from .configuration_sam2 import (
|
|
Sam2Config,
|
|
Sam2HieraDetConfig,
|
|
Sam2MaskDecoderConfig,
|
|
Sam2PromptEncoderConfig,
|
|
Sam2VisionConfig,
|
|
)
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class Sam2FastImageProcessorKwargs(ImagesKwargs, total=False):
|
|
r"""
|
|
mask_size (`dict[str, int]`, *optional*):
|
|
The size `{"height": int, "width": int}` to resize the segmentation maps to.
|
|
"""
|
|
|
|
mask_size: dict[str, int]
|
|
|
|
|
|
@auto_docstring
|
|
class Sam2ImageProcessorFast(SamImageProcessorFast):
|
|
resample = PILImageResampling.BILINEAR
|
|
image_mean = IMAGENET_DEFAULT_MEAN
|
|
image_std = IMAGENET_DEFAULT_STD
|
|
size = {"height": 1024, "width": 1024}
|
|
mask_size = {"height": 256, "width": 256}
|
|
do_resize = True
|
|
do_rescale = True
|
|
do_normalize = True
|
|
do_convert_rgb = True
|
|
|
|
valid_kwargs = Sam2FastImageProcessorKwargs
|
|
|
|
# modular artefacts
|
|
do_pad = None
|
|
pad_size = None
|
|
mask_pad_size = None
|
|
|
|
def __init__(self, **kwargs: Unpack[Sam2FastImageProcessorKwargs]):
|
|
BaseImageProcessorFast.__init__(self, **kwargs)
|
|
|
|
def _preprocess(
|
|
self,
|
|
images: list["torch.Tensor"],
|
|
return_tensors: str | TensorType | None,
|
|
**kwargs,
|
|
) -> "torch.Tensor":
|
|
return BaseImageProcessorFast._preprocess(self, images, return_tensors=return_tensors, **kwargs).pixel_values
|
|
|
|
@auto_docstring
|
|
def preprocess(
|
|
self,
|
|
images: ImageInput,
|
|
segmentation_maps: ImageInput | None = None,
|
|
**kwargs: Unpack[Sam2FastImageProcessorKwargs],
|
|
) -> BatchFeature:
|
|
r"""
|
|
segmentation_maps (`ImageInput`, *optional*):
|
|
The segmentation maps to preprocess.
|
|
"""
|
|
return super().preprocess(images, segmentation_maps, **kwargs)
|
|
|
|
def _preprocess_image_like_inputs(
|
|
self,
|
|
images: ImageInput,
|
|
segmentation_maps: ImageInput | None,
|
|
do_convert_rgb: bool,
|
|
input_data_format: ChannelDimension,
|
|
device: Union[str, "torch.device"] | None = None,
|
|
**kwargs: Unpack[Sam2FastImageProcessorKwargs],
|
|
) -> BatchFeature:
|
|
"""
|
|
Preprocess image-like inputs.
|
|
"""
|
|
images = self._prepare_image_like_inputs(
|
|
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
|
)
|
|
original_sizes = [image.shape[-2:] for image in images]
|
|
images_kwargs = kwargs.copy()
|
|
pixel_values = self._preprocess(images, **images_kwargs)
|
|
data = {
|
|
"pixel_values": pixel_values,
|
|
"original_sizes": original_sizes,
|
|
}
|
|
|
|
if segmentation_maps is not None:
|
|
processed_segmentation_maps = self._prepare_image_like_inputs(
|
|
images=segmentation_maps,
|
|
expected_ndims=2,
|
|
do_convert_rgb=False,
|
|
input_data_format=ChannelDimension.FIRST,
|
|
)
|
|
|
|
segmentation_maps_kwargs = kwargs.copy()
|
|
segmentation_maps_kwargs.update(
|
|
{
|
|
"do_normalize": False,
|
|
"do_rescale": False,
|
|
"interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST],
|
|
"size": segmentation_maps_kwargs.pop("mask_size"),
|
|
}
|
|
)
|
|
processed_segmentation_maps = self._preprocess(
|
|
images=processed_segmentation_maps, **segmentation_maps_kwargs
|
|
)
|
|
data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64)
|
|
|
|
return BatchFeature(data=data, tensor_type=kwargs["return_tensors"])
|
|
|
|
def _further_process_kwargs(
|
|
self,
|
|
size: SizeDict | None = None,
|
|
mask_size: SizeDict | None = None,
|
|
default_to_square: bool | None = None,
|
|
image_mean: float | list[float] | None = None,
|
|
image_std: float | list[float] | None = None,
|
|
data_format: ChannelDimension | None = None,
|
|
**kwargs,
|
|
) -> dict:
|
|
"""
|
|
Update kwargs that need further processing before being validated
|
|
Can be overridden by subclasses to customize the processing of kwargs.
|
|
"""
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if size is not None:
|
|
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
|
|
if mask_size is not None:
|
|
mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size"))
|
|
if isinstance(image_mean, list):
|
|
image_mean = tuple(image_mean)
|
|
if isinstance(image_std, list):
|
|
image_std = tuple(image_std)
|
|
if data_format is None:
|
|
data_format = ChannelDimension.FIRST
|
|
|
|
kwargs["size"] = size
|
|
kwargs["mask_size"] = mask_size
|
|
kwargs["image_mean"] = image_mean
|
|
kwargs["image_std"] = image_std
|
|
kwargs["data_format"] = data_format
|
|
|
|
# torch resize uses interpolation instead of resample
|
|
# Check if resample is an int before checking if it's an instance of PILImageResampling
|
|
# because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
|
|
# Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
|
|
resample = kwargs.pop("resample")
|
|
kwargs["interpolation"] = (
|
|
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
|
)
|
|
|
|
return kwargs
|
|
|
|
def _apply_non_overlapping_constraints(self, pred_masks: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Apply non-overlapping constraints to the object scores in pred_masks. Here we
|
|
keep only the highest scoring object at each spatial location in pred_masks.
|
|
"""
|
|
batch_size = pred_masks.size(0)
|
|
if batch_size == 1:
|
|
return pred_masks
|
|
|
|
device = pred_masks.device
|
|
# "max_obj_inds": object index of the object with the highest score at each location
|
|
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
|
|
# "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
|
|
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
|
|
keep = max_obj_inds == batch_obj_inds
|
|
# suppress overlapping regions' scores below -10.0 so that the foreground regions
|
|
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
|
|
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
|
return pred_masks
|
|
|
|
def post_process_masks(
|
|
self,
|
|
masks,
|
|
original_sizes,
|
|
mask_threshold=0.0,
|
|
binarize=True,
|
|
max_hole_area=0.0,
|
|
max_sprinkle_area=0.0,
|
|
apply_non_overlapping_constraints=False,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Remove padding and upscale masks to the original image size.
|
|
|
|
Args:
|
|
masks (`Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]`):
|
|
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
|
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
|
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
|
width) format.
|
|
mask_threshold (`float`, *optional*, defaults to 0.0):
|
|
Threshold for binarization and post-processing operations.
|
|
binarize (`bool`, *optional*, defaults to `True`):
|
|
Whether to binarize the masks.
|
|
max_hole_area (`float`, *optional*, defaults to 0.0):
|
|
The maximum area of a hole to fill.
|
|
max_sprinkle_area (`float`, *optional*, defaults to 0.0):
|
|
The maximum area of a sprinkle to fill.
|
|
apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`):
|
|
Whether to apply non-overlapping constraints to the masks.
|
|
|
|
Returns:
|
|
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
|
|
is given by original_size.
|
|
"""
|
|
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
|
|
original_sizes = original_sizes.tolist()
|
|
# TODO: add connected components kernel for postprocessing
|
|
output_masks = []
|
|
for i, original_size in enumerate(original_sizes):
|
|
if isinstance(masks[i], np.ndarray):
|
|
masks[i] = torch.from_numpy(masks[i])
|
|
elif not isinstance(masks[i], torch.Tensor):
|
|
raise TypeError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
|
|
interpolated_mask = F.interpolate(masks[i], original_size, mode="bilinear", align_corners=False)
|
|
if apply_non_overlapping_constraints:
|
|
interpolated_mask = self._apply_non_overlapping_constraints(interpolated_mask)
|
|
if binarize:
|
|
interpolated_mask = interpolated_mask > mask_threshold
|
|
output_masks.append(interpolated_mask)
|
|
|
|
return output_masks
|
|
|
|
def _get_preprocess_shape(self):
|
|
raise NotImplementedError("No _get_preprocess_shape for SAM 2.")
|
|
|
|
def resize(self):
|
|
raise NotImplementedError("No need to override resize for SAM 2.")
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
|
|
class Sam2VisionEncoderOutput(BaseModelOutputWithPooling):
|
|
r"""
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
|
|
Sequence of hidden-states at the output of the last layer of the model.
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
|
one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
|
|
model at the output of each stage.
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
|
the self-attention heads.
|
|
fpn_hidden_states (`tuple(torch.FloatTensor)`):
|
|
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
|
|
`(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
|
|
fpn_position_encoding (`tuple(torch.FloatTensor)`):
|
|
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
|
|
`(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
|
|
"""
|
|
|
|
fpn_hidden_states: torch.FloatTensor | None = None
|
|
fpn_position_encoding: torch.FloatTensor | None = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(custom_intro="Base class for the Sam2 model's output.")
|
|
class Sam2ImageSegmentationOutput(ModelOutput):
|
|
r"""
|
|
iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
|
|
The Intersection over Union (IoU) scores of the predicted masks.
|
|
pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
|
|
The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
|
|
by the processor to be brought to the original image size.
|
|
object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
|
|
Logits for the object score, indicating if an object is present.
|
|
image_embeddings (`tuple(torch.FloatTensor)`):
|
|
The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
|
|
tensor has shape `(batch_size, channels, height, width)`.
|
|
vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
|
|
Hidden-states of the vision model at the output of each stage.
|
|
vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
Attentions weights of the vision model.
|
|
mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
Attentions weights of the mask decoder.
|
|
"""
|
|
|
|
iou_scores: torch.FloatTensor | None = None
|
|
pred_masks: torch.FloatTensor | None = None
|
|
object_score_logits: torch.FloatTensor | None = None
|
|
image_embeddings: tuple[torch.FloatTensor, ...] = None
|
|
vision_hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
|
vision_attentions: tuple[torch.FloatTensor, ...] | None = None
|
|
mask_decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
|
|
|
|
|
|
class Sam2PatchEmbeddings(nn.Module):
|
|
r"""
|
|
Turns pixel values into patch embeddings for transformer consumption.
|
|
|
|
Args:
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
Pixel values. Pixel values can be obtained using
|
|
[`AutoImageProcessor`]. See [`Sam2ImageProcessorFast.__call__`] for details.
|
|
|
|
Returns:
|
|
embeddings (`torch.FloatTensor`):
|
|
Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding
|
|
"""
|
|
|
|
def __init__(self, config: Sam2HieraDetConfig):
|
|
super().__init__()
|
|
num_channels = config.num_channels
|
|
hidden_size = config.hidden_size
|
|
|
|
self.projection = nn.Conv2d(
|
|
num_channels,
|
|
hidden_size,
|
|
kernel_size=config.patch_kernel_size,
|
|
stride=config.patch_stride,
|
|
padding=config.patch_padding,
|
|
)
|
|
|
|
def forward(self, pixel_values):
|
|
_, num_channels, height, width = pixel_values.shape
|
|
embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).permute(0, 2, 3, 1)
|
|
return embeddings
|
|
|
|
|
|
class Sam2SinePositionEmbedding(MaskFormerSinePositionEmbedding):
|
|
pass
|
|
|
|
|
|
class Sam2VisionNeck(nn.Module):
|
|
def __init__(self, config: Sam2VisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
|
|
self.convs = nn.ModuleList()
|
|
for in_channels in config.backbone_channel_list:
|
|
self.convs.append(
|
|
nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=config.fpn_hidden_size,
|
|
kernel_size=config.fpn_kernel_size,
|
|
stride=config.fpn_stride,
|
|
padding=config.fpn_padding,
|
|
),
|
|
)
|
|
self.fpn_top_down_levels = config.fpn_top_down_levels
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
|
|
fpn_hidden_states = ()
|
|
fpn_position_encoding = ()
|
|
|
|
# forward in top-down order (from low to high resolution)
|
|
n = len(self.convs) - 1
|
|
for i in range(n, -1, -1):
|
|
lateral_features = hidden_states[i].permute(0, 3, 1, 2)
|
|
lateral_features = self.convs[n - i](lateral_features.to(self.convs[i].weight.dtype))
|
|
if i not in self.fpn_top_down_levels or i == n:
|
|
prev_features = lateral_features
|
|
else:
|
|
top_down_features = F.interpolate(
|
|
prev_features.to(dtype=torch.float32),
|
|
scale_factor=2.0,
|
|
mode="nearest",
|
|
align_corners=None,
|
|
antialias=False,
|
|
).to(lateral_features.dtype)
|
|
prev_features = lateral_features + top_down_features
|
|
|
|
prev_position_encoding = self.position_encoding(
|
|
prev_features.shape, prev_features.device, prev_features.dtype
|
|
).to(prev_features.dtype)
|
|
|
|
fpn_hidden_states += (prev_features,)
|
|
fpn_position_encoding += (prev_position_encoding,)
|
|
|
|
return fpn_hidden_states, fpn_position_encoding
|
|
|
|
|
|
def do_pool(x: torch.Tensor, query_stride: int | None = None) -> torch.Tensor:
|
|
if query_stride is None:
|
|
return x
|
|
# (B, H, W, C) -> (B, C, H, W)
|
|
x = x.permute(0, 3, 1, 2)
|
|
x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False)
|
|
# (B, C, H', W') -> (B, H', W', C)
|
|
x = x.permute(0, 2, 3, 1)
|
|
return x
|
|
|
|
|
|
class Sam2MultiScaleAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: Sam2HieraDetConfig,
|
|
dim: int,
|
|
dim_out: int,
|
|
num_attention_heads: int,
|
|
query_stride: tuple[int, int] | None = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
|
|
self.dim = dim
|
|
self.dim_out = dim_out
|
|
self.query_stride = query_stride
|
|
|
|
self.num_attention_heads = num_attention_heads
|
|
head_dim = dim_out // num_attention_heads
|
|
self.scale = head_dim**-0.5
|
|
self.qkv = nn.Linear(dim, dim_out * 3)
|
|
self.proj = nn.Linear(dim_out, dim_out)
|
|
|
|
self.is_causal = False
|
|
|
|
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
batch_size, height, width, _ = hidden_states.shape
|
|
# qkv with shape (B, H * W, 3, nHead, C)
|
|
qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
|
# q, k, v with shape (B, H * W, nheads, C)
|
|
query, key, value = torch.unbind(qkv, 2)
|
|
|
|
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
|
|
|
# Q pooling (for downsample at stage changes)
|
|
if self.query_stride:
|
|
query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride)
|
|
height, width = query.shape[1:3] # downsampled shape
|
|
query = query.reshape(batch_size, height * width, self.num_attention_heads, -1)
|
|
|
|
# transpose query, key, value to (B, nHead, H * W, C)
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
attn_output, _ = attention_interface(
|
|
self,
|
|
query,
|
|
key,
|
|
value,
|
|
attention_mask=None,
|
|
is_causal=self.is_causal,
|
|
scaling=self.scale,
|
|
**kwargs,
|
|
)
|
|
attn_output = attn_output.reshape(batch_size, height, width, -1)
|
|
|
|
attn_output = self.proj(attn_output)
|
|
|
|
return attn_output
|
|
|
|
|
|
class Sam2FeedForward(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
output_dim: int,
|
|
num_layers: int,
|
|
activation: str = "relu",
|
|
sigmoid_output: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.num_layers = num_layers
|
|
self.activation = ACT2FN[activation]
|
|
self.proj_in = nn.Linear(input_dim, hidden_dim)
|
|
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
|
self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
|
|
self.sigmoid_output = sigmoid_output
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.proj_in(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
for layer in self.layers:
|
|
hidden_states = self.activation(layer(hidden_states))
|
|
|
|
hidden_states = self.proj_out(hidden_states)
|
|
if self.sigmoid_output:
|
|
hidden_states = F.sigmoid(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class Sam2MultiScaleBlock(GradientCheckpointingLayer):
|
|
def __init__(
|
|
self,
|
|
config: Sam2HieraDetConfig,
|
|
stage_idx: int,
|
|
block_idx: int,
|
|
total_block_idx: int,
|
|
):
|
|
super().__init__()
|
|
|
|
# take embed dim from previous stage if first block of stage
|
|
self.dim = (
|
|
config.embed_dim_per_stage[stage_idx - 1]
|
|
if stage_idx > 0 and block_idx == 0
|
|
else config.embed_dim_per_stage[stage_idx]
|
|
)
|
|
self.dim_out = config.embed_dim_per_stage[stage_idx]
|
|
self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
|
|
# take window size from previous stage if first block of stage
|
|
self.window_size = (
|
|
config.window_size_per_stage[stage_idx - 1]
|
|
if stage_idx > 0 and block_idx == 0
|
|
else config.window_size_per_stage[stage_idx]
|
|
)
|
|
self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size
|
|
# use query stride for first block of stage if stage is a query pool stage
|
|
self.query_stride = (
|
|
config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None
|
|
)
|
|
|
|
self.attn = Sam2MultiScaleAttention(
|
|
config,
|
|
self.dim,
|
|
self.dim_out,
|
|
num_attention_heads=config.num_attention_heads_per_stage[stage_idx],
|
|
query_stride=self.query_stride,
|
|
)
|
|
self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps)
|
|
self.mlp = Sam2FeedForward(
|
|
self.dim_out,
|
|
int(self.dim_out * config.mlp_ratio),
|
|
self.dim_out,
|
|
num_layers=2,
|
|
activation=config.hidden_act,
|
|
)
|
|
if self.dim != self.dim_out:
|
|
self.proj = nn.Linear(self.dim, self.dim_out)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.FloatTensor:
|
|
residual = hidden_states # batch_size, height, width, channel
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
|
|
# Skip connection
|
|
if self.dim != self.dim_out:
|
|
residual = do_pool(self.proj(hidden_states), self.query_stride)
|
|
|
|
# Window partition
|
|
window_size = self.window_size
|
|
if self.window_size > 0:
|
|
H, W = hidden_states.shape[1], hidden_states.shape[2]
|
|
hidden_states, pad_hw = window_partition(hidden_states, window_size)
|
|
|
|
# Window Attention + Q Pooling (if stage change)
|
|
attn_output = self.attn(
|
|
hidden_states=hidden_states,
|
|
**kwargs,
|
|
)
|
|
hidden_states = attn_output
|
|
if self.query_stride:
|
|
# Shapes have changed due to Q pooling
|
|
window_size = self.window_size // self.query_stride[0]
|
|
H, W = residual.shape[1:3]
|
|
|
|
pad_h = (window_size - H % window_size) % window_size
|
|
pad_w = (window_size - W % window_size) % window_size
|
|
pad_hw = (H + pad_h, W + pad_w)
|
|
|
|
# Reverse window partition
|
|
if self.window_size > 0:
|
|
hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W))
|
|
|
|
hidden_states = residual + hidden_states
|
|
layernorm_output = self.layer_norm2(hidden_states)
|
|
hidden_states = hidden_states + self.mlp(layernorm_output)
|
|
|
|
return hidden_states
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Hiera model's outputs that also contains a pooling of the last hidden states.
|
|
"""
|
|
)
|
|
class Sam2HieraDetModelOutput(ModelOutput):
|
|
r"""
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
|
|
hidden-states at the output of the last layer of the model.
|
|
intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`):
|
|
Sequence of hidden-states at the output of the intermediate layers of the model.
|
|
"""
|
|
|
|
last_hidden_state: torch.FloatTensor | None = None
|
|
intermediate_hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
|
|
|
|
|
@auto_docstring
|
|
class Sam2PreTrainedModel(PreTrainedModel):
|
|
config_class = Sam2Config
|
|
base_model_prefix = "sam2"
|
|
main_input_name = "pixel_values"
|
|
input_modalities = ("image",)
|
|
_supports_sdpa = True
|
|
_supports_flash_attn = True
|
|
_supports_attention_backend = True
|
|
|
|
@torch.no_grad()
|
|
def _init_weights(self, module):
|
|
super()._init_weights(module)
|
|
if isinstance(module, Sam2HieraDetModel):
|
|
if module.pos_embed is not None:
|
|
init.zeros_(module.pos_embed)
|
|
if module.pos_embed_window is not None:
|
|
init.zeros_(module.pos_embed_window)
|
|
elif isinstance(module, Sam2PositionalEmbedding):
|
|
init.normal_(module.positional_embedding, std=module.scale)
|
|
elif isinstance(module, Sam2Model):
|
|
if module.no_memory_embedding is not None:
|
|
init.zeros_(module.no_memory_embedding)
|
|
|
|
|
|
class Sam2HieraDetModel(Sam2PreTrainedModel):
|
|
config_class = Sam2HieraDetConfig
|
|
main_input_name = "pixel_values"
|
|
_can_record_outputs = {
|
|
"hidden_states": Sam2MultiScaleBlock,
|
|
"attentions": Sam2MultiScaleAttention,
|
|
}
|
|
|
|
def __init__(self, config: Sam2HieraDetConfig):
|
|
super().__init__(config)
|
|
|
|
self.patch_embed = Sam2PatchEmbeddings(config)
|
|
# Windowed positional embedding (https://huggingface.co/papers/2311.05613)
|
|
self.pos_embed = nn.Parameter(
|
|
torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size)
|
|
)
|
|
self.pos_embed_window = nn.Parameter(
|
|
torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0])
|
|
)
|
|
self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist()
|
|
self.blocks = nn.ModuleList()
|
|
total_block_idx = 0
|
|
for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage):
|
|
for block_idx in range(blocks_per_stage):
|
|
block = Sam2MultiScaleBlock(
|
|
config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx
|
|
)
|
|
self.blocks.append(block)
|
|
total_block_idx += 1
|
|
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.patch_embed
|
|
|
|
def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
|
|
h, w = hw
|
|
window_embed = self.pos_embed_window
|
|
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
|
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
|
|
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
|
return pos_embed
|
|
|
|
@check_model_inputs
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | Sam2HieraDetModelOutput:
|
|
if pixel_values is None:
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
hidden_states = self.patch_embed(pixel_values)
|
|
hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3])
|
|
|
|
intermediate_hidden_states = ()
|
|
for i, block_module in enumerate(self.blocks):
|
|
hidden_states = block_module(hidden_states, **kwargs)
|
|
|
|
if i in self.stage_ends:
|
|
intermediate_hidden_states = intermediate_hidden_states + (hidden_states,)
|
|
|
|
return Sam2HieraDetModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
intermediate_hidden_states=intermediate_hidden_states,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The vision model from Sam without any head or projection on top.
|
|
"""
|
|
)
|
|
class Sam2VisionModel(Sam2PreTrainedModel):
|
|
config_class = Sam2VisionConfig
|
|
main_input_name = "pixel_values"
|
|
_can_record_outputs = {
|
|
"hidden_states": Sam2MultiScaleBlock,
|
|
"attentions": Sam2MultiScaleAttention,
|
|
}
|
|
|
|
def __init__(self, config: Sam2VisionConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.backbone = AutoModel.from_config(config.backbone_config)
|
|
|
|
self.neck = Sam2VisionNeck(config)
|
|
self.num_feature_levels = config.num_feature_levels
|
|
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.backbone.get_input_embeddings()
|
|
|
|
@check_model_inputs
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | Sam2VisionEncoderOutput:
|
|
if pixel_values is None:
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
# Forward through backbone
|
|
backbone_output = self.backbone(pixel_values, **kwargs)
|
|
hidden_states = backbone_output.last_hidden_state
|
|
intermediate_hidden_states = backbone_output.intermediate_hidden_states
|
|
|
|
fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
|
|
# Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
|
|
fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
|
|
fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
|
|
|
|
return Sam2VisionEncoderOutput(
|
|
last_hidden_state=hidden_states,
|
|
fpn_hidden_states=fpn_hidden_states,
|
|
fpn_position_encoding=fpn_position_encoding,
|
|
)
|
|
|
|
|
|
class Sam2PositionalEmbedding(nn.Module):
|
|
def __init__(self, config: Sam2PromptEncoderConfig):
|
|
super().__init__()
|
|
self.scale = config.scale
|
|
positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
|
|
self.register_buffer("positional_embedding", positional_embedding)
|
|
|
|
def forward(self, input_coords, input_shape=None):
|
|
"""Positionally encode points that are normalized to [0,1]."""
|
|
coordinates = input_coords.clone()
|
|
|
|
if input_shape is not None:
|
|
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
|
|
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
|
|
coordinates.to(torch.float32)
|
|
|
|
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
|
coordinates = 2 * coordinates - 1
|
|
coordinates = coordinates.to(self.positional_embedding.dtype)
|
|
coordinates = coordinates @ self.positional_embedding
|
|
coordinates = 2 * np.pi * coordinates
|
|
# outputs d_1 x ... x d_n x channel shape
|
|
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
|
|
|
|
|
|
class Sam2MaskEmbedding(SamMaskEmbedding):
|
|
pass
|
|
|
|
|
|
class Sam2PromptEncoder(SamPromptEncoder):
|
|
def __init__(self, config: Sam2PromptEncoderConfig):
|
|
nn.Module.__init__(self)
|
|
self.shared_embedding = Sam2PositionalEmbedding(config)
|
|
self.mask_embed = Sam2MaskEmbedding(config)
|
|
self.no_mask_embed = nn.Embedding(1, config.hidden_size)
|
|
|
|
self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
|
|
self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
|
|
self.input_image_size = config.image_size
|
|
|
|
self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
|
|
self.hidden_size = config.hidden_size
|
|
self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
|
|
|
|
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
|
|
"""Embeds point prompts."""
|
|
points = points + 0.5 # Shift to center of pixel
|
|
if pad:
|
|
points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
|
|
labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
|
|
input_shape = (self.input_image_size, self.input_image_size)
|
|
point_embedding = self.shared_embedding(points, input_shape)
|
|
|
|
# torch.where and expanding the labels tensor is required by the ONNX export
|
|
point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
|
|
|
|
# This is required for the ONNX export. The dtype, device need to be explicitly
|
|
# specified as otherwise torch.onnx.export interprets as double
|
|
point_embedding = torch.where(
|
|
labels[..., None] != -10,
|
|
point_embedding,
|
|
torch.zeros_like(point_embedding),
|
|
)
|
|
|
|
# Add point embeddings for labels >= 0
|
|
point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
|
|
|
|
return point_embedding
|
|
|
|
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
|
"""Embeds box prompts."""
|
|
boxes = boxes + 0.5 # Shift to center of pixel
|
|
coords = boxes.view(*boxes.shape[:2], 2, 2)
|
|
# add padding point for consistency with the original implementation
|
|
coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
|
|
corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
|
|
corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
|
|
corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
|
|
corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
|
|
return corner_embedding
|
|
|
|
|
|
class Sam2Attention(nn.Module):
|
|
"""
|
|
SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
|
values.
|
|
"""
|
|
|
|
def __init__(self, config, downsample_rate=None):
|
|
super().__init__()
|
|
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.internal_dim = config.hidden_size // downsample_rate
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.head_dim = self.internal_dim // config.num_attention_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
self.is_causal = False
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
|
self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_similarity: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Input projections
|
|
batch_size, point_batch_size = query.shape[:2]
|
|
new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
|
|
|
|
query = self.q_proj(query).view(*new_shape).transpose(1, 2)
|
|
key = self.k_proj(key).view(*new_shape).transpose(1, 2)
|
|
value = self.v_proj(value).view(*new_shape).transpose(1, 2)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
if is_flash_attention_requested(self.config) and attention_similarity is not None:
|
|
# Target guided masks are represented as float masks and are incompatible with Flash Attention
|
|
# Fallback to SDPA for this call only so the rest of the model can still benefit from FA
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
|
|
logger.warning_once(
|
|
"Falling back to SDPA for target-guided attention because "
|
|
"Flash Attention does not support additive bias masks."
|
|
)
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query,
|
|
key,
|
|
value,
|
|
attention_mask=attention_similarity,
|
|
dropout=0.0,
|
|
scaling=self.scaling,
|
|
is_causal=self.is_causal,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(
|
|
batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
|
|
).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class Sam2TwoWayAttentionBlock(SamTwoWayAttentionBlock, GradientCheckpointingLayer):
|
|
def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False):
|
|
nn.Module.__init__(self)
|
|
self.self_attn = Sam2Attention(config, downsample_rate=1)
|
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size)
|
|
|
|
self.cross_attn_token_to_image = Sam2Attention(config)
|
|
self.layer_norm2 = nn.LayerNorm(config.hidden_size)
|
|
|
|
self.mlp = Sam2FeedForward(
|
|
config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
|
|
)
|
|
self.layer_norm3 = nn.LayerNorm(config.hidden_size)
|
|
|
|
self.layer_norm4 = nn.LayerNorm(config.hidden_size)
|
|
self.cross_attn_image_to_token = Sam2Attention(config)
|
|
|
|
self.skip_first_layer_pe = skip_first_layer_pe
|
|
|
|
|
|
class Sam2TwoWayTransformer(SamTwoWayTransformer):
|
|
pass
|
|
|
|
|
|
class Sam2LayerNorm(SamLayerNorm):
|
|
pass
|
|
|
|
|
|
class Sam2MaskDecoder(SamMaskDecoder):
|
|
def __init__(self, config: Sam2MaskDecoderConfig):
|
|
super().__init__(config)
|
|
del self.iou_prediction_head
|
|
self.iou_prediction_head = Sam2FeedForward(
|
|
self.hidden_size,
|
|
config.iou_head_hidden_dim,
|
|
self.num_mask_tokens,
|
|
config.iou_head_depth,
|
|
sigmoid_output=True,
|
|
)
|
|
|
|
self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
|
|
self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
|
|
|
|
self.obj_score_token = nn.Embedding(1, self.hidden_size)
|
|
self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3)
|
|
|
|
self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
|
|
self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
|
|
self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
|
|
|
|
def _get_stability_scores(self, mask_logits):
|
|
"""
|
|
Compute stability scores of the mask logits based on the IoU between upper and
|
|
lower thresholds.
|
|
"""
|
|
mask_logits = mask_logits.flatten(-2)
|
|
stability_delta = self.dynamic_multimask_stability_delta
|
|
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
|
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
|
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
|
return stability_scores
|
|
|
|
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
|
"""
|
|
When outputting a single mask, if the stability score from the current single-mask
|
|
output (based on output token 0) falls below a threshold, we instead select from
|
|
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
|
|
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
|
|
"""
|
|
# The best mask from multimask output tokens (1~3)
|
|
multimask_logits = all_mask_logits[:, :, 1:, :, :]
|
|
multimask_iou_scores = all_iou_scores[:, :, 1:]
|
|
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
|
|
best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
|
best_scores_inds_expanded = best_scores_inds_expanded.expand(
|
|
-1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
|
|
)
|
|
best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
|
|
best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
|
|
|
|
# The mask from singlemask output token 0 and its stability score
|
|
singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
|
|
singlemask_iou_scores = all_iou_scores[:, :, 0:1]
|
|
stability_scores = self._get_stability_scores(singlemask_logits)
|
|
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
|
|
|
# Dynamically fall back to best multimask output upon low stability scores.
|
|
mask_logits_out = torch.where(
|
|
is_stable[..., None, None].expand_as(singlemask_logits),
|
|
singlemask_logits,
|
|
best_multimask_logits,
|
|
)
|
|
iou_scores_out = torch.where(
|
|
is_stable.expand_as(singlemask_iou_scores),
|
|
singlemask_iou_scores,
|
|
best_multimask_iou_scores,
|
|
)
|
|
return mask_logits_out, iou_scores_out
|
|
|
|
def forward(
|
|
self,
|
|
image_embeddings: torch.Tensor,
|
|
image_positional_embeddings: torch.Tensor,
|
|
sparse_prompt_embeddings: torch.Tensor,
|
|
dense_prompt_embeddings: torch.Tensor,
|
|
multimask_output: bool,
|
|
high_resolution_features: list[torch.Tensor],
|
|
attention_similarity: torch.Tensor | None = None,
|
|
target_embedding: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Predict masks given image and prompt embeddings.
|
|
|
|
Args:
|
|
image_embeddings (`torch.Tensor`):
|
|
The embeddings from the image encoder.
|
|
image_positional_embeddings (`torch.Tensor`):
|
|
Positional encoding with the shape of image_embeddings.
|
|
sparse_prompt_embeddings (`torch.Tensor`):
|
|
The embeddings of the points and boxes.
|
|
dense_prompt_embeddings (`torch.Tensor`):
|
|
The embeddings of the mask inputs.
|
|
multimask_output (`bool`):
|
|
Whether to return multiple masks or a single mask.
|
|
high_resolution_features (`list[torch.Tensor]`, *optional*):
|
|
The high-resolution features from the vision encoder.
|
|
attention_similarity (`torch.Tensor`, *optional*):
|
|
The attention similarity tensor.
|
|
target_embedding (`torch.Tensor`, *optional*):
|
|
The target embedding.
|
|
"""
|
|
batch_size, num_channels, height, width = image_embeddings.shape
|
|
point_batch_size = sparse_prompt_embeddings.shape[1]
|
|
# Concatenate output tokens
|
|
output_tokens = torch.cat(
|
|
[
|
|
self.obj_score_token.weight,
|
|
self.iou_token.weight,
|
|
self.mask_tokens.weight,
|
|
],
|
|
dim=0,
|
|
)
|
|
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
|
|
|
if sparse_prompt_embeddings.shape[0] != 0:
|
|
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
|
|
else:
|
|
tokens = output_tokens
|
|
point_embeddings = tokens.to(self.iou_token.weight.dtype)
|
|
|
|
# Expand per-image data in batch direction to be per-mask
|
|
image_embeddings = image_embeddings + dense_prompt_embeddings
|
|
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
|
|
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
|
|
# Run the transformer
|
|
point_embeddings, image_embeddings = self.transformer(
|
|
point_embeddings=point_embeddings,
|
|
image_embeddings=image_embeddings,
|
|
image_positional_embeddings=image_positional_embeddings,
|
|
attention_similarity=attention_similarity,
|
|
target_embedding=target_embedding,
|
|
**kwargs,
|
|
)
|
|
iou_token_out = point_embeddings[:, :, 1, :]
|
|
mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
|
|
|
|
# Upscale mask embeddings and predict masks using the mask tokens
|
|
image_embeddings = image_embeddings.transpose(2, 3).view(
|
|
batch_size * point_batch_size, num_channels, height, width
|
|
)
|
|
|
|
feat_s0, feat_s1 = high_resolution_features
|
|
feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
|
|
feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
|
|
upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
|
|
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
|
|
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
|
|
|
|
hyper_in_list: list[torch.Tensor] = []
|
|
for i in range(self.num_mask_tokens):
|
|
current_mlp = self.output_hypernetworks_mlps[i]
|
|
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
|
|
hyper_in = torch.stack(hyper_in_list, dim=2)
|
|
|
|
_, num_channels, height, width = upscaled_embedding.shape
|
|
upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
|
|
masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
|
|
|
|
# Generate mask quality predictions
|
|
iou_pred = self.iou_prediction_head(iou_token_out)
|
|
object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
|
|
|
|
# Select the correct mask or masks for output
|
|
if multimask_output:
|
|
mask_slice = slice(1, None)
|
|
masks = masks[:, :, mask_slice, :, :]
|
|
iou_pred = iou_pred[:, :, mask_slice]
|
|
elif self.dynamic_multimask_via_stability and not self.training:
|
|
mask_slice = slice(0, 1)
|
|
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
|
else:
|
|
mask_slice = slice(0, 1)
|
|
masks = masks[:, :, mask_slice, :, :]
|
|
iou_pred = iou_pred[:, :, mask_slice]
|
|
|
|
sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
|
|
|
|
return masks, iou_pred, sam_tokens_out, object_score_logits
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
|
|
input points and labels, boxes, or masks.
|
|
"""
|
|
)
|
|
class Sam2Model(SamModel):
|
|
_keys_to_ignore_on_load_unexpected = [
|
|
r"^memory_.*",
|
|
r"^mask_downsample.*",
|
|
r"^object_pointer_proj.*",
|
|
r"^temporal_positional_encoding_projection_layer.*",
|
|
"no_memory_positional_encoding",
|
|
"no_object_pointer",
|
|
"occlusion_spatial_embedding_parameter",
|
|
]
|
|
_tied_weights_keys = {}
|
|
|
|
def __init__(self, config: Sam2Config):
|
|
PreTrainedModel.__init__(self, config)
|
|
self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config)
|
|
self.vision_encoder = AutoModel.from_config(config.vision_config)
|
|
self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config)
|
|
# The module using it is not a PreTrainedModel subclass so we need this
|
|
config.mask_decoder_config._attn_implementation = config._attn_implementation
|
|
self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config)
|
|
|
|
self.num_feature_levels = config.vision_config.num_feature_levels
|
|
self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
|
|
# a single token to indicate no memory embedding from previous frames
|
|
self.hidden_dim = config.vision_config.fpn_hidden_size
|
|
self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
|
|
|
self.post_init()
|
|
|
|
def get_image_wide_positional_embeddings(self) -> torch.Tensor:
|
|
size = self.prompt_encoder.image_embedding_size
|
|
target_device = self.shared_image_embedding.positional_embedding.device
|
|
target_dtype = self.shared_image_embedding.positional_embedding.dtype
|
|
grid = torch.ones(size, device=target_device, dtype=target_dtype)
|
|
y_embed = grid.cumsum(dim=0) - 0.5
|
|
x_embed = grid.cumsum(dim=1) - 0.5
|
|
y_embed = y_embed / size[0]
|
|
x_embed = x_embed / size[1]
|
|
|
|
positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
|
|
return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
|
|
|
|
@torch.no_grad()
|
|
def get_image_embeddings(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> list[torch.Tensor]:
|
|
r"""
|
|
Returns the image embeddings by passing the pixel values through the vision encoder.
|
|
|
|
Args:
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
Input pixel values
|
|
"""
|
|
batch_size = pixel_values.shape[0]
|
|
image_outputs = self.get_image_features(pixel_values, return_dict=True, **kwargs)
|
|
feature_maps = image_outputs.fpn_hidden_states
|
|
|
|
# add no memory embedding to the last feature map
|
|
feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
|
|
|
|
# reshape feature maps to the same shape as the backbone feature sizes
|
|
image_embeddings = [
|
|
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
|
for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
|
|
]
|
|
|
|
return image_embeddings
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | Sam2VisionEncoderOutput:
|
|
r"""
|
|
pixel_values (`torch.FloatTensor`):
|
|
Input pixel values of shape `(batch_size, num_channels, height, width)`.
|
|
"""
|
|
vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)
|
|
|
|
feature_maps = vision_outputs.fpn_hidden_states
|
|
feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
|
|
|
|
# precompute projected level 0 and level 1 features in SAM decoder
|
|
# to avoid running it again on every SAM click
|
|
feature_maps = list(feature_maps)
|
|
feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
|
|
feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
|
|
|
|
# flatten NxCxHxW to HWxNxC
|
|
feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
|
|
feature_maps_position_embeddings = [
|
|
feature_map_position_embedding.flatten(2).permute(2, 0, 1)
|
|
for feature_map_position_embedding in feature_maps_position_embeddings
|
|
]
|
|
vision_outputs.fpn_hidden_states = feature_maps
|
|
vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
|
|
|
|
return vision_outputs
|
|
|
|
@check_model_inputs
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
input_points: torch.FloatTensor | None = None,
|
|
input_labels: torch.LongTensor | None = None,
|
|
input_boxes: torch.FloatTensor | None = None,
|
|
input_masks: torch.LongTensor | None = None,
|
|
image_embeddings: torch.FloatTensor | None = None,
|
|
multimask_output: bool = True,
|
|
attention_similarity: torch.FloatTensor | None = None,
|
|
target_embedding: torch.FloatTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Sam2ImageSegmentationOutput:
|
|
r"""
|
|
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
|
|
Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
|
|
better results. The points can be obtained by passing a list of list of list to the processor that will
|
|
create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
|
|
second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
|
|
per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
|
|
multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
|
|
coordinates of the point. If a different number of points is passed either for each image, or for each
|
|
mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
|
|
computation of the embedding will be skipped for these points using the labels.
|
|
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
|
|
Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
|
|
official implementation, there are 3 types of labels
|
|
|
|
- `1`: the point is a point that contains the object of interest
|
|
- `0`: the point is a point that does not contain the object of interest
|
|
- `-1`: the point corresponds to the background
|
|
|
|
We added the label:
|
|
|
|
- `-10`: the point is a padding point, thus should be ignored by the prompt encoder
|
|
|
|
The padding labels should be automatically done by the processor.
|
|
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
|
|
Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
|
|
much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
|
|
that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
|
|
size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
|
|
In the order (`x1`, `y1`, `x2`, `y2`):
|
|
|
|
- `x1`: the x coordinate of the top left point of the input box
|
|
- `y1`: the y coordinate of the top left point of the input box
|
|
- `x2`: the x coordinate of the bottom right point of the input box
|
|
- `y2`: the y coordinate of the bottom right point of the input box
|
|
input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
|
|
SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
|
|
generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
|
|
manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
|
|
image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
|
|
Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
|
|
efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
|
|
method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
|
|
multimask_output (`bool`, *optional*):
|
|
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
|
|
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
|
|
"best" mask, by specifying `multimask_output=False`.
|
|
attention_similarity (`torch.FloatTensor`, *optional*):
|
|
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
|
|
model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
|
|
target_embedding (`torch.FloatTensor`, *optional*):
|
|
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
|
|
the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
>>> from transformers import AutoModel, AutoProcessor
|
|
|
|
>>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny")
|
|
>>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny")
|
|
|
|
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... raw_image = Image.open(BytesIO(response.read())).convert("RGB")
|
|
>>> input_points = [[[400, 650]]] # 2D location of a window on the car
|
|
>>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
|
|
|
|
>>> # Get segmentation mask
|
|
>>> outputs = model(**inputs)
|
|
|
|
>>> # Postprocess masks
|
|
>>> masks = processor.post_process_masks(
|
|
... outputs.pred_masks, inputs["original_sizes"]
|
|
... )
|
|
```
|
|
"""
|
|
if not ((pixel_values is None) ^ (image_embeddings is None)):
|
|
raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
|
|
if input_points is not None and input_boxes is not None:
|
|
if input_points.shape[1] != input_boxes.shape[1]:
|
|
raise ValueError(
|
|
f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
|
|
)
|
|
|
|
image_positional_embeddings = self.get_image_wide_positional_embeddings()
|
|
# repeat with batch size
|
|
batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
|
|
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
|
|
|
|
vision_attentions = None
|
|
vision_hidden_states = None
|
|
|
|
if pixel_values is not None:
|
|
image_outputs: Sam2VisionEncoderOutput = self.get_image_features(pixel_values, return_dict=True, **kwargs)
|
|
feature_maps = image_outputs.fpn_hidden_states
|
|
vision_hidden_states = image_outputs.hidden_states
|
|
vision_attentions = image_outputs.attentions
|
|
|
|
# add no memory embedding to the last feature map
|
|
feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
|
|
|
|
# reshape feature maps to the same shape as the backbone feature sizes
|
|
image_embeddings = [
|
|
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
|
for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
|
|
]
|
|
|
|
if input_points is not None and input_labels is None:
|
|
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
|
|
|
|
if input_points is None and input_boxes is None:
|
|
# If no points are provide, pad with an empty point (with label -1)
|
|
input_points = torch.zeros(
|
|
batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
|
|
)
|
|
input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
|
|
|
|
if input_masks is not None:
|
|
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
|
# and feed it as a dense mask prompt into the SAM mask encoder
|
|
if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
|
|
input_masks = F.interpolate(
|
|
input_masks.float(),
|
|
size=self.prompt_encoder.mask_input_size,
|
|
align_corners=False,
|
|
mode="bilinear",
|
|
antialias=True, # use antialias for downsampling
|
|
).to(input_masks.dtype)
|
|
|
|
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
|
input_points=input_points,
|
|
input_labels=input_labels,
|
|
input_boxes=input_boxes,
|
|
input_masks=input_masks,
|
|
)
|
|
low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
|
|
image_embeddings=image_embeddings[-1],
|
|
image_positional_embeddings=image_positional_embeddings,
|
|
sparse_prompt_embeddings=sparse_embeddings,
|
|
dense_prompt_embeddings=dense_embeddings,
|
|
multimask_output=multimask_output,
|
|
high_resolution_features=image_embeddings[:-1],
|
|
attention_similarity=attention_similarity,
|
|
target_embedding=target_embedding,
|
|
**kwargs,
|
|
)
|
|
|
|
return Sam2ImageSegmentationOutput(
|
|
iou_scores=iou_scores,
|
|
pred_masks=low_res_multimasks,
|
|
object_score_logits=object_score_logits,
|
|
image_embeddings=image_embeddings,
|
|
vision_hidden_states=vision_hidden_states,
|
|
vision_attentions=vision_attentions,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"Sam2Model",
|
|
"Sam2VisionModel",
|
|
"Sam2PreTrainedModel",
|
|
"Sam2ImageProcessorFast",
|
|
"Sam2HieraDetModel",
|
|
]
|