# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch DINOv2 model.""" import collections.abc from collections.abc import Callable import torch from torch import nn from ... import initialization as init from ...activations import ACT2FN from ...backbone_utils import BackboneMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging, torch_int from ...utils.generic import can_return_tuple, check_model_inputs from .configuration_dinov2 import Dinov2Config logger = logging.get_logger(__name__) class Dinov2Embeddings(nn.Module): """ Construct the CLS token, mask token, position and patch embeddings. """ def __init__(self, config: Dinov2Config) -> None: super().__init__() self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) if config.use_mask_token: self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) self.patch_embeddings = Dinov2PatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.patch_size = config.patch_size self.use_mask_token = config.use_mask_token self.config = config def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) target_dtype = patch_pos_embed.dtype patch_pos_embed = nn.functional.interpolate( patch_pos_embed.to(torch.float32), size=(new_height, new_width), mode="bicubic", align_corners=False, ).to(dtype=target_dtype) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embeddings.projection.weight.dtype embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) if bool_masked_pos is not None and self.use_mask_token: embeddings = torch.where( bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings ) # add the [CLS] token to the embedded patch tokens cls_tokens = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) # add positional encoding to each token embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) embeddings = self.dropout(embeddings) return embeddings class Dinov2PatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: num_channels = pixel_values.shape[1] if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." f" Expected {self.num_channels} but got {num_channels}." ) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings # Copied from transformers.models.bert.modeling_bert.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, scaling: float | None = None, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attention_mask = attention_mask[:, :, :, : key.shape[-2]] attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2 class Dinov2SelfAttention(nn.Module): def __init__(self, config: Dinov2Config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size {config.hidden_size} is not a multiple of the number of attention " f"heads {config.num_attention_heads}." ) self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.dropout_prob = config.attention_probs_dropout_prob self.scaling = self.attention_head_size**-0.5 self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size = hidden_states.shape[0] new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2) value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) context_layer, attention_probs = attention_interface( self, query_layer, key_layer, value_layer, None, is_causal=self.is_causal, scaling=self.scaling, dropout=0.0 if not self.training else self.dropout_prob, ) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.reshape(new_context_layer_shape) return context_layer, attention_probs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 class Dinov2SelfOutput(nn.Module): """ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the layernorm applied before each block. """ def __init__(self, config: Dinov2Config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2 class Dinov2Attention(nn.Module): def __init__(self, config: Dinov2Config): super().__init__() self.attention = Dinov2SelfAttention(config) self.output = Dinov2SelfOutput(config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self_attn_output, _ = self.attention(hidden_states) output = self.output(self_attn_output, hidden_states) return output class Dinov2LayerScale(nn.Module): def __init__(self, config) -> None: super().__init__() self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return hidden_state * self.lambda1 # Copied from transformers.models.beit.modeling_beit.drop_path def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize output = input.div(keep_prob) * random_tensor return output # Copied from transformers.models.beit.modeling_beit.BeitDropPath class Dinov2DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: float | None = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return drop_path(hidden_states, self.drop_prob, self.training) def extra_repr(self) -> str: return f"p={self.drop_prob}" class Dinov2MLP(nn.Module): def __init__(self, config) -> None: super().__init__() in_features = out_features = config.hidden_size hidden_features = int(config.hidden_size * config.mlp_ratio) self.fc1 = nn.Linear(in_features, hidden_features, bias=True) if isinstance(config.hidden_act, str): self.activation = ACT2FN[config.hidden_act] else: self.activation = config.hidden_act self.fc2 = nn.Linear(hidden_features, out_features, bias=True) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.fc1(hidden_state) hidden_state = self.activation(hidden_state) hidden_state = self.fc2(hidden_state) return hidden_state class Dinov2SwiGLUFFN(nn.Module): def __init__(self, config) -> None: super().__init__() in_features = out_features = config.hidden_size hidden_features = int(config.hidden_size * config.mlp_ratio) hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) self.weights_out = nn.Linear(hidden_features, out_features, bias=True) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.weights_in(hidden_state) x1, x2 = hidden_state.chunk(2, dim=-1) hidden = nn.functional.silu(x1) * x2 return self.weights_out(hidden) class Dinov2Layer(GradientCheckpointingLayer): """This corresponds to the Block class in the original implementation.""" def __init__(self, config: Dinov2Config) -> None: super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attention = Dinov2Attention(config) self.layer_scale1 = Dinov2LayerScale(config) self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_swiglu_ffn: self.mlp = Dinov2SwiGLUFFN(config) else: self.mlp = Dinov2MLP(config) self.layer_scale2 = Dinov2LayerScale(config) def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: hidden_states_norm = self.norm1(hidden_states) self_attention_output = self.attention(hidden_states_norm) self_attention_output = self.layer_scale1(self_attention_output) # first residual connection hidden_states = self.drop_path(self_attention_output) + hidden_states # in Dinov2, layernorm is also applied after self-attention layer_output = self.norm2(hidden_states) layer_output = self.mlp(layer_output) layer_output = self.layer_scale2(layer_output) # second residual connection layer_output = self.drop_path(layer_output) + hidden_states return layer_output class Dinov2Encoder(nn.Module): def __init__(self, config: Dinov2Config): super().__init__() self.config = config self.layer = nn.ModuleList([Dinov2Layer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward(self, hidden_states: torch.Tensor, output_hidden_states: bool = False) -> BaseModelOutput: all_hidden_states = [hidden_states] if output_hidden_states else None for i, layer_module in enumerate(self.layer): hidden_states = layer_module(hidden_states) if all_hidden_states: all_hidden_states.append(hidden_states) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=tuple(all_hidden_states) if all_hidden_states else None, ) @auto_docstring class Dinov2PreTrainedModel(PreTrainedModel): config: Dinov2Config base_model_prefix = "dinov2" main_input_name = "pixel_values" input_modalities = ("image",) supports_gradient_checkpointing = True _no_split_modules = ["Dinov2Layer"] _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True _can_record_outputs = { "attentions": Dinov2SelfAttention, } @torch.no_grad() def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): init.zeros_(module.bias) init.ones_(module.weight) elif isinstance(module, Dinov2Embeddings): init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) if self.config.use_mask_token: init.zeros_(module.mask_token) elif isinstance(module, Dinov2LayerScale): init.constant_(module.lambda1, self.config.layerscale_value) @auto_docstring class Dinov2Model(Dinov2PreTrainedModel): def __init__(self, config: Dinov2Config): super().__init__(config) self.config = config self.embeddings = Dinov2Embeddings(config) self.encoder = Dinov2Encoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> Dinov2PatchEmbeddings: return self.embeddings.patch_embeddings @check_model_inputs(tie_last_hidden_states=False) @auto_docstring def forward( self, pixel_values: torch.Tensor | None = None, bool_masked_pos: torch.Tensor | None = None, output_hidden_states: bool | None = None, **kwargs, ) -> BaseModelOutputWithPooling: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for pre-training. """ if output_hidden_states is None: output_hidden_states = self.config.output_hidden_states if pixel_values is None: raise ValueError("You have to specify pixel_values") embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) encoder_outputs: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=output_hidden_states) sequence_output = encoder_outputs.last_hidden_state sequence_output = self.layernorm(sequence_output) pooled_output = sequence_output[:, 0, :] return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, ) @auto_docstring( custom_intro=""" Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. """ ) class Dinov2ForImageClassification(Dinov2PreTrainedModel): def __init__(self, config: Dinov2Config) -> None: super().__init__(config) self.num_labels = config.num_labels self.dinov2 = Dinov2Model(config) # Classifier head self.classifier = ( nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity() ) # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.Tensor | None = None, labels: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> ImageClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ outputs: BaseModelOutputWithPooling = self.dinov2(pixel_values, **kwargs) sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size cls_token = sequence_output[:, 0] patch_tokens = sequence_output[:, 1:] linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) logits = self.classifier(linear_input) loss = None if labels is not None: loss = self.loss_function(labels, logits, self.config, **kwargs) return ImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @auto_docstring( custom_intro=""" Dinov2 backbone, to be used with frameworks like DETR and MaskFormer. """ ) class Dinov2Backbone(BackboneMixin, Dinov2PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] self.embeddings = Dinov2Embeddings(config) self.encoder = Dinov2Encoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> Dinov2PatchEmbeddings: return self.embeddings.patch_embeddings @check_model_inputs @auto_docstring def forward( self, pixel_values: torch.Tensor, output_hidden_states: bool | None = None, **kwargs ) -> BackboneOutput: r""" Examples: ```python >>> from transformers import AutoImageProcessor, AutoBackbone >>> import torch >>> from PIL import Image >>> import httpx >>> from io import BytesIO >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> with httpx.stream("GET", url) as response: ... image = Image.open(BytesIO(response.read())) >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") >>> model = AutoBackbone.from_pretrained( ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] ... ) >>> inputs = processor(image, return_tensors="pt") >>> outputs = model(**inputs) >>> feature_maps = outputs.feature_maps >>> list(feature_maps[-1].shape) [1, 768, 16, 16] ```""" if output_hidden_states is None: output_hidden_states = self.config.output_hidden_states embedding_output = self.embeddings(pixel_values) output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True) hidden_states = output.hidden_states feature_maps = [] for stage, hidden_state in zip(self.stage_names, hidden_states): if stage in self.out_features: if self.config.apply_layernorm: hidden_state = self.layernorm(hidden_state) if self.config.reshape_hidden_states: hidden_state = hidden_state[:, 1:] # this was actually a bug in the original implementation that we copied here, # cause normally the order is height, width batch_size, _, height, width = pixel_values.shape patch_size = self.config.patch_size hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() feature_maps.append(hidden_state) return BackboneOutput( feature_maps=tuple(feature_maps), hidden_states=hidden_states if output_hidden_states else None, ) __all__ = ["Dinov2ForImageClassification", "Dinov2Model", "Dinov2PreTrainedModel", "Dinov2Backbone"]