# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Dilated Neighborhood Attention Transformer model.""" import math from dataclasses import dataclass import torch from torch import nn from ...activations import ACT2FN from ...backbone_utils import BackboneMixin from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, OptionalDependencyNotAvailable, auto_docstring, is_natten_available, logging, requires_backends, ) from .configuration_dinat import DinatConfig if is_natten_available(): from natten.functional import natten2dav, natten2dqkrpb else: def natten2dqkrpb(*args, **kwargs): raise OptionalDependencyNotAvailable() def natten2dav(*args, **kwargs): raise OptionalDependencyNotAvailable() logger = logging.get_logger(__name__) # drop_path and DinatDropPath are from the timm library. @dataclass @auto_docstring( custom_intro=""" Dinat encoder's outputs, with potential hidden states and attentions. """ ) class DinatEncoderOutput(ModelOutput): r""" reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. """ last_hidden_state: torch.FloatTensor | None = None hidden_states: tuple[torch.FloatTensor, ...] | None = None attentions: tuple[torch.FloatTensor, ...] | None = None reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None @dataclass @auto_docstring( custom_intro=""" Dinat model's outputs that also contains a pooling of the last hidden states. """ ) class DinatModelOutput(ModelOutput): r""" pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): Average pooling of the last layer hidden-state. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. """ last_hidden_state: torch.FloatTensor | None = None pooler_output: torch.FloatTensor | None = None hidden_states: tuple[torch.FloatTensor, ...] | None = None attentions: tuple[torch.FloatTensor, ...] | None = None reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None @dataclass @auto_docstring( custom_intro=""" Dinat outputs for image classification. """ ) class DinatImageClassifierOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Classification (or regression if config.num_labels==1) loss. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. """ loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None hidden_states: tuple[torch.FloatTensor, ...] | None = None attentions: tuple[torch.FloatTensor, ...] | None = None reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None class DinatEmbeddings(nn.Module): """ Construct the patch and position embeddings. """ def __init__(self, config): super().__init__() self.patch_embeddings = DinatPatchEmbeddings(config) self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, pixel_values: torch.FloatTensor | None) -> tuple[torch.Tensor]: embeddings = self.patch_embeddings(pixel_values) embeddings = self.norm(embeddings) embeddings = self.dropout(embeddings) return embeddings class DinatPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() patch_size = config.patch_size num_channels, hidden_size = config.num_channels, config.embed_dim self.num_channels = num_channels if patch_size == 4: pass else: # TODO: Support arbitrary patch sizes. raise ValueError("Dinat only supports patch size of 4 at the moment.") self.projection = nn.Sequential( nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), ) def forward(self, pixel_values: torch.FloatTensor | None) -> torch.Tensor: _, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) embeddings = self.projection(pixel_values) embeddings = embeddings.permute(0, 2, 3, 1) return embeddings class DinatDownsampler(nn.Module): """ Convolutional Downsampling Layer. Args: dim (`int`): Number of input channels. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): Normalization layer class. """ def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: super().__init__() self.dim = dim self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) self.norm = norm_layer(2 * dim) def forward(self, input_feature: torch.Tensor) -> torch.Tensor: input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) input_feature = self.norm(input_feature) return input_feature # Copied from transformers.models.beit.modeling_beit.drop_path def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize output = input.div(keep_prob) * random_tensor return output # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Dinat class DinatDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: float | None = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return drop_path(hidden_states, self.drop_prob, self.training) def extra_repr(self) -> str: return f"p={self.drop_prob}" class NeighborhoodAttention(nn.Module): def __init__(self, config, dim, num_heads, kernel_size, dilation): super().__init__() if dim % num_heads != 0: raise ValueError( f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" ) self.num_attention_heads = num_heads self.attention_head_size = int(dim / num_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.kernel_size = kernel_size self.dilation = dilation # rpb is learnable relative positional biases; same concept is used Swin. self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1))) self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def forward( self, hidden_states: torch.Tensor, output_attentions: bool | None = False, ) -> tuple[torch.Tensor]: batch_size, seq_length, _ = hidden_states.shape query_layer = ( self.query(hidden_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) .transpose(1, 2) ) key_layer = ( self.key(hidden_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) .transpose(1, 2) ) value_layer = ( self.value(hidden_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) .transpose(1, 2) ) # Apply the scale factor before computing attention weights. It's usually more efficient because # attention weights are typically a bigger tensor compared to query. # It gives identical results because scalars are commutable in matrix multiplication. query_layer = query_layer / math.sqrt(self.attention_head_size) # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases. attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, self.dilation) # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation) context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs class NeighborhoodAttentionOutput(nn.Module): def __init__(self, config, dim): super().__init__() self.dense = nn.Linear(dim, dim) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class NeighborhoodAttentionModule(nn.Module): def __init__(self, config, dim, num_heads, kernel_size, dilation): super().__init__() self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size, dilation) self.output = NeighborhoodAttentionOutput(config, dim) def forward( self, hidden_states: torch.Tensor, output_attentions: bool | None = False, ) -> tuple[torch.Tensor]: self_outputs = self.self(hidden_states, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class DinatIntermediate(nn.Module): def __init__(self, config, dim): super().__init__() self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class DinatOutput(nn.Module): def __init__(self, config, dim): super().__init__() self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class DinatLayer(nn.Module): def __init__(self, config, dim, num_heads, dilation, drop_path_rate=0.0): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.kernel_size = config.kernel_size self.dilation = dilation self.window_size = self.kernel_size * self.dilation self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = NeighborhoodAttentionModule( config, dim, num_heads, kernel_size=self.kernel_size, dilation=self.dilation ) self.drop_path = DinatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = DinatIntermediate(config, dim) self.output = DinatOutput(config, dim) self.layer_scale_parameters = ( nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True) if config.layer_scale_init_value > 0 else None ) def maybe_pad(self, hidden_states, height, width): window_size = self.window_size pad_values = (0, 0, 0, 0, 0, 0) if height < window_size or width < window_size: pad_l = pad_t = 0 pad_r = max(0, window_size - width) pad_b = max(0, window_size - height) pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b) hidden_states = nn.functional.pad(hidden_states, pad_values) return hidden_states, pad_values def forward( self, hidden_states: torch.Tensor, output_attentions: bool | None = False, ) -> tuple[torch.Tensor, torch.Tensor]: batch_size, height, width, channels = hidden_states.size() shortcut = hidden_states hidden_states = self.layernorm_before(hidden_states) # pad hidden_states if they are smaller than kernel size x dilation hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) _, height_pad, width_pad, _ = hidden_states.shape attention_outputs = self.attention(hidden_states, output_attentions=output_attentions) attention_output = attention_outputs[0] was_padded = pad_values[3] > 0 or pad_values[5] > 0 if was_padded: attention_output = attention_output[:, :height, :width, :].contiguous() if self.layer_scale_parameters is not None: attention_output = self.layer_scale_parameters[0] * attention_output hidden_states = shortcut + self.drop_path(attention_output) layer_output = self.layernorm_after(hidden_states) layer_output = self.output(self.intermediate(layer_output)) if self.layer_scale_parameters is not None: layer_output = self.layer_scale_parameters[1] * layer_output layer_output = hidden_states + self.drop_path(layer_output) layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) return layer_outputs class DinatStage(nn.Module): def __init__(self, config, dim, depth, num_heads, dilations, drop_path_rate, downsample): super().__init__() self.config = config self.dim = dim self.layers = nn.ModuleList( [ DinatLayer( config=config, dim=dim, num_heads=num_heads, dilation=dilations[i], drop_path_rate=drop_path_rate[i], ) for i in range(depth) ] ) # patch merging layer if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm) else: self.downsample = None self.pointing = False def forward( self, hidden_states: torch.Tensor, output_attentions: bool | None = False, ) -> tuple[torch.Tensor]: _, height, width, _ = hidden_states.size() for i, layer_module in enumerate(self.layers): layer_outputs = layer_module(hidden_states, output_attentions) hidden_states = layer_outputs[0] hidden_states_before_downsampling = hidden_states if self.downsample is not None: hidden_states = self.downsample(hidden_states_before_downsampling) stage_outputs = (hidden_states, hidden_states_before_downsampling) if output_attentions: stage_outputs += layer_outputs[1:] return stage_outputs class DinatEncoder(nn.Module): def __init__(self, config): super().__init__() self.num_levels = len(config.depths) self.config = config dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] self.levels = nn.ModuleList( [ DinatStage( config=config, dim=int(config.embed_dim * 2**i_layer), depth=config.depths[i_layer], num_heads=config.num_heads[i_layer], dilations=config.dilations[i_layer], drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], downsample=DinatDownsampler if (i_layer < self.num_levels - 1) else None, ) for i_layer in range(self.num_levels) ] ) def forward( self, hidden_states: torch.Tensor, output_attentions: bool | None = False, output_hidden_states: bool | None = False, output_hidden_states_before_downsampling: bool | None = False, return_dict: bool | None = True, ) -> tuple | DinatEncoderOutput: all_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if output_hidden_states: # rearrange b h w c -> b c h w reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2) all_hidden_states += (hidden_states,) all_reshaped_hidden_states += (reshaped_hidden_state,) for i, layer_module in enumerate(self.levels): layer_outputs = layer_module(hidden_states, output_attentions) hidden_states = layer_outputs[0] hidden_states_before_downsampling = layer_outputs[1] if output_hidden_states and output_hidden_states_before_downsampling: # rearrange b h w c -> b c h w reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2) all_hidden_states += (hidden_states_before_downsampling,) all_reshaped_hidden_states += (reshaped_hidden_state,) elif output_hidden_states and not output_hidden_states_before_downsampling: # rearrange b h w c -> b c h w reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2) all_hidden_states += (hidden_states,) all_reshaped_hidden_states += (reshaped_hidden_state,) if output_attentions: all_self_attentions += layer_outputs[2:] if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return DinatEncoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, reshaped_hidden_states=all_reshaped_hidden_states, ) @auto_docstring class DinatPreTrainedModel(PreTrainedModel): config: DinatConfig base_model_prefix = "dinat" main_input_name = "pixel_values" input_modalities = ("image",) @auto_docstring class DinatModel(DinatPreTrainedModel): def __init__(self, config, add_pooling_layer=True): r""" add_pooling_layer (bool, *optional*, defaults to `True`): Whether to add a pooling layer """ super().__init__(config) requires_backends(self, ["natten"]) self.config = config self.num_levels = len(config.depths) self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1)) self.embeddings = DinatEmbeddings(config) self.encoder = DinatEncoder(config) self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.patch_embeddings @auto_docstring def forward( self, pixel_values: torch.FloatTensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs, ) -> tuple | DinatModelOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") embedding_output = self.embeddings(pixel_values) encoder_outputs = self.encoder( embedding_output, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) pooled_output = None if self.pooler is not None: pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2)) pooled_output = torch.flatten(pooled_output, 1) if not return_dict: output = (sequence_output, pooled_output) + encoder_outputs[1:] return output return DinatModelOutput( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, ) @auto_docstring( custom_intro=""" Dinat Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. """ ) class DinatForImageClassification(DinatPreTrainedModel): def __init__(self, config): super().__init__(config) requires_backends(self, ["natten"]) self.num_labels = config.num_labels self.dinat = DinatModel(config) # Classifier head self.classifier = ( nn.Linear(self.dinat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() ) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, pixel_values: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs, ) -> tuple | DinatImageClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.dinat( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] logits = self.classifier(pooled_output) loss = None if labels is not None: loss = self.loss_function(labels, logits, self.config) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return DinatImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, reshaped_hidden_states=outputs.reshaped_hidden_states, ) @auto_docstring( custom_intro=""" NAT backbone, to be used with frameworks like DETR and MaskFormer. """ ) class DinatBackbone(BackboneMixin, DinatPreTrainedModel): def __init__(self, config): super().__init__(config) requires_backends(self, ["natten"]) self.embeddings = DinatEmbeddings(config) self.encoder = DinatEncoder(config) self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] # Add layer norms to hidden states of out_features hidden_states_norms = {} for stage, num_channels in zip(self.out_features, self.channels): hidden_states_norms[stage] = nn.LayerNorm(num_channels) self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.patch_embeddings @auto_docstring def forward( self, pixel_values: torch.Tensor, output_hidden_states: bool | None = None, output_attentions: bool | None = None, return_dict: bool | None = None, **kwargs, ) -> BackboneOutput: r""" Examples: ```python >>> from transformers import AutoImageProcessor, AutoBackbone >>> import torch >>> from PIL import Image >>> import httpx >>> from io import BytesIO >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> with httpx.stream("GET", url) as response: ... image = Image.open(BytesIO(response.read())) >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224") >>> model = AutoBackbone.from_pretrained( ... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"] ... ) >>> inputs = processor(image, return_tensors="pt") >>> outputs = model(**inputs) >>> feature_maps = outputs.feature_maps >>> list(feature_maps[-1].shape) [1, 512, 7, 7] ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions embedding_output = self.embeddings(pixel_values) outputs = self.encoder( embedding_output, output_attentions=output_attentions, output_hidden_states=True, output_hidden_states_before_downsampling=True, return_dict=True, ) hidden_states = outputs.reshaped_hidden_states feature_maps = () for stage, hidden_state in zip(self.stage_names, hidden_states): if stage in self.out_features: batch_size, num_channels, height, width = hidden_state.shape hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() hidden_state = hidden_state.view(batch_size, height * width, num_channels) hidden_state = self.hidden_states_norms[stage](hidden_state) hidden_state = hidden_state.view(batch_size, height, width, num_channels) hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() feature_maps += (hidden_state,) if not return_dict: output = (feature_maps,) if output_hidden_states: output += (outputs.hidden_states,) return output return BackboneOutput( feature_maps=feature_maps, hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions, ) __all__ = ["DinatForImageClassification", "DinatModel", "DinatPreTrainedModel", "DinatBackbone"]