You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1664 lines
72 KiB

# Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch DETR model."""
import math
from collections.abc import Callable
from dataclasses import dataclass
import torch
import torch.nn as nn
from ... import initialization as init
from ...activations import ACT2FN
from ...backbone_utils import load_backbone
from ...masking_utils import create_bidirectional_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithCrossAttentions,
Seq2SeqModelOutput,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import compile_compatible_method_lru_cache
from ...utils import (
ModelOutput,
TransformersKwargs,
auto_docstring,
logging,
)
from ...utils.generic import can_return_tuple, check_model_inputs
from .configuration_detr import DetrConfig
logger = logging.get_logger(__name__)
@dataclass
@auto_docstring(
custom_intro="""
Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
"""
)
class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
r"""
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
used to compute the weighted average in the cross-attention heads.
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
layernorm.
"""
intermediate_hidden_states: torch.FloatTensor | None = None
@dataclass
@auto_docstring(
custom_intro="""
Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
"""
)
class DetrModelOutput(Seq2SeqModelOutput):
r"""
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
layernorm.
"""
intermediate_hidden_states: torch.FloatTensor | None = None
@dataclass
@auto_docstring(
custom_intro="""
Output type of [`DetrForObjectDetection`].
"""
)
class DetrObjectDetectionOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
scale-invariant IoU loss.
loss_dict (`Dict`, *optional*):
A dictionary containing the individual losses. Useful for logging.
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
Classification logits (including no-object) for all queries.
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
unnormalized bounding boxes.
auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
`pred_boxes`) for each decoder layer.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
"""
loss: torch.FloatTensor | None = None
loss_dict: dict | None = None
logits: torch.FloatTensor | None = None
pred_boxes: torch.FloatTensor | None = None
auxiliary_outputs: list[dict] | None = None
last_hidden_state: torch.FloatTensor | None = None
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
decoder_attentions: tuple[torch.FloatTensor] | None = None
cross_attentions: tuple[torch.FloatTensor] | None = None
encoder_last_hidden_state: torch.FloatTensor | None = None
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
encoder_attentions: tuple[torch.FloatTensor] | None = None
@dataclass
@auto_docstring(
custom_intro="""
Output type of [`DetrForSegmentation`].
"""
)
class DetrSegmentationOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
scale-invariant IoU loss.
loss_dict (`Dict`, *optional*):
A dictionary containing the individual losses. Useful for logging.
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
Classification logits (including no-object) for all queries.
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
unnormalized bounding boxes.
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
Segmentation masks logits for all queries. See also
[`~DetrImageProcessor.post_process_semantic_segmentation`] or
[`~DetrImageProcessor.post_process_instance_segmentation`]
[`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
segmentation masks respectively.
auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
`pred_boxes`) for each decoder layer.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
"""
loss: torch.FloatTensor | None = None
loss_dict: dict | None = None
logits: torch.FloatTensor | None = None
pred_boxes: torch.FloatTensor | None = None
pred_masks: torch.FloatTensor | None = None
auxiliary_outputs: list[dict] | None = None
last_hidden_state: torch.FloatTensor | None = None
decoder_hidden_states: tuple[torch.FloatTensor] | None = None
decoder_attentions: tuple[torch.FloatTensor] | None = None
cross_attentions: tuple[torch.FloatTensor] | None = None
encoder_last_hidden_state: torch.FloatTensor | None = None
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
encoder_attentions: tuple[torch.FloatTensor] | None = None
class DetrFrozenBatchNorm2d(nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
torchvision.models.resnet[18,34,50,101] produce nans.
"""
def __init__(self, n):
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
# move reshapes to the beginning
# to make it user-friendly
weight = self.weight.reshape(1, -1, 1, 1)
bias = self.bias.reshape(1, -1, 1, 1)
running_var = self.running_var.reshape(1, -1, 1, 1)
running_mean = self.running_mean.reshape(1, -1, 1, 1)
epsilon = 1e-5
scale = weight * (running_var + epsilon).rsqrt()
bias = bias - running_mean * scale
return x * scale + bias
def replace_batch_norm(model):
r"""
Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
Args:
model (torch.nn.Module):
input model
"""
for name, module in model.named_children():
if isinstance(module, nn.BatchNorm2d):
new_module = DetrFrozenBatchNorm2d(module.num_features)
if module.weight.device != torch.device("meta"):
new_module.weight.copy_(module.weight)
new_module.bias.copy_(module.bias)
new_module.running_mean.copy_(module.running_mean)
new_module.running_var.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
class DetrConvEncoder(nn.Module):
"""
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
"""
def __init__(self, config):
super().__init__()
self.config = config
backbone = load_backbone(config)
self.intermediate_channel_sizes = backbone.channels
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
# We used to load with timm library directly instead of the AutoBackbone API
# so we need to unwrap the `backbone._backbone` module to load weights without mismatch
is_timm_model = False
if hasattr(backbone, "_backbone"):
backbone = backbone._backbone
is_timm_model = True
self.model = backbone
backbone_model_type = config.backbone_config.model_type
if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters():
if is_timm_model:
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
else:
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
parameter.requires_grad_(False)
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
# send pixel_values through the model to get list of feature maps
features = self.model(pixel_values)
if isinstance(features, dict):
features = features.feature_maps
out = []
for feature_map in features:
# downsample pixel_mask to match shape of corresponding feature_map
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
out.append((feature_map, mask))
return out
class DetrSinePositionEmbedding(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
need paper, generalized to work on images.
"""
def __init__(
self,
num_position_features: int = 64,
temperature: int = 10000,
normalize: bool = False,
scale: float | None = None,
):
super().__init__()
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
self.num_position_features = num_position_features
self.temperature = temperature
self.normalize = normalize
self.scale = 2 * math.pi if scale is None else scale
@compile_compatible_method_lru_cache(maxsize=1)
def forward(
self,
shape: torch.Size,
device: torch.device | str,
dtype: torch.dtype,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
if mask is None:
mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
y_embed = mask.cumsum(1, dtype=dtype)
x_embed = mask.cumsum(2, dtype=dtype)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
# expected by the encoder
pos = pos.flatten(2).permute(0, 2, 1)
return pos
class DetrLearnedPositionEmbedding(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, embedding_dim=256):
super().__init__()
self.row_embeddings = nn.Embedding(50, embedding_dim)
self.column_embeddings = nn.Embedding(50, embedding_dim)
@compile_compatible_method_lru_cache(maxsize=1)
def forward(
self,
shape: torch.Size,
device: torch.device | str,
dtype: torch.dtype,
mask: torch.Tensor | None = None,
):
height, width = shape[-2:]
width_values = torch.arange(width, device=device)
height_values = torch.arange(height, device=device)
x_emb = self.column_embeddings(width_values)
y_emb = self.row_embeddings(height_values)
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
pos = pos.permute(2, 0, 1)
pos = pos.unsqueeze(0)
pos = pos.repeat(shape[0], 1, 1, 1)
# Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
# expected by the encoder
pos = pos.flatten(2).permute(0, 2, 1)
return pos
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float | None = None,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class DetrSelfAttention(nn.Module):
"""
Multi-headed self-attention from 'Attention Is All You Need' paper.
In DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
"""
def __init__(
self,
config: DetrConfig,
hidden_size: int,
num_attention_heads: int,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
self.config = config
self.head_dim = hidden_size // num_attention_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = dropout
self.is_causal = False
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_embeddings: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Position embeddings are added to both queries and keys (but not values).
"""
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class DetrCrossAttention(nn.Module):
"""
Multi-headed cross-attention from 'Attention Is All You Need' paper.
In DETR, queries get their own position embeddings, while keys get encoder position embeddings.
Values don't get any position embeddings.
"""
def __init__(
self,
config: DetrConfig,
hidden_size: int,
num_attention_heads: int,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
self.config = config
self.head_dim = hidden_size // num_attention_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = dropout
self.is_causal = False
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_embeddings: torch.Tensor | None = None,
encoder_position_embeddings: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Position embeddings logic:
- Queries get position_embeddings
- Keys get encoder_position_embeddings
- Values don't get any position embeddings
"""
query_input_shape = hidden_states.shape[:-1]
query_hidden_shape = (*query_input_shape, -1, self.head_dim)
kv_input_shape = key_value_states.shape[:-1]
kv_hidden_shape = (*kv_input_shape, -1, self.head_dim)
query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
key_input = (
key_value_states + encoder_position_embeddings
if encoder_position_embeddings is not None
else key_value_states
)
query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2)
key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class DetrMLP(nn.Module):
def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int):
super().__init__()
self.fc1 = nn.Linear(hidden_size, intermediate_size)
self.fc2 = nn.Linear(intermediate_size, hidden_size)
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.dropout = config.dropout
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return hidden_states
class DetrEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DetrConfig):
super().__init__()
self.hidden_size = config.d_model
self.self_attn = DetrSelfAttention(
config=config,
hidden_size=self.hidden_size,
num_attention_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
)
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
self.dropout = config.dropout
self.mlp = DetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
spatial_position_embeddings: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
Spatial position embeddings (2D positional encodings of image locations), to be added to both
the queries and keys in self-attention (but not to values).
"""
residual = hidden_states
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_embeddings=spatial_position_embeddings,
**kwargs,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if self.training:
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
return hidden_states
class DetrDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DetrConfig):
super().__init__()
self.hidden_size = config.d_model
self.self_attn = DetrSelfAttention(
config=config,
hidden_size=self.hidden_size,
num_attention_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
)
self.dropout = config.dropout
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
self.encoder_attn = DetrCrossAttention(
config=config,
hidden_size=self.hidden_size,
num_attention_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
spatial_position_embeddings: torch.Tensor | None = None,
object_queries_position_embeddings: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
spatial_position_embeddings (`torch.FloatTensor`, *optional*):
Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only
in the cross-attention layer (not to values).
object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
Position embeddings for the object query slots. In self-attention, these are added to both queries
and keys (not values). In cross-attention, these are added to queries only (not to keys or values).
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
"""
residual = hidden_states
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=object_queries_position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states, _ = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_embeddings=object_queries_position_embeddings,
encoder_position_embeddings=spatial_position_embeddings,
**kwargs,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class DetrConvBlock(nn.Module):
"""Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
self.activation = ACT2FN[activation]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.activation(self.norm(self.conv(x)))
class DetrFPNFusionStage(nn.Module):
"""Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
super().__init__()
self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
self.refine = DetrConvBlock(current_channels, output_channels, activation)
def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
"""
Args:
features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
Returns:
Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
"""
fpn_features = self.fpn_adapter(fpn_features)
features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
return self.refine(fpn_features + features)
class DetrMaskHeadSmallConv(nn.Module):
"""
Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
Combines attention maps (spatial localization) with encoder features (semantics) and progressively
upsamples through multiple scales, fusing with FPN features for high-resolution detail.
"""
def __init__(
self,
input_channels: int,
fpn_channels: list[int],
hidden_size: int,
activation_function: str = "relu",
):
super().__init__()
if input_channels % 8 != 0:
raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
self.conv1 = DetrConvBlock(input_channels, input_channels, activation_function)
self.conv2 = DetrConvBlock(input_channels, hidden_size // 2, activation_function)
# Progressive channel reduction: /2 -> /4 -> /8 -> /16
self.fpn_stages = nn.ModuleList(
[
DetrFPNFusionStage(fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function),
DetrFPNFusionStage(fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function),
DetrFPNFusionStage(fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function),
]
)
self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
def forward(
self,
features: torch.Tensor,
attention_masks: torch.Tensor,
fpn_features: list[torch.Tensor],
) -> torch.Tensor:
"""
Args:
features: Encoder output features, shape (batch_size, hidden_size, H, W)
attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
Returns:
Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
"""
num_queries = attention_masks.shape[1]
# Expand to (batch_size * num_queries) dimension
features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
attention_masks = attention_masks.flatten(0, 1)
fpn_features = [
fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
]
hidden_states = torch.cat([features, attention_masks], dim=1)
hidden_states = self.conv1(hidden_states)
hidden_states = self.conv2(hidden_states)
for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
hidden_states = fpn_stage(hidden_states, fpn_feat)
return self.output_conv(hidden_states)
class DetrMHAttentionMap(nn.Module):
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
self.head_dim = hidden_size // num_attention_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = dropout
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
def forward(
self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
):
query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
query_states = self.q_proj(query_states).view(query_hidden_shape)
key_states = nn.functional.conv2d(
key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
).view(key_hidden_shape)
batch_size, num_queries, num_heads, head_dim = query_states.shape
_, _, _, height, width = key_states.shape
query_shape = (batch_size * num_heads, num_queries, head_dim)
key_shape = (batch_size * num_heads, height * width, head_dim)
attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
query = query_states.transpose(1, 2).contiguous().view(query_shape)
key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
attn_weights = (
(torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
return attn_weights
@auto_docstring
class DetrPreTrainedModel(PreTrainedModel):
config: DetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
input_modalities = ("image",)
_no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn = True
_supports_attention_backend = True
_supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
_keys_to_ignore_on_load_unexpected = [
r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
]
@torch.no_grad()
def _init_weights(self, module):
std = self.config.init_std
xavier_std = self.config.init_xavier_std
if isinstance(module, DetrMaskHeadSmallConv):
# DetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_uniform_(m.weight, a=1)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(module, DetrMHAttentionMap):
init.zeros_(module.k_proj.bias)
init.zeros_(module.q_proj.bias)
init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
elif isinstance(module, DetrLearnedPositionEmbedding):
init.uniform_(module.row_embeddings.weight)
init.uniform_(module.column_embeddings.weight)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
init.normal_(module.weight, mean=0.0, std=std)
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
init.zeros_(module.weight[module.padding_idx])
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
init.ones_(module.weight)
init.zeros_(module.bias)
class DetrEncoder(DetrPreTrainedModel):
"""
Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
[`DetrEncoderLayer`] modules.
Args:
config (`DetrConfig`): Model configuration object.
"""
_can_record_outputs = {"hidden_states": DetrEncoderLayer, "attentions": DetrSelfAttention}
def __init__(self, config: DetrConfig):
super().__init__(config)
self.dropout = config.dropout
self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs()
def forward(
self,
inputs_embeds=None,
attention_mask=None,
spatial_position_embeddings=None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
- 1 for pixel features that are real (i.e. **not masked**),
- 0 for pixel features that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
"""
hidden_states = inputs_embeds
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask
if attention_mask is not None:
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
)
for encoder_layer in self.layers:
# we add spatial_position_embeddings as extra input to the encoder_layer
hidden_states = encoder_layer(
hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
)
return BaseModelOutput(last_hidden_state=hidden_states)
class DetrDecoder(DetrPreTrainedModel):
"""
Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules,
which apply self-attention to the queries and cross-attention to the encoder's outputs.
Args:
config (`DetrConfig`): Model configuration object.
"""
_can_record_outputs = {
"hidden_states": DetrDecoderLayer,
"attentions": DetrSelfAttention,
"cross_attentions": DetrCrossAttention,
}
def __init__(self, config: DetrConfig):
super().__init__(config)
self.dropout = config.dropout
self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
# in DETR, the decoder uses layernorm after the last decoder layer output
self.layernorm = nn.LayerNorm(config.d_model)
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs()
def forward(
self,
inputs_embeds=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
spatial_position_embeddings=None,
object_queries_position_embeddings=None,
**kwargs: Unpack[TransformersKwargs],
) -> DetrDecoderOutput:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
The query embeddings that are passed into the decoder.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
- 1 for queries that are **not masked**,
- 0 for queries that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer.
object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
"""
if inputs_embeds is not None:
hidden_states = inputs_embeds
# expand decoder attention mask (for self-attention on object queries)
if attention_mask is not None:
# [batch_size, num_queries] -> [batch_size, 1, num_queries, num_queries]
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
)
# expand encoder attention mask (for cross-attention on encoder outputs)
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
encoder_attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
)
# optional intermediate hidden states
intermediate = () if self.config.auxiliary_loss else None
# decoder layers
for idx, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask,
spatial_position_embeddings,
object_queries_position_embeddings,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
**kwargs,
)
if self.config.auxiliary_loss:
hidden_states = self.layernorm(hidden_states)
intermediate += (hidden_states,)
# finally, apply layernorm
hidden_states = self.layernorm(hidden_states)
# stack intermediate decoder activations
if self.config.auxiliary_loss:
intermediate = torch.stack(intermediate)
return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate)
@auto_docstring(
custom_intro="""
The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
any specific head on top.
"""
)
class DetrModel(DetrPreTrainedModel):
def __init__(self, config: DetrConfig):
super().__init__(config)
self.backbone = DetrConvEncoder(config)
if config.position_embedding_type == "sine":
self.position_embedding = DetrSinePositionEmbedding(config.d_model // 2, normalize=True)
elif config.position_embedding_type == "learned":
self.position_embedding = DetrLearnedPositionEmbedding(config.d_model // 2)
else:
raise ValueError(f"Not supported {config.position_embedding_type}")
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
self.encoder = DetrEncoder(config)
self.decoder = DetrDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def freeze_backbone(self):
for _, param in self.backbone.model.named_parameters():
param.requires_grad_(False)
def unfreeze_backbone(self):
for _, param in self.backbone.model.named_parameters():
param.requires_grad_(True)
@auto_docstring
@can_return_tuple
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
pixel_mask: torch.LongTensor | None = None,
decoder_attention_mask: torch.FloatTensor | None = None,
encoder_outputs: torch.FloatTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
decoder_inputs_embeds: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor] | DetrModelOutput:
r"""
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
- 1 for queries that are **not masked**,
- 0 for queries that are **masked**.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
embedded representation. Useful for tasks that require custom query initialization.
Examples:
```python
>>> from transformers import AutoImageProcessor, DetrModel
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
>>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
>>> # prepare image for the model
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> # forward pass
>>> outputs = model(**inputs)
>>> # the last hidden states are the final query embeddings of the Transformer decoder
>>> # these are of shape (batch_size, num_queries, hidden_size)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 100, 256]
```"""
if pixel_values is None and inputs_embeds is None:
raise ValueError("You have to specify either pixel_values or inputs_embeds")
if inputs_embeds is None:
batch_size, num_channels, height, width = pixel_values.shape
device = pixel_values.device
if pixel_mask is None:
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
vision_features = self.backbone(pixel_values, pixel_mask)
feature_map, mask = vision_features[-1]
# Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
# Position embeddings are already flattened to (batch_size, sequence_length, hidden_size) format
projected_feature_map = self.input_projection(feature_map)
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
spatial_position_embeddings = self.position_embedding(
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
)
flattened_mask = mask.flatten(1)
else:
batch_size = inputs_embeds.shape[0]
device = inputs_embeds.device
flattened_features = inputs_embeds
# When using inputs_embeds, we need to infer spatial dimensions for position embeddings
# Assume square feature map
seq_len = inputs_embeds.shape[1]
feat_dim = int(seq_len**0.5)
# Create position embeddings for the inferred spatial size
spatial_position_embeddings = self.position_embedding(
shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]),
device=device,
dtype=inputs_embeds.dtype,
)
# If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten.
if pixel_mask is not None:
mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0]
flattened_mask = mask.flatten(1)
else:
# If no mask provided, assume all positions are valid
flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long)
if encoder_outputs is None:
encoder_outputs = self.encoder(
inputs_embeds=flattened_features,
attention_mask=flattened_mask,
spatial_position_embeddings=spatial_position_embeddings,
**kwargs,
)
object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
batch_size, 1, 1
)
# Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
if decoder_inputs_embeds is not None:
queries = decoder_inputs_embeds
else:
queries = torch.zeros_like(object_queries_position_embeddings)
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
inputs_embeds=queries,
attention_mask=decoder_attention_mask,
spatial_position_embeddings=spatial_position_embeddings,
object_queries_position_embeddings=object_queries_position_embeddings,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=flattened_mask,
**kwargs,
)
return DetrModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
)
class DetrMLPPredictionHead(nn.Module):
"""
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
height and width of a bounding box w.r.t. an image.
"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
@auto_docstring(
custom_intro="""
DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
such as COCO detection.
"""
)
class DetrForObjectDetection(DetrPreTrainedModel):
def __init__(self, config: DetrConfig):
super().__init__(config)
# DETR encoder-decoder model
self.model = DetrModel(config)
# Object detection heads
self.class_labels_classifier = nn.Linear(
config.d_model, config.num_labels + 1
) # We add one for the "no object" class
self.bbox_predictor = DetrMLPPredictionHead(
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
)
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
@can_return_tuple
def forward(
self,
pixel_values: torch.FloatTensor,
pixel_mask: torch.LongTensor | None = None,
decoder_attention_mask: torch.FloatTensor | None = None,
encoder_outputs: torch.FloatTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
decoder_inputs_embeds: torch.FloatTensor | None = None,
labels: list[dict] | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor] | DetrObjectDetectionOutput:
r"""
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
- 1 for queries that are **not masked**,
- 0 for queries that are **masked**.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
embedded representation. Useful for tasks that require custom query initialization.
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
Examples:
```python
>>> from transformers import AutoImageProcessor, DetrForObjectDetection
>>> import torch
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
>>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
>>> target_sizes = torch.tensor([image.size[::-1]])
>>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
... 0
... ]
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
... box = [round(i, 2) for i in box.tolist()]
... print(
... f"Detected {model.config.id2label[label.item()]} with confidence "
... f"{round(score.item(), 3)} at location {box}"
... )
Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
```"""
# First, sent images through DETR base model to obtain encoder + decoder outputs
outputs = self.model(
pixel_values,
pixel_mask=pixel_mask,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
**kwargs,
)
sequence_output = outputs[0]
# class logits + predicted bounding boxes
logits = self.class_labels_classifier(sequence_output)
pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
loss, loss_dict, auxiliary_outputs = None, None, None
if labels is not None:
outputs_class, outputs_coord = None, None
if self.config.auxiliary_loss:
intermediate = outputs.intermediate_hidden_states
outputs_class = self.class_labels_classifier(intermediate)
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
loss, loss_dict, auxiliary_outputs = self.loss_function(
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
)
return DetrObjectDetectionOutput(
loss=loss,
loss_dict=loss_dict,
logits=logits,
pred_boxes=pred_boxes,
auxiliary_outputs=auxiliary_outputs,
last_hidden_state=outputs.last_hidden_state,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@auto_docstring(
custom_intro="""
DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
such as COCO panoptic.
"""
)
class DetrForSegmentation(DetrPreTrainedModel):
_checkpoint_conversion_mapping = {
"bbox_attention.q_linear": "bbox_attention.q_proj",
"bbox_attention.k_linear": "bbox_attention.k_proj",
# Mask head refactor
"mask_head.lay1": "mask_head.conv1.conv",
"mask_head.gn1": "mask_head.conv1.norm",
"mask_head.lay2": "mask_head.conv2.conv",
"mask_head.gn2": "mask_head.conv2.norm",
"mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter",
"mask_head.lay3": "mask_head.fpn_stages.0.refine.conv",
"mask_head.gn3": "mask_head.fpn_stages.0.refine.norm",
"mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter",
"mask_head.lay4": "mask_head.fpn_stages.1.refine.conv",
"mask_head.gn4": "mask_head.fpn_stages.1.refine.norm",
"mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter",
"mask_head.lay5": "mask_head.fpn_stages.2.refine.conv",
"mask_head.gn5": "mask_head.fpn_stages.2.refine.norm",
"mask_head.out_lay": "mask_head.output_conv",
}
def __init__(self, config: DetrConfig):
super().__init__(config)
# object detection model
self.detr = DetrForObjectDetection(config)
# segmentation head
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
intermediate_channel_sizes = self.detr.model.backbone.intermediate_channel_sizes
self.mask_head = DetrMaskHeadSmallConv(
input_channels=hidden_size + number_of_heads,
fpn_channels=intermediate_channel_sizes[::-1][-3:],
hidden_size=hidden_size,
activation_function=config.activation_function,
)
self.bbox_attention = DetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
@can_return_tuple
def forward(
self,
pixel_values: torch.FloatTensor,
pixel_mask: torch.LongTensor | None = None,
decoder_attention_mask: torch.FloatTensor | None = None,
encoder_outputs: torch.FloatTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
decoder_inputs_embeds: torch.FloatTensor | None = None,
labels: list[dict] | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor] | DetrSegmentationOutput:
r"""
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
- 1 for queries that are **not masked**,
- 0 for queries that are **masked**.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
embedded representation. Useful for tasks that require custom query initialization.
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves
should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a
`torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a
`torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`.
Examples:
```python
>>> import io
>>> import httpx
>>> from io import BytesIO
>>> from PIL import Image
>>> import torch
>>> import numpy
>>> from transformers import AutoImageProcessor, DetrForSegmentation
>>> from transformers.image_transforms import rgb_to_id
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
>>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
>>> # prepare image for the model
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> # forward pass
>>> outputs = model(**inputs)
>>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
>>> # Segmentation results are returned as a list of dictionaries
>>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
>>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
>>> panoptic_seg = result[0]["segmentation"]
>>> panoptic_seg.shape
torch.Size([300, 500])
>>> # Get prediction score and segment_id to class_id mapping of each segment
>>> panoptic_segments_info = result[0]["segments_info"]
>>> len(panoptic_segments_info)
5
```"""
batch_size, num_channels, height, width = pixel_values.shape
device = pixel_values.device
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=device)
vision_features = self.detr.model.backbone(pixel_values, pixel_mask)
feature_map, mask = vision_features[-1]
# Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
projected_feature_map = self.detr.model.input_projection(feature_map)
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
spatial_position_embeddings = self.detr.model.position_embedding(
shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
)
flattened_mask = mask.flatten(1)
if encoder_outputs is None:
encoder_outputs = self.detr.model.encoder(
inputs_embeds=flattened_features,
attention_mask=flattened_mask,
spatial_position_embeddings=spatial_position_embeddings,
**kwargs,
)
object_queries_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
batch_size, 1, 1
)
# Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
if decoder_inputs_embeds is not None:
queries = decoder_inputs_embeds
else:
queries = torch.zeros_like(object_queries_position_embeddings)
decoder_outputs = self.detr.model.decoder(
inputs_embeds=queries,
attention_mask=decoder_attention_mask,
spatial_position_embeddings=spatial_position_embeddings,
object_queries_position_embeddings=object_queries_position_embeddings,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=flattened_mask,
**kwargs,
)
sequence_output = decoder_outputs[0]
logits = self.detr.class_labels_classifier(sequence_output)
pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
height, width = feature_map.shape[-2:]
memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
batch_size, self.config.d_model, height, width
)
attention_mask = flattened_mask.view(batch_size, height, width)
if attention_mask is not None:
min_dtype = torch.finfo(memory.dtype).min
attention_mask = torch.where(
attention_mask.unsqueeze(1).unsqueeze(1),
torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
min_dtype,
)
bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
seg_masks = self.mask_head(
features=projected_feature_map,
attention_masks=bbox_mask,
fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
)
pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
loss, loss_dict, auxiliary_outputs = None, None, None
if labels is not None:
outputs_class, outputs_coord = None, None
if self.config.auxiliary_loss:
intermediate = decoder_outputs.intermediate_hidden_states
outputs_class = self.detr.class_labels_classifier(intermediate)
outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
loss, loss_dict, auxiliary_outputs = self.loss_function(
logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
)
return DetrSegmentationOutput(
loss=loss,
loss_dict=loss_dict,
logits=logits,
pred_boxes=pred_boxes,
pred_masks=pred_masks,
auxiliary_outputs=auxiliary_outputs,
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
__all__ = [
"DetrForObjectDetection",
"DetrForSegmentation",
"DetrModel",
"DetrPreTrainedModel",
]