You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1664 lines
72 KiB
1664 lines
72 KiB
# Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch DETR model."""
|
|
|
|
import math
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from ... import initialization as init
|
|
from ...activations import ACT2FN
|
|
from ...backbone_utils import load_backbone
|
|
from ...masking_utils import create_bidirectional_mask
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithCrossAttentions,
|
|
Seq2SeqModelOutput,
|
|
)
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
|
from ...utils import (
|
|
ModelOutput,
|
|
TransformersKwargs,
|
|
auto_docstring,
|
|
logging,
|
|
)
|
|
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
from .configuration_detr import DetrConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
|
|
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
|
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
|
"""
|
|
)
|
|
class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
|
r"""
|
|
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
|
used to compute the weighted average in the cross-attention heads.
|
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
|
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
|
layernorm.
|
|
"""
|
|
|
|
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
|
|
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
|
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
|
"""
|
|
)
|
|
class DetrModelOutput(Seq2SeqModelOutput):
|
|
r"""
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
|
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
|
layernorm.
|
|
"""
|
|
|
|
intermediate_hidden_states: torch.FloatTensor | None = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Output type of [`DetrForObjectDetection`].
|
|
"""
|
|
)
|
|
class DetrObjectDetectionOutput(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
|
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
|
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
|
scale-invariant IoU loss.
|
|
loss_dict (`Dict`, *optional*):
|
|
A dictionary containing the individual losses. Useful for logging.
|
|
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
|
|
Classification logits (including no-object) for all queries.
|
|
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
|
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
|
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
|
possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
|
|
unnormalized bounding boxes.
|
|
auxiliary_outputs (`list[Dict]`, *optional*):
|
|
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
|
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
|
`pred_boxes`) for each decoder layer.
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
"""
|
|
|
|
loss: torch.FloatTensor | None = None
|
|
loss_dict: dict | None = None
|
|
logits: torch.FloatTensor | None = None
|
|
pred_boxes: torch.FloatTensor | None = None
|
|
auxiliary_outputs: list[dict] | None = None
|
|
last_hidden_state: torch.FloatTensor | None = None
|
|
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
encoder_last_hidden_state: torch.FloatTensor | None = None
|
|
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Output type of [`DetrForSegmentation`].
|
|
"""
|
|
)
|
|
class DetrSegmentationOutput(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
|
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
|
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
|
scale-invariant IoU loss.
|
|
loss_dict (`Dict`, *optional*):
|
|
A dictionary containing the individual losses. Useful for logging.
|
|
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
|
|
Classification logits (including no-object) for all queries.
|
|
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
|
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
|
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
|
possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
|
|
unnormalized bounding boxes.
|
|
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
|
|
Segmentation masks logits for all queries. See also
|
|
[`~DetrImageProcessor.post_process_semantic_segmentation`] or
|
|
[`~DetrImageProcessor.post_process_instance_segmentation`]
|
|
[`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
|
|
segmentation masks respectively.
|
|
auxiliary_outputs (`list[Dict]`, *optional*):
|
|
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
|
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
|
`pred_boxes`) for each decoder layer.
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
"""
|
|
|
|
loss: torch.FloatTensor | None = None
|
|
loss_dict: dict | None = None
|
|
logits: torch.FloatTensor | None = None
|
|
pred_boxes: torch.FloatTensor | None = None
|
|
pred_masks: torch.FloatTensor | None = None
|
|
auxiliary_outputs: list[dict] | None = None
|
|
last_hidden_state: torch.FloatTensor | None = None
|
|
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
decoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
cross_attentions: tuple[torch.FloatTensor] | None = None
|
|
encoder_last_hidden_state: torch.FloatTensor | None = None
|
|
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
|
|
encoder_attentions: tuple[torch.FloatTensor] | None = None
|
|
|
|
|
|
class DetrFrozenBatchNorm2d(nn.Module):
|
|
"""
|
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
|
|
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
|
torchvision.models.resnet[18,34,50,101] produce nans.
|
|
"""
|
|
|
|
def __init__(self, n):
|
|
super().__init__()
|
|
self.register_buffer("weight", torch.ones(n))
|
|
self.register_buffer("bias", torch.zeros(n))
|
|
self.register_buffer("running_mean", torch.zeros(n))
|
|
self.register_buffer("running_var", torch.ones(n))
|
|
|
|
def _load_from_state_dict(
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
):
|
|
num_batches_tracked_key = prefix + "num_batches_tracked"
|
|
if num_batches_tracked_key in state_dict:
|
|
del state_dict[num_batches_tracked_key]
|
|
|
|
super()._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
)
|
|
|
|
def forward(self, x):
|
|
# move reshapes to the beginning
|
|
# to make it user-friendly
|
|
weight = self.weight.reshape(1, -1, 1, 1)
|
|
bias = self.bias.reshape(1, -1, 1, 1)
|
|
running_var = self.running_var.reshape(1, -1, 1, 1)
|
|
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
|
epsilon = 1e-5
|
|
scale = weight * (running_var + epsilon).rsqrt()
|
|
bias = bias - running_mean * scale
|
|
return x * scale + bias
|
|
|
|
|
|
def replace_batch_norm(model):
|
|
r"""
|
|
Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
|
|
|
|
Args:
|
|
model (torch.nn.Module):
|
|
input model
|
|
"""
|
|
for name, module in model.named_children():
|
|
if isinstance(module, nn.BatchNorm2d):
|
|
new_module = DetrFrozenBatchNorm2d(module.num_features)
|
|
|
|
if module.weight.device != torch.device("meta"):
|
|
new_module.weight.copy_(module.weight)
|
|
new_module.bias.copy_(module.bias)
|
|
new_module.running_mean.copy_(module.running_mean)
|
|
new_module.running_var.copy_(module.running_var)
|
|
|
|
model._modules[name] = new_module
|
|
|
|
if len(list(module.children())) > 0:
|
|
replace_batch_norm(module)
|
|
|
|
|
|
class DetrConvEncoder(nn.Module):
|
|
"""
|
|
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
|
|
|
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
|
|
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
|
|
backbone = load_backbone(config)
|
|
self.intermediate_channel_sizes = backbone.channels
|
|
|
|
# replace batch norm by frozen batch norm
|
|
with torch.no_grad():
|
|
replace_batch_norm(backbone)
|
|
|
|
# We used to load with timm library directly instead of the AutoBackbone API
|
|
# so we need to unwrap the `backbone._backbone` module to load weights without mismatch
|
|
is_timm_model = False
|
|
if hasattr(backbone, "_backbone"):
|
|
backbone = backbone._backbone
|
|
is_timm_model = True
|
|
self.model = backbone
|
|
|
|
backbone_model_type = config.backbone_config.model_type
|
|
if "resnet" in backbone_model_type:
|
|
for name, parameter in self.model.named_parameters():
|
|
if is_timm_model:
|
|
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
|
parameter.requires_grad_(False)
|
|
else:
|
|
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
|
|
parameter.requires_grad_(False)
|
|
|
|
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
|
# send pixel_values through the model to get list of feature maps
|
|
features = self.model(pixel_values)
|
|
if isinstance(features, dict):
|
|
features = features.feature_maps
|
|
|
|
out = []
|
|
for feature_map in features:
|
|
# downsample pixel_mask to match shape of corresponding feature_map
|
|
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
|
|
out.append((feature_map, mask))
|
|
return out
|
|
|
|
|
|
class DetrSinePositionEmbedding(nn.Module):
|
|
"""
|
|
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
|
need paper, generalized to work on images.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_position_features: int = 64,
|
|
temperature: int = 10000,
|
|
normalize: bool = False,
|
|
scale: float | None = None,
|
|
):
|
|
super().__init__()
|
|
if scale is not None and normalize is False:
|
|
raise ValueError("normalize should be True if scale is passed")
|
|
self.num_position_features = num_position_features
|
|
self.temperature = temperature
|
|
self.normalize = normalize
|
|
self.scale = 2 * math.pi if scale is None else scale
|
|
|
|
@compile_compatible_method_lru_cache(maxsize=1)
|
|
def forward(
|
|
self,
|
|
shape: torch.Size,
|
|
device: torch.device | str,
|
|
dtype: torch.dtype,
|
|
mask: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
if mask is None:
|
|
mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
|
|
y_embed = mask.cumsum(1, dtype=dtype)
|
|
x_embed = mask.cumsum(2, dtype=dtype)
|
|
if self.normalize:
|
|
eps = 1e-6
|
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
|
|
dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
|
|
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
|
|
|
|
pos_x = x_embed[:, :, :, None] / dim_t
|
|
pos_y = y_embed[:, :, :, None] / dim_t
|
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
# expected by the encoder
|
|
pos = pos.flatten(2).permute(0, 2, 1)
|
|
return pos
|
|
|
|
|
|
class DetrLearnedPositionEmbedding(nn.Module):
|
|
"""
|
|
This module learns positional embeddings up to a fixed maximum size.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim=256):
|
|
super().__init__()
|
|
self.row_embeddings = nn.Embedding(50, embedding_dim)
|
|
self.column_embeddings = nn.Embedding(50, embedding_dim)
|
|
|
|
@compile_compatible_method_lru_cache(maxsize=1)
|
|
def forward(
|
|
self,
|
|
shape: torch.Size,
|
|
device: torch.device | str,
|
|
dtype: torch.dtype,
|
|
mask: torch.Tensor | None = None,
|
|
):
|
|
height, width = shape[-2:]
|
|
width_values = torch.arange(width, device=device)
|
|
height_values = torch.arange(height, device=device)
|
|
x_emb = self.column_embeddings(width_values)
|
|
y_emb = self.row_embeddings(height_values)
|
|
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
|
|
pos = pos.permute(2, 0, 1)
|
|
pos = pos.unsqueeze(0)
|
|
pos = pos.repeat(shape[0], 1, 1, 1)
|
|
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
|
|
# expected by the encoder
|
|
pos = pos.flatten(2).permute(0, 2, 1)
|
|
return pos
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: torch.Tensor | None,
|
|
scaling: float | None = None,
|
|
dropout: float = 0.0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
if scaling is None:
|
|
scaling = query.size(-1) ** -0.5
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class DetrSelfAttention(nn.Module):
|
|
"""
|
|
Multi-headed self-attention from 'Attention Is All You Need' paper.
|
|
|
|
In DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: DetrConfig,
|
|
hidden_size: int,
|
|
num_attention_heads: int,
|
|
dropout: float = 0.0,
|
|
bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.head_dim = hidden_size // num_attention_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attention_dropout = dropout
|
|
self.is_causal = False
|
|
|
|
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_embeddings: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Position embeddings are added to both queries and keys (but not values).
|
|
"""
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
|
|
query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
|
|
query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class DetrCrossAttention(nn.Module):
|
|
"""
|
|
Multi-headed cross-attention from 'Attention Is All You Need' paper.
|
|
|
|
In DETR, queries get their own position embeddings, while keys get encoder position embeddings.
|
|
Values don't get any position embeddings.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: DetrConfig,
|
|
hidden_size: int,
|
|
num_attention_heads: int,
|
|
dropout: float = 0.0,
|
|
bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.head_dim = hidden_size // num_attention_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attention_dropout = dropout
|
|
self.is_causal = False
|
|
|
|
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
key_value_states: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_embeddings: torch.Tensor | None = None,
|
|
encoder_position_embeddings: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Position embeddings logic:
|
|
- Queries get position_embeddings
|
|
- Keys get encoder_position_embeddings
|
|
- Values don't get any position embeddings
|
|
"""
|
|
query_input_shape = hidden_states.shape[:-1]
|
|
query_hidden_shape = (*query_input_shape, -1, self.head_dim)
|
|
|
|
kv_input_shape = key_value_states.shape[:-1]
|
|
kv_hidden_shape = (*kv_input_shape, -1, self.head_dim)
|
|
|
|
query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
|
|
key_input = (
|
|
key_value_states + encoder_position_embeddings
|
|
if encoder_position_embeddings is not None
|
|
else key_value_states
|
|
)
|
|
|
|
query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2)
|
|
key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2)
|
|
value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class DetrMLP(nn.Module):
|
|
def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int):
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
|
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
|
self.activation_fn = ACT2FN[config.activation_function]
|
|
self.activation_dropout = config.activation_dropout
|
|
self.dropout = config.dropout
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
return hidden_states
|
|
|
|
|
|
class DetrEncoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: DetrConfig):
|
|
super().__init__()
|
|
self.hidden_size = config.d_model
|
|
self.self_attn = DetrSelfAttention(
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_attention_heads=config.encoder_attention_heads,
|
|
dropout=config.attention_dropout,
|
|
)
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
self.dropout = config.dropout
|
|
self.mlp = DetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
|
|
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
spatial_position_embeddings: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
values.
|
|
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
Spatial position embeddings (2D positional encodings of image locations), to be added to both
|
|
the queries and keys in self-attention (but not to values).
|
|
"""
|
|
residual = hidden_states
|
|
hidden_states, _ = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_embeddings=spatial_position_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
hidden_states = residual + hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
if self.training:
|
|
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class DetrDecoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: DetrConfig):
|
|
super().__init__()
|
|
self.hidden_size = config.d_model
|
|
|
|
self.self_attn = DetrSelfAttention(
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_attention_heads=config.decoder_attention_heads,
|
|
dropout=config.attention_dropout,
|
|
)
|
|
self.dropout = config.dropout
|
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
self.encoder_attn = DetrCrossAttention(
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_attention_heads=config.decoder_attention_heads,
|
|
dropout=config.attention_dropout,
|
|
)
|
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
|
|
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
spatial_position_embeddings: torch.Tensor | None = None,
|
|
object_queries_position_embeddings: torch.Tensor | None = None,
|
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
encoder_attention_mask: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
values.
|
|
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only
|
|
in the cross-attention layer (not to values).
|
|
object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
|
|
Position embeddings for the object query slots. In self-attention, these are added to both queries
|
|
and keys (not values). In cross-attention, these are added to queries only (not to keys or values).
|
|
encoder_hidden_states (`torch.FloatTensor`):
|
|
cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
values.
|
|
"""
|
|
residual = hidden_states
|
|
|
|
# Self Attention
|
|
hidden_states, _ = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
position_embeddings=object_queries_position_embeddings,
|
|
attention_mask=attention_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
hidden_states = residual + hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
# Cross-Attention Block
|
|
if encoder_hidden_states is not None:
|
|
residual = hidden_states
|
|
|
|
hidden_states, _ = self.encoder_attn(
|
|
hidden_states=hidden_states,
|
|
key_value_states=encoder_hidden_states,
|
|
attention_mask=encoder_attention_mask,
|
|
position_embeddings=object_queries_position_embeddings,
|
|
encoder_position_embeddings=spatial_position_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
hidden_states = residual + hidden_states
|
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class DetrConvBlock(nn.Module):
|
|
"""Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
|
|
|
|
def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
|
|
self.activation = ACT2FN[activation]
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.activation(self.norm(self.conv(x)))
|
|
|
|
|
|
class DetrFPNFusionStage(nn.Module):
|
|
"""Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
|
|
|
|
def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
|
|
super().__init__()
|
|
self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
|
|
self.refine = DetrConvBlock(current_channels, output_channels, activation)
|
|
|
|
def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
|
|
fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
|
|
|
|
Returns:
|
|
Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
|
|
"""
|
|
fpn_features = self.fpn_adapter(fpn_features)
|
|
features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
|
|
return self.refine(fpn_features + features)
|
|
|
|
|
|
class DetrMaskHeadSmallConv(nn.Module):
|
|
"""
|
|
Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
|
|
|
|
Combines attention maps (spatial localization) with encoder features (semantics) and progressively
|
|
upsamples through multiple scales, fusing with FPN features for high-resolution detail.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_channels: int,
|
|
fpn_channels: list[int],
|
|
hidden_size: int,
|
|
activation_function: str = "relu",
|
|
):
|
|
super().__init__()
|
|
if input_channels % 8 != 0:
|
|
raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
|
|
|
|
self.conv1 = DetrConvBlock(input_channels, input_channels, activation_function)
|
|
self.conv2 = DetrConvBlock(input_channels, hidden_size // 2, activation_function)
|
|
|
|
# Progressive channel reduction: /2 -> /4 -> /8 -> /16
|
|
self.fpn_stages = nn.ModuleList(
|
|
[
|
|
DetrFPNFusionStage(fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function),
|
|
DetrFPNFusionStage(fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function),
|
|
DetrFPNFusionStage(fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function),
|
|
]
|
|
)
|
|
|
|
self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
|
|
|
|
def forward(
|
|
self,
|
|
features: torch.Tensor,
|
|
attention_masks: torch.Tensor,
|
|
fpn_features: list[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
features: Encoder output features, shape (batch_size, hidden_size, H, W)
|
|
attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
|
|
fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
|
|
|
|
Returns:
|
|
Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
|
|
"""
|
|
num_queries = attention_masks.shape[1]
|
|
|
|
# Expand to (batch_size * num_queries) dimension
|
|
features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
|
|
attention_masks = attention_masks.flatten(0, 1)
|
|
fpn_features = [
|
|
fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
|
|
]
|
|
|
|
hidden_states = torch.cat([features, attention_masks], dim=1)
|
|
hidden_states = self.conv1(hidden_states)
|
|
hidden_states = self.conv2(hidden_states)
|
|
|
|
for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
|
|
hidden_states = fpn_stage(hidden_states, fpn_feat)
|
|
|
|
return self.output_conv(hidden_states)
|
|
|
|
|
|
class DetrMHAttentionMap(nn.Module):
|
|
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_attention_heads: int,
|
|
dropout: float = 0.0,
|
|
bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.head_dim = hidden_size // num_attention_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attention_dropout = dropout
|
|
|
|
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
|
|
def forward(
|
|
self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
|
|
):
|
|
query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
|
|
key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
|
|
|
|
query_states = self.q_proj(query_states).view(query_hidden_shape)
|
|
key_states = nn.functional.conv2d(
|
|
key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
|
|
).view(key_hidden_shape)
|
|
|
|
batch_size, num_queries, num_heads, head_dim = query_states.shape
|
|
_, _, _, height, width = key_states.shape
|
|
query_shape = (batch_size * num_heads, num_queries, head_dim)
|
|
key_shape = (batch_size * num_heads, height * width, head_dim)
|
|
attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
|
|
|
|
query = query_states.transpose(1, 2).contiguous().view(query_shape)
|
|
key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
|
|
|
|
attn_weights = (
|
|
(torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
|
|
return attn_weights
|
|
|
|
|
|
@auto_docstring
|
|
class DetrPreTrainedModel(PreTrainedModel):
|
|
config: DetrConfig
|
|
base_model_prefix = "model"
|
|
main_input_name = "pixel_values"
|
|
input_modalities = ("image",)
|
|
_no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
|
|
supports_gradient_checkpointing = True
|
|
_supports_sdpa = True
|
|
_supports_flash_attn = True
|
|
_supports_attention_backend = True
|
|
_supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
|
|
_keys_to_ignore_on_load_unexpected = [
|
|
r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
|
|
]
|
|
|
|
@torch.no_grad()
|
|
def _init_weights(self, module):
|
|
std = self.config.init_std
|
|
xavier_std = self.config.init_xavier_std
|
|
|
|
if isinstance(module, DetrMaskHeadSmallConv):
|
|
# DetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
|
|
for m in module.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
init.kaiming_uniform_(m.weight, a=1)
|
|
if m.bias is not None:
|
|
init.constant_(m.bias, 0)
|
|
elif isinstance(module, DetrMHAttentionMap):
|
|
init.zeros_(module.k_proj.bias)
|
|
init.zeros_(module.q_proj.bias)
|
|
init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
|
|
init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
|
|
elif isinstance(module, DetrLearnedPositionEmbedding):
|
|
init.uniform_(module.row_embeddings.weight)
|
|
init.uniform_(module.column_embeddings.weight)
|
|
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
init.zeros_(module.bias)
|
|
elif isinstance(module, nn.Embedding):
|
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
|
|
init.zeros_(module.weight[module.padding_idx])
|
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
|
init.ones_(module.weight)
|
|
init.zeros_(module.bias)
|
|
|
|
|
|
class DetrEncoder(DetrPreTrainedModel):
|
|
"""
|
|
Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
|
|
[`DetrEncoderLayer`] modules.
|
|
|
|
Args:
|
|
config (`DetrConfig`): Model configuration object.
|
|
"""
|
|
|
|
_can_record_outputs = {"hidden_states": DetrEncoderLayer, "attentions": DetrSelfAttention}
|
|
|
|
def __init__(self, config: DetrConfig):
|
|
super().__init__(config)
|
|
|
|
self.dropout = config.dropout
|
|
self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@check_model_inputs()
|
|
def forward(
|
|
self,
|
|
inputs_embeds=None,
|
|
attention_mask=None,
|
|
spatial_position_embeddings=None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> BaseModelOutput:
|
|
r"""
|
|
Args:
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for pixel features that are real (i.e. **not masked**),
|
|
- 0 for pixel features that are padding (i.e. **masked**).
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
|
|
"""
|
|
hidden_states = inputs_embeds
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
|
|
# expand attention_mask
|
|
if attention_mask is not None:
|
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
attention_mask = create_bidirectional_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
)
|
|
|
|
for encoder_layer in self.layers:
|
|
# we add spatial_position_embeddings as extra input to the encoder_layer
|
|
hidden_states = encoder_layer(
|
|
hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
|
|
)
|
|
|
|
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
|
|
|
|
class DetrDecoder(DetrPreTrainedModel):
|
|
"""
|
|
Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules,
|
|
which apply self-attention to the queries and cross-attention to the encoder's outputs.
|
|
|
|
Args:
|
|
config (`DetrConfig`): Model configuration object.
|
|
"""
|
|
|
|
_can_record_outputs = {
|
|
"hidden_states": DetrDecoderLayer,
|
|
"attentions": DetrSelfAttention,
|
|
"cross_attentions": DetrCrossAttention,
|
|
}
|
|
|
|
def __init__(self, config: DetrConfig):
|
|
super().__init__(config)
|
|
self.dropout = config.dropout
|
|
|
|
self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
# in DETR, the decoder uses layernorm after the last decoder layer output
|
|
self.layernorm = nn.LayerNorm(config.d_model)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@check_model_inputs()
|
|
def forward(
|
|
self,
|
|
inputs_embeds=None,
|
|
attention_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
spatial_position_embeddings=None,
|
|
object_queries_position_embeddings=None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> DetrDecoderOutput:
|
|
r"""
|
|
Args:
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
The query embeddings that are passed into the decoder.
|
|
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for queries that are **not masked**,
|
|
- 0 for queries that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
|
of the decoder.
|
|
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
|
|
Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
|
|
in `[0, 1]`:
|
|
|
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
|
|
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer.
|
|
object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
|
|
"""
|
|
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
|
|
# expand decoder attention mask (for self-attention on object queries)
|
|
if attention_mask is not None:
|
|
# [batch_size, num_queries] -> [batch_size, 1, num_queries, num_queries]
|
|
attention_mask = create_bidirectional_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
)
|
|
|
|
# expand encoder attention mask (for cross-attention on encoder outputs)
|
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
encoder_attention_mask = create_bidirectional_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=encoder_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
|
|
# optional intermediate hidden states
|
|
intermediate = () if self.config.auxiliary_loss else None
|
|
|
|
# decoder layers
|
|
|
|
for idx, decoder_layer in enumerate(self.layers):
|
|
hidden_states = decoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
spatial_position_embeddings,
|
|
object_queries_position_embeddings,
|
|
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
if self.config.auxiliary_loss:
|
|
hidden_states = self.layernorm(hidden_states)
|
|
intermediate += (hidden_states,)
|
|
|
|
# finally, apply layernorm
|
|
hidden_states = self.layernorm(hidden_states)
|
|
|
|
# stack intermediate decoder activations
|
|
if self.config.auxiliary_loss:
|
|
intermediate = torch.stack(intermediate)
|
|
|
|
return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
|
|
any specific head on top.
|
|
"""
|
|
)
|
|
class DetrModel(DetrPreTrainedModel):
|
|
def __init__(self, config: DetrConfig):
|
|
super().__init__(config)
|
|
|
|
self.backbone = DetrConvEncoder(config)
|
|
|
|
if config.position_embedding_type == "sine":
|
|
self.position_embedding = DetrSinePositionEmbedding(config.d_model // 2, normalize=True)
|
|
elif config.position_embedding_type == "learned":
|
|
self.position_embedding = DetrLearnedPositionEmbedding(config.d_model // 2)
|
|
else:
|
|
raise ValueError(f"Not supported {config.position_embedding_type}")
|
|
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
|
|
self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
|
|
|
self.encoder = DetrEncoder(config)
|
|
self.decoder = DetrDecoder(config)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def freeze_backbone(self):
|
|
for _, param in self.backbone.model.named_parameters():
|
|
param.requires_grad_(False)
|
|
|
|
def unfreeze_backbone(self):
|
|
for _, param in self.backbone.model.named_parameters():
|
|
param.requires_grad_(True)
|
|
|
|
@auto_docstring
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
pixel_mask: torch.LongTensor | None = None,
|
|
decoder_attention_mask: torch.FloatTensor | None = None,
|
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.FloatTensor] | DetrModelOutput:
|
|
r"""
|
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for queries that are **not masked**,
|
|
- 0 for queries that are **masked**.
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
|
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
embedded representation. Useful for tasks that require custom query initialization.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import AutoImageProcessor, DetrModel
|
|
>>> from PIL import Image
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read()))
|
|
|
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
|
>>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
|
|
|
|
>>> # prepare image for the model
|
|
>>> inputs = image_processor(images=image, return_tensors="pt")
|
|
|
|
>>> # forward pass
|
|
>>> outputs = model(**inputs)
|
|
|
|
>>> # the last hidden states are the final query embeddings of the Transformer decoder
|
|
>>> # these are of shape (batch_size, num_queries, hidden_size)
|
|
>>> last_hidden_states = outputs.last_hidden_state
|
|
>>> list(last_hidden_states.shape)
|
|
[1, 100, 256]
|
|
```"""
|
|
if pixel_values is None and inputs_embeds is None:
|
|
raise ValueError("You have to specify either pixel_values or inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
device = pixel_values.device
|
|
|
|
if pixel_mask is None:
|
|
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
|
vision_features = self.backbone(pixel_values, pixel_mask)
|
|
feature_map, mask = vision_features[-1]
|
|
|
|
# Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
|
|
# Position embeddings are already flattened to (batch_size, sequence_length, hidden_size) format
|
|
projected_feature_map = self.input_projection(feature_map)
|
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
spatial_position_embeddings = self.position_embedding(
|
|
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
)
|
|
flattened_mask = mask.flatten(1)
|
|
else:
|
|
batch_size = inputs_embeds.shape[0]
|
|
device = inputs_embeds.device
|
|
flattened_features = inputs_embeds
|
|
# When using inputs_embeds, we need to infer spatial dimensions for position embeddings
|
|
# Assume square feature map
|
|
seq_len = inputs_embeds.shape[1]
|
|
feat_dim = int(seq_len**0.5)
|
|
# Create position embeddings for the inferred spatial size
|
|
spatial_position_embeddings = self.position_embedding(
|
|
shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]),
|
|
device=device,
|
|
dtype=inputs_embeds.dtype,
|
|
)
|
|
# If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten.
|
|
if pixel_mask is not None:
|
|
mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0]
|
|
flattened_mask = mask.flatten(1)
|
|
else:
|
|
# If no mask provided, assume all positions are valid
|
|
flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long)
|
|
|
|
if encoder_outputs is None:
|
|
encoder_outputs = self.encoder(
|
|
inputs_embeds=flattened_features,
|
|
attention_mask=flattened_mask,
|
|
spatial_position_embeddings=spatial_position_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
|
|
batch_size, 1, 1
|
|
)
|
|
|
|
# Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
|
|
if decoder_inputs_embeds is not None:
|
|
queries = decoder_inputs_embeds
|
|
else:
|
|
queries = torch.zeros_like(object_queries_position_embeddings)
|
|
|
|
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
|
decoder_outputs = self.decoder(
|
|
inputs_embeds=queries,
|
|
attention_mask=decoder_attention_mask,
|
|
spatial_position_embeddings=spatial_position_embeddings,
|
|
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
encoder_attention_mask=flattened_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
return DetrModelOutput(
|
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
decoder_attentions=decoder_outputs.attentions,
|
|
cross_attentions=decoder_outputs.cross_attentions,
|
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
encoder_attentions=encoder_outputs.attentions,
|
|
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
|
)
|
|
|
|
|
|
class DetrMLPPredictionHead(nn.Module):
|
|
"""
|
|
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
|
height and width of a bounding box w.r.t. an image.
|
|
|
|
"""
|
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
super().__init__()
|
|
self.num_layers = num_layers
|
|
h = [hidden_dim] * (num_layers - 1)
|
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
|
|
def forward(self, x):
|
|
for i, layer in enumerate(self.layers):
|
|
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
return x
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
|
|
such as COCO detection.
|
|
"""
|
|
)
|
|
class DetrForObjectDetection(DetrPreTrainedModel):
|
|
def __init__(self, config: DetrConfig):
|
|
super().__init__(config)
|
|
|
|
# DETR encoder-decoder model
|
|
self.model = DetrModel(config)
|
|
|
|
# Object detection heads
|
|
self.class_labels_classifier = nn.Linear(
|
|
config.d_model, config.num_labels + 1
|
|
) # We add one for the "no object" class
|
|
self.bbox_predictor = DetrMLPPredictionHead(
|
|
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
|
|
)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
pixel_mask: torch.LongTensor | None = None,
|
|
decoder_attention_mask: torch.FloatTensor | None = None,
|
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
labels: list[dict] | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.FloatTensor] | DetrObjectDetectionOutput:
|
|
r"""
|
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for queries that are **not masked**,
|
|
- 0 for queries that are **masked**.
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
|
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
embedded representation. Useful for tasks that require custom query initialization.
|
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
|
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
|
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
|
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import AutoImageProcessor, DetrForObjectDetection
|
|
>>> import torch
|
|
>>> from PIL import Image
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read()))
|
|
|
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
|
>>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
|
|
|
>>> inputs = image_processor(images=image, return_tensors="pt")
|
|
>>> outputs = model(**inputs)
|
|
|
|
>>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
|
|
>>> target_sizes = torch.tensor([image.size[::-1]])
|
|
>>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
|
|
... 0
|
|
... ]
|
|
|
|
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
|
... box = [round(i, 2) for i in box.tolist()]
|
|
... print(
|
|
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
|
... f"{round(score.item(), 3)} at location {box}"
|
|
... )
|
|
Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
|
|
Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
|
|
Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
|
|
Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
|
|
Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
|
|
```"""
|
|
|
|
# First, sent images through DETR base model to obtain encoder + decoder outputs
|
|
outputs = self.model(
|
|
pixel_values,
|
|
pixel_mask=pixel_mask,
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
encoder_outputs=encoder_outputs,
|
|
inputs_embeds=inputs_embeds,
|
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
**kwargs,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
# class logits + predicted bounding boxes
|
|
logits = self.class_labels_classifier(sequence_output)
|
|
pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
|
|
|
|
loss, loss_dict, auxiliary_outputs = None, None, None
|
|
if labels is not None:
|
|
outputs_class, outputs_coord = None, None
|
|
if self.config.auxiliary_loss:
|
|
intermediate = outputs.intermediate_hidden_states
|
|
outputs_class = self.class_labels_classifier(intermediate)
|
|
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
|
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
|
)
|
|
|
|
return DetrObjectDetectionOutput(
|
|
loss=loss,
|
|
loss_dict=loss_dict,
|
|
logits=logits,
|
|
pred_boxes=pred_boxes,
|
|
auxiliary_outputs=auxiliary_outputs,
|
|
last_hidden_state=outputs.last_hidden_state,
|
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
|
decoder_attentions=outputs.decoder_attentions,
|
|
cross_attentions=outputs.cross_attentions,
|
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
|
encoder_attentions=outputs.encoder_attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
|
|
such as COCO panoptic.
|
|
"""
|
|
)
|
|
class DetrForSegmentation(DetrPreTrainedModel):
|
|
_checkpoint_conversion_mapping = {
|
|
"bbox_attention.q_linear": "bbox_attention.q_proj",
|
|
"bbox_attention.k_linear": "bbox_attention.k_proj",
|
|
# Mask head refactor
|
|
"mask_head.lay1": "mask_head.conv1.conv",
|
|
"mask_head.gn1": "mask_head.conv1.norm",
|
|
"mask_head.lay2": "mask_head.conv2.conv",
|
|
"mask_head.gn2": "mask_head.conv2.norm",
|
|
"mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter",
|
|
"mask_head.lay3": "mask_head.fpn_stages.0.refine.conv",
|
|
"mask_head.gn3": "mask_head.fpn_stages.0.refine.norm",
|
|
"mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter",
|
|
"mask_head.lay4": "mask_head.fpn_stages.1.refine.conv",
|
|
"mask_head.gn4": "mask_head.fpn_stages.1.refine.norm",
|
|
"mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter",
|
|
"mask_head.lay5": "mask_head.fpn_stages.2.refine.conv",
|
|
"mask_head.gn5": "mask_head.fpn_stages.2.refine.norm",
|
|
"mask_head.out_lay": "mask_head.output_conv",
|
|
}
|
|
|
|
def __init__(self, config: DetrConfig):
|
|
super().__init__(config)
|
|
|
|
# object detection model
|
|
self.detr = DetrForObjectDetection(config)
|
|
|
|
# segmentation head
|
|
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
|
|
intermediate_channel_sizes = self.detr.model.backbone.intermediate_channel_sizes
|
|
|
|
self.mask_head = DetrMaskHeadSmallConv(
|
|
input_channels=hidden_size + number_of_heads,
|
|
fpn_channels=intermediate_channel_sizes[::-1][-3:],
|
|
hidden_size=hidden_size,
|
|
activation_function=config.activation_function,
|
|
)
|
|
|
|
self.bbox_attention = DetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
pixel_mask: torch.LongTensor | None = None,
|
|
decoder_attention_mask: torch.FloatTensor | None = None,
|
|
encoder_outputs: torch.FloatTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
decoder_inputs_embeds: torch.FloatTensor | None = None,
|
|
labels: list[dict] | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.FloatTensor] | DetrSegmentationOutput:
|
|
r"""
|
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for queries that are **not masked**,
|
|
- 0 for queries that are **masked**.
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
|
|
multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
|
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
embedded representation. Useful for tasks that require custom query initialization.
|
|
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
|
Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
|
|
dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
|
|
bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves
|
|
should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a
|
|
`torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a
|
|
`torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> import io
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
>>> from PIL import Image
|
|
>>> import torch
|
|
>>> import numpy
|
|
|
|
>>> from transformers import AutoImageProcessor, DetrForSegmentation
|
|
>>> from transformers.image_transforms import rgb_to_id
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read()))
|
|
|
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
|
|
>>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
|
|
|
|
>>> # prepare image for the model
|
|
>>> inputs = image_processor(images=image, return_tensors="pt")
|
|
|
|
>>> # forward pass
|
|
>>> outputs = model(**inputs)
|
|
|
|
>>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
|
|
>>> # Segmentation results are returned as a list of dictionaries
|
|
>>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
|
|
|
|
>>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
|
|
>>> panoptic_seg = result[0]["segmentation"]
|
|
>>> panoptic_seg.shape
|
|
torch.Size([300, 500])
|
|
>>> # Get prediction score and segment_id to class_id mapping of each segment
|
|
>>> panoptic_segments_info = result[0]["segments_info"]
|
|
>>> len(panoptic_segments_info)
|
|
5
|
|
```"""
|
|
|
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
device = pixel_values.device
|
|
|
|
if pixel_mask is None:
|
|
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
|
|
|
vision_features = self.detr.model.backbone(pixel_values, pixel_mask)
|
|
feature_map, mask = vision_features[-1]
|
|
|
|
# Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
|
|
projected_feature_map = self.detr.model.input_projection(feature_map)
|
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
spatial_position_embeddings = self.detr.model.position_embedding(
|
|
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
|
|
)
|
|
flattened_mask = mask.flatten(1)
|
|
|
|
if encoder_outputs is None:
|
|
encoder_outputs = self.detr.model.encoder(
|
|
inputs_embeds=flattened_features,
|
|
attention_mask=flattened_mask,
|
|
spatial_position_embeddings=spatial_position_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
object_queries_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
|
|
batch_size, 1, 1
|
|
)
|
|
|
|
# Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
|
|
if decoder_inputs_embeds is not None:
|
|
queries = decoder_inputs_embeds
|
|
else:
|
|
queries = torch.zeros_like(object_queries_position_embeddings)
|
|
|
|
decoder_outputs = self.detr.model.decoder(
|
|
inputs_embeds=queries,
|
|
attention_mask=decoder_attention_mask,
|
|
spatial_position_embeddings=spatial_position_embeddings,
|
|
object_queries_position_embeddings=object_queries_position_embeddings,
|
|
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
encoder_attention_mask=flattened_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
sequence_output = decoder_outputs[0]
|
|
|
|
logits = self.detr.class_labels_classifier(sequence_output)
|
|
pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
|
|
|
|
height, width = feature_map.shape[-2:]
|
|
memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
|
|
batch_size, self.config.d_model, height, width
|
|
)
|
|
attention_mask = flattened_mask.view(batch_size, height, width)
|
|
|
|
if attention_mask is not None:
|
|
min_dtype = torch.finfo(memory.dtype).min
|
|
attention_mask = torch.where(
|
|
attention_mask.unsqueeze(1).unsqueeze(1),
|
|
torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
|
|
min_dtype,
|
|
)
|
|
|
|
bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
|
|
|
|
seg_masks = self.mask_head(
|
|
features=projected_feature_map,
|
|
attention_masks=bbox_mask,
|
|
fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
|
|
)
|
|
|
|
pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
|
|
|
|
loss, loss_dict, auxiliary_outputs = None, None, None
|
|
if labels is not None:
|
|
outputs_class, outputs_coord = None, None
|
|
if self.config.auxiliary_loss:
|
|
intermediate = decoder_outputs.intermediate_hidden_states
|
|
outputs_class = self.detr.class_labels_classifier(intermediate)
|
|
outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
|
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
|
logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
|
|
)
|
|
|
|
return DetrSegmentationOutput(
|
|
loss=loss,
|
|
loss_dict=loss_dict,
|
|
logits=logits,
|
|
pred_boxes=pred_boxes,
|
|
pred_masks=pred_masks,
|
|
auxiliary_outputs=auxiliary_outputs,
|
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
decoder_attentions=decoder_outputs.attentions,
|
|
cross_attentions=decoder_outputs.cross_attentions,
|
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
encoder_attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"DetrForObjectDetection",
|
|
"DetrForSegmentation",
|
|
"DetrModel",
|
|
"DetrPreTrainedModel",
|
|
]
|