You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
755 lines
30 KiB
755 lines
30 KiB
# Copyright 2022 Multimedia Computing Group, Nanjing University and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch VideoMAE (masked autoencoder) model."""
|
|
|
|
import collections.abc
|
|
from collections.abc import Callable
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import MSELoss
|
|
|
|
from ...activations import ACT2FN
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
|
|
from ...utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
from .configuration_videomae import VideoMAEConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Class for VideoMAEDecoder's outputs, with potential hidden states and attentions.
|
|
"""
|
|
)
|
|
class VideoMAEDecoderOutput(ModelOutput):
|
|
r"""
|
|
logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
|
|
Pixel reconstruction logits.
|
|
"""
|
|
|
|
logits: torch.FloatTensor | None = None
|
|
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
attentions: tuple[torch.FloatTensor] | None = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Class for VideoMAEForPreTraining's outputs, with potential hidden states and attentions.
|
|
"""
|
|
)
|
|
class VideoMAEForPreTrainingOutput(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor` of shape `(1,)`):
|
|
Pixel reconstruction loss.
|
|
logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
|
|
Pixel reconstruction logits.
|
|
"""
|
|
|
|
loss: torch.FloatTensor | None = None
|
|
logits: torch.FloatTensor | None = None
|
|
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
attentions: tuple[torch.FloatTensor] | None = None
|
|
|
|
|
|
# sin-cos position encoding
|
|
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
|
|
def get_sinusoid_encoding_table(n_position, d_hid):
|
|
"""Sinusoid position encoding table"""
|
|
|
|
# TODO: make it with torch instead of numpy
|
|
def get_position_angle_vec(position):
|
|
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
|
|
|
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
|
|
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
|
|
|
|
|
class VideoMAEEmbeddings(nn.Module):
|
|
"""
|
|
Construct the patch and position embeddings.
|
|
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.patch_embeddings = VideoMAEPatchEmbeddings(config)
|
|
self.num_patches = self.patch_embeddings.num_patches
|
|
# fixed sin-cos embedding
|
|
self.position_embeddings = get_sinusoid_encoding_table(self.num_patches, config.hidden_size)
|
|
self.config = config
|
|
|
|
def forward(self, pixel_values, bool_masked_pos):
|
|
# create patch embeddings
|
|
embeddings = self.patch_embeddings(pixel_values)
|
|
|
|
# add position embeddings
|
|
embeddings = embeddings + self.position_embeddings.detach().type_as(embeddings).to(
|
|
device=embeddings.device, copy=True
|
|
)
|
|
# only keep visible patches
|
|
# ~bool_masked_pos means visible
|
|
if bool_masked_pos is not None:
|
|
batch_size, _, num_channels = embeddings.shape
|
|
embeddings = embeddings[~bool_masked_pos]
|
|
embeddings = embeddings.reshape(batch_size, -1, num_channels)
|
|
|
|
return embeddings
|
|
|
|
|
|
class VideoMAEPatchEmbeddings(nn.Module):
|
|
"""
|
|
Video to Patch Embedding. This module turns a batch of videos of shape (batch_size, num_frames, num_channels,
|
|
height, width) into a tensor of shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
|
|
|
|
The seq_len (the number of patches) equals (number of frames // tubelet_size) * (height // patch_size) * (width //
|
|
patch_size).
|
|
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
image_size = config.image_size
|
|
patch_size = config.patch_size
|
|
num_channels = config.num_channels
|
|
hidden_size = config.hidden_size
|
|
num_frames = config.num_frames
|
|
tubelet_size = config.tubelet_size
|
|
|
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
|
self.image_size = image_size
|
|
self.patch_size = patch_size
|
|
self.tubelet_size = int(tubelet_size)
|
|
num_patches = (
|
|
(image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
|
|
)
|
|
self.num_channels = num_channels
|
|
self.num_patches = num_patches
|
|
self.projection = nn.Conv3d(
|
|
in_channels=num_channels,
|
|
out_channels=hidden_size,
|
|
kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
|
|
stride=(self.tubelet_size, patch_size[0], patch_size[1]),
|
|
)
|
|
|
|
def forward(self, pixel_values):
|
|
batch_size, num_frames, num_channels, height, width = pixel_values.shape
|
|
if num_channels != self.num_channels:
|
|
raise ValueError(
|
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
|
)
|
|
if height != self.image_size[0] or width != self.image_size[1]:
|
|
raise ValueError(
|
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
|
)
|
|
# permute to (batch_size, num_channels, num_frames, height, width)
|
|
pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
|
|
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
|
return embeddings
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: torch.Tensor | None,
|
|
scaling: float | None = None,
|
|
dropout: float = 0.0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
if scaling is None:
|
|
scaling = query.size(-1) ** -0.5
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class VideoMAESelfAttention(nn.Module):
|
|
def __init__(self, config: VideoMAEConfig) -> None:
|
|
super().__init__()
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
raise ValueError(
|
|
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
|
f"heads {config.num_attention_heads}."
|
|
)
|
|
self.config = config
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
self.dropout_prob = config.attention_probs_dropout_prob
|
|
self.scaling = self.attention_head_size**-0.5
|
|
self.is_causal = False
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
|
|
|
if config.qkv_bias:
|
|
self.q_bias = nn.Parameter(torch.zeros(self.all_head_size))
|
|
self.v_bias = nn.Parameter(torch.zeros(self.all_head_size))
|
|
else:
|
|
self.q_bias = None
|
|
self.v_bias = None
|
|
|
|
def forward(self, hidden_states: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
|
|
batch_size, seq_length, _ = hidden_states.shape
|
|
|
|
k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
|
|
keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
|
|
values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
|
|
queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
|
|
|
|
key_layer = keys.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
|
|
value_layer = values.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
|
|
query_layer = queries.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
|
|
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
context_layer, attention_probs = attention_interface(
|
|
self,
|
|
query_layer,
|
|
key_layer,
|
|
value_layer,
|
|
None,
|
|
is_causal=self.is_causal,
|
|
scaling=self.scaling,
|
|
dropout=0.0 if not self.training else self.dropout_prob,
|
|
)
|
|
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
context_layer = context_layer.reshape(new_context_layer_shape)
|
|
|
|
return context_layer, attention_probs
|
|
|
|
|
|
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
|
|
class VideoMAESelfOutput(nn.Module):
|
|
"""
|
|
The residual connection is defined in VideoMAELayer instead of here (as is the case with other models), due to the
|
|
layernorm applied before each block.
|
|
"""
|
|
|
|
def __init__(self, config: VideoMAEConfig):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VideoMAE
|
|
class VideoMAEAttention(nn.Module):
|
|
def __init__(self, config: VideoMAEConfig):
|
|
super().__init__()
|
|
self.attention = VideoMAESelfAttention(config)
|
|
self.output = VideoMAESelfOutput(config)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
self_attn_output, _ = self.attention(hidden_states)
|
|
output = self.output(self_attn_output, hidden_states)
|
|
return output
|
|
|
|
|
|
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
|
|
class VideoMAEIntermediate(nn.Module):
|
|
def __init__(self, config: VideoMAEConfig):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.intermediate_act_fn = config.hidden_act
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->VideoMAE
|
|
class VideoMAEOutput(nn.Module):
|
|
def __init__(self, config: VideoMAEConfig):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = hidden_states + input_tensor
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE
|
|
class VideoMAELayer(GradientCheckpointingLayer):
|
|
"""This corresponds to the Block class in the timm implementation."""
|
|
|
|
def __init__(self, config: VideoMAEConfig):
|
|
super().__init__()
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
self.seq_len_dim = 1
|
|
self.attention = VideoMAEAttention(config)
|
|
self.intermediate = VideoMAEIntermediate(config)
|
|
self.output = VideoMAEOutput(config)
|
|
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states_norm = self.layernorm_before(hidden_states)
|
|
attention_output = self.attention(hidden_states_norm)
|
|
|
|
# first residual connection
|
|
hidden_states = attention_output + hidden_states
|
|
|
|
# in VideoMAE, layernorm is also applied after self-attention
|
|
layer_output = self.layernorm_after(hidden_states)
|
|
layer_output = self.intermediate(layer_output)
|
|
|
|
# second residual connection is done here
|
|
layer_output = self.output(layer_output, hidden_states)
|
|
|
|
return layer_output
|
|
|
|
|
|
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VideoMAE
|
|
class VideoMAEEncoder(nn.Module):
|
|
def __init__(self, config: VideoMAEConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer = nn.ModuleList([VideoMAELayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
|
|
for i, layer_module in enumerate(self.layer):
|
|
hidden_states = layer_module(hidden_states)
|
|
|
|
return BaseModelOutput(last_hidden_state=hidden_states)
|
|
|
|
|
|
@auto_docstring
|
|
class VideoMAEPreTrainedModel(PreTrainedModel):
|
|
config: VideoMAEConfig
|
|
base_model_prefix = "videomae"
|
|
main_input_name = "pixel_values"
|
|
input_modalities = "video"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["VideoMAEEmbeddings", "VideoMAELayer"]
|
|
_supports_sdpa = True
|
|
_supports_flash_attn = True
|
|
_supports_flex_attn = True
|
|
_supports_attention_backend = True
|
|
_can_record_outputs = {
|
|
"hidden_states": VideoMAELayer,
|
|
"attentions": VideoMAESelfAttention,
|
|
}
|
|
|
|
|
|
@auto_docstring
|
|
class VideoMAEModel(VideoMAEPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.embeddings = VideoMAEEmbeddings(config)
|
|
self.encoder = VideoMAEEncoder(config)
|
|
|
|
if config.use_mean_pooling:
|
|
self.layernorm = None
|
|
else:
|
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.patch_embeddings
|
|
|
|
@check_model_inputs(tie_last_hidden_states=False)
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
bool_masked_pos: torch.BoolTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> BaseModelOutput:
|
|
r"""
|
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
|
|
batch must have the same number of masked patches. If `None`, then all patches are considered. Sequence
|
|
length is `(num_frames // tubelet_size) * (image_size // patch_size) ** 2`.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> import torch
|
|
>>> from transformers import VideoMAEVideoProcessor, VideoMAEModel
|
|
>>> from huggingface_hub import hf_hub_download
|
|
|
|
>>> # replace this with your own video file
|
|
>>> video_path = hf_hub_download(
|
|
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
|
|
... )
|
|
|
|
>>> video_processor = VideoMAEVideoProcessor.from_pretrained("MCG-NJU/videomae-base")
|
|
>>> model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base")
|
|
|
|
>>> # prepare video for the model
|
|
>>> inputs = video_processor(video_path, return_tensors="pt")
|
|
|
|
>>> # forward pass
|
|
>>> with torch.no_grad():
|
|
... outputs = model(**inputs)
|
|
|
|
>>> last_hidden_states = outputs.last_hidden_state
|
|
>>> list(last_hidden_states.shape)
|
|
[1, 1568, 768]
|
|
```"""
|
|
|
|
embedding_output = self.embeddings(pixel_values, bool_masked_pos)
|
|
|
|
encoder_outputs: BaseModelOutput = self.encoder(embedding_output)
|
|
sequence_output = encoder_outputs.last_hidden_state
|
|
if self.layernorm is not None:
|
|
sequence_output = self.layernorm(sequence_output)
|
|
|
|
return BaseModelOutput(last_hidden_state=sequence_output)
|
|
|
|
|
|
class VideoMAEDecoder(nn.Module):
|
|
def __init__(self, config: VideoMAEConfig):
|
|
super().__init__()
|
|
|
|
decoder_num_labels = config.num_channels * config.tubelet_size * config.patch_size**2
|
|
|
|
decoder_config = deepcopy(config)
|
|
decoder_config.hidden_size = config.decoder_hidden_size
|
|
decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
|
|
decoder_config.num_attention_heads = config.decoder_num_attention_heads
|
|
decoder_config.intermediate_size = config.decoder_intermediate_size
|
|
self.decoder_layers = nn.ModuleList(
|
|
[VideoMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
|
|
)
|
|
|
|
self.norm = nn.LayerNorm(config.decoder_hidden_size)
|
|
self.head = (
|
|
nn.Linear(config.decoder_hidden_size, decoder_num_labels) if decoder_num_labels > 0 else nn.Identity()
|
|
)
|
|
|
|
self.gradient_checkpointing = False
|
|
self.config = decoder_config
|
|
|
|
def forward(self, hidden_states: torch.Tensor, return_token_num: int):
|
|
# Apply transformer layers
|
|
for layer_module in self.decoder_layers:
|
|
hidden_states = layer_module(hidden_states)
|
|
|
|
hidden_states = hidden_states[:, -return_token_num:]
|
|
|
|
# predictor projection
|
|
hidden_states = self.norm(hidden_states)
|
|
logits = self.head(hidden_states)
|
|
|
|
return VideoMAEDecoderOutput(logits=logits)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The VideoMAE Model transformer with the decoder on top for self-supervised pre-training.
|
|
"""
|
|
)
|
|
class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.videomae = VideoMAEModel(config)
|
|
|
|
self.encoder_to_decoder = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=False)
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
|
|
self.position_embeddings = get_sinusoid_encoding_table(
|
|
self.videomae.embeddings.num_patches, config.decoder_hidden_size
|
|
)
|
|
|
|
self.decoder = VideoMAEDecoder(config)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
bool_masked_pos: torch.BoolTensor,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> VideoMAEForPreTrainingOutput:
|
|
r"""
|
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
|
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
|
|
batch must have the same number of masked patches. Sequence length is `(num_frames // tubelet_size) *
|
|
(image_size // patch_size) ** 2`.
|
|
|
|
Examples:
|
|
```python
|
|
>>> from transformers import AutoImageProcessor, VideoMAEForPreTraining
|
|
>>> import numpy as np
|
|
>>> import torch
|
|
|
|
>>> num_frames = 16
|
|
>>> video = list(np.random.randint(0, 256, (num_frames, 3, 224, 224)))
|
|
|
|
>>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
|
|
>>> model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base")
|
|
|
|
>>> pixel_values = image_processor(video, return_tensors="pt").pixel_values
|
|
|
|
>>> num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
|
|
>>> seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
|
|
>>> bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
|
|
|
|
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
|
>>> loss = outputs.loss
|
|
```"""
|
|
outputs: BaseModelOutput = self.videomae(pixel_values, bool_masked_pos=bool_masked_pos, **kwargs)
|
|
|
|
sequence_output = outputs.last_hidden_state
|
|
sequence_output = self.encoder_to_decoder(sequence_output)
|
|
|
|
# [batch_size, num_visible_patches, decoder_hidden_size]
|
|
batch_size, _, num_channels = sequence_output.shape
|
|
|
|
# we don't unshuffle the correct visible token order, but shuffle the position embeddings accordingly.
|
|
if bool_masked_pos is None:
|
|
raise ValueError("One must provided a boolean mask ")
|
|
|
|
expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values)
|
|
expanded_position_embeddings = expanded_position_embeddings.detach().to(device=pixel_values.device, copy=True)
|
|
pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
|
|
pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)
|
|
|
|
# [batch_size, num_patches, decoder_hidden_size]
|
|
x_full = torch.cat([sequence_output + pos_emb_visible, self.mask_token + pos_emb_mask], dim=1)
|
|
|
|
# [batch_size, num_masked_patches, num_channels * patch_size * patch_size]
|
|
decoder_outputs: VideoMAEDecoderOutput = self.decoder(x_full, pos_emb_mask.shape[1])
|
|
logits = decoder_outputs.logits
|
|
|
|
loss = None
|
|
with torch.no_grad():
|
|
# calculate the labels to be predicted
|
|
if self.config.num_channels != 3:
|
|
# Can't unnormalize with default means/stds
|
|
frames = pixel_values
|
|
else:
|
|
# first, unnormalize the frames
|
|
device = pixel_values.device
|
|
dtype = pixel_values.dtype
|
|
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
|
|
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
|
|
frames = pixel_values * std + mean # in [0, 1]
|
|
|
|
batch_size, time, num_channels, height, width = frames.shape
|
|
tubelet_size, patch_size = self.config.tubelet_size, self.config.patch_size
|
|
if self.config.norm_pix_loss:
|
|
# step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
|
|
frames = frames.view(
|
|
batch_size,
|
|
time // tubelet_size,
|
|
tubelet_size,
|
|
num_channels,
|
|
height // patch_size,
|
|
patch_size,
|
|
width // patch_size,
|
|
patch_size,
|
|
)
|
|
# step 2: move dimensions to concatenate:
|
|
frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
|
|
# step 3: concatenate:
|
|
frames = frames.view(
|
|
batch_size,
|
|
time // tubelet_size * height // patch_size * width // patch_size,
|
|
tubelet_size * patch_size * patch_size,
|
|
num_channels,
|
|
)
|
|
# step 4: normalize. The authors find that the mean is about 0.48 and standard deviation is about 0.08.
|
|
frames_norm = (frames - frames.mean(dim=-2, keepdim=True)) / (
|
|
frames.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
|
|
)
|
|
# step 5: reshape to (batch_size, T//ts * H//ps * W//ps, ts * ps * ps * C)
|
|
videos_patch = frames_norm.view(
|
|
batch_size,
|
|
time // tubelet_size * height // patch_size * width // patch_size,
|
|
tubelet_size * patch_size * patch_size * num_channels,
|
|
)
|
|
else:
|
|
if self.config.num_channels != 3:
|
|
raise ValueError(
|
|
"Can't unnormalize non-RGB images. Consider setting config.norm_pix_loss to False."
|
|
)
|
|
# step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
|
|
frames = frames.view(
|
|
batch_size,
|
|
time // tubelet_size,
|
|
tubelet_size,
|
|
num_channels,
|
|
height // patch_size,
|
|
patch_size,
|
|
width // patch_size,
|
|
patch_size,
|
|
)
|
|
# step 2: move dimensions to concatenate: (batch_size, T//ts, H//ps, W//ps, ts, ps, ps, C)
|
|
frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
|
|
# step 3: concatenate
|
|
videos_patch = frames.view(
|
|
batch_size,
|
|
time // tubelet_size * height // patch_size * width // patch_size,
|
|
tubelet_size * patch_size * patch_size * num_channels,
|
|
)
|
|
|
|
batch_size, _, num_channels = videos_patch.shape
|
|
labels = videos_patch[bool_masked_pos].reshape(batch_size, -1, num_channels)
|
|
|
|
loss_fct = MSELoss()
|
|
loss = loss_fct(logits, labels)
|
|
|
|
return VideoMAEForPreTrainingOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
VideoMAE Model transformer with a video classification head on top (a linear layer on top of the average pooled hidden
|
|
states of all tokens) e.g. for ImageNet.
|
|
"""
|
|
)
|
|
class VideoMAEForVideoClassification(VideoMAEPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.num_labels = config.num_labels
|
|
self.videomae = VideoMAEModel(config)
|
|
|
|
# Classifier head
|
|
self.fc_norm = nn.LayerNorm(config.hidden_size) if config.use_mean_pooling else None
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor | None = None,
|
|
labels: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> ImageClassifierOutput:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> import torch
|
|
>>> from transformers import VideoMAEVideoProcessor, VideoMAEForVideoClassification
|
|
>>> from huggingface_hub import hf_hub_download
|
|
|
|
>>> # replace this with your own video file
|
|
>>> video_path = hf_hub_download(
|
|
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
|
|
... )
|
|
|
|
>>> video_processor = VideoMAEVideoProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
|
|
>>> model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
|
|
|
|
>>> inputs = video_processor(video_path, return_tensors="pt")
|
|
|
|
>>> with torch.no_grad():
|
|
... outputs = model(**inputs)
|
|
... logits = outputs.logits
|
|
|
|
>>> # model predicts one of the 400 Kinetics-400 classes
|
|
>>> predicted_label = logits.argmax(-1).item()
|
|
>>> print(model.config.id2label[predicted_label])
|
|
eating spaghetti
|
|
```"""
|
|
|
|
outputs: BaseModelOutput = self.videomae(pixel_values, **kwargs)
|
|
sequence_output = outputs.last_hidden_state
|
|
|
|
if self.fc_norm is not None:
|
|
output = sequence_output.mean(1)
|
|
output = self.fc_norm(output)
|
|
else:
|
|
output = sequence_output[:, 0]
|
|
|
|
logits = self.classifier(output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(labels, logits, self.config, **kwargs)
|
|
|
|
return ImageClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = ["VideoMAEForPreTraining", "VideoMAEModel", "VideoMAEPreTrainedModel", "VideoMAEForVideoClassification"]
|