You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

279 lines
11 KiB

# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch VitPose model."""
from dataclasses import dataclass
import torch
from torch import nn
from ... import initialization as init
from ...backbone_utils import load_backbone
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
from ...utils.generic import can_return_tuple
from .configuration_vitpose import VitPoseConfig
logger = logging.get_logger(__name__)
# General docstring
@dataclass
@auto_docstring(
custom_intro="""
Class for outputs of pose estimation models.
"""
)
class VitPoseEstimatorOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail.
heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`):
Heatmaps as predicted by the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
"""
loss: torch.FloatTensor | None = None
heatmaps: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
@auto_docstring
class VitPosePreTrainedModel(PreTrainedModel):
config: VitPoseConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
input_modalities = ("image",)
supports_gradient_checkpointing = True
@torch.no_grad()
def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
init.zeros_(module.bias)
init.ones_(module.weight)
def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
"""Flip the flipped heatmaps back to the original form.
Args:
output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`):
The output heatmaps obtained from the flipped images.
flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`):
Pairs of keypoints which are mirrored (for example, left ear -- right ear).
target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`):
Target type to use. Can be gaussian-heatmap or combined-target.
gaussian-heatmap: Classification target with gaussian distribution.
combined-target: The combination of classification target (response map) and regression target (offset map).
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
Returns:
torch.Tensor: heatmaps that flipped back to the original image
"""
if target_type not in ["gaussian-heatmap", "combined-target"]:
raise ValueError("target_type should be gaussian-heatmap or combined-target")
if output_flipped.ndim != 4:
raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]")
batch_size, num_keypoints, height, width = output_flipped.shape
channels = 1
if target_type == "combined-target":
channels = 3
output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...]
output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width)
output_flipped_back = output_flipped.clone()
# Swap left-right parts
for left, right in flip_pairs.tolist():
output_flipped_back[:, left, ...] = output_flipped[:, right, ...]
output_flipped_back[:, right, ...] = output_flipped[:, left, ...]
output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width))
# Flip horizontally
output_flipped_back = output_flipped_back.flip(-1)
return output_flipped_back
class VitPoseSimpleDecoder(nn.Module):
"""
Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the
feature maps into heatmaps.
"""
def __init__(self, config: VitPoseConfig):
super().__init__()
self.activation = nn.ReLU()
self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False)
self.conv = nn.Conv2d(
config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1
)
def forward(self, hidden_state: torch.Tensor, flip_pairs: torch.Tensor | None = None) -> torch.Tensor:
# Transform input: ReLU + upsample
hidden_state = self.activation(hidden_state)
hidden_state = self.upsampling(hidden_state)
heatmaps = self.conv(hidden_state)
if flip_pairs is not None:
heatmaps = flip_back(heatmaps, flip_pairs)
return heatmaps
class VitPoseClassicDecoder(nn.Module):
"""
Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer,
turning the feature maps into heatmaps.
"""
def __init__(self, config: VitPoseConfig):
super().__init__()
self.deconv1 = nn.ConvTranspose2d(
config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False
)
self.batchnorm1 = nn.BatchNorm2d(256)
self.relu1 = nn.ReLU()
self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False)
self.batchnorm2 = nn.BatchNorm2d(256)
self.relu2 = nn.ReLU()
self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_state: torch.Tensor, flip_pairs: torch.Tensor | None = None):
hidden_state = self.deconv1(hidden_state)
hidden_state = self.batchnorm1(hidden_state)
hidden_state = self.relu1(hidden_state)
hidden_state = self.deconv2(hidden_state)
hidden_state = self.batchnorm2(hidden_state)
hidden_state = self.relu2(hidden_state)
heatmaps = self.conv(hidden_state)
if flip_pairs is not None:
heatmaps = flip_back(heatmaps, flip_pairs)
return heatmaps
@auto_docstring(
custom_intro="""
The VitPose model with a pose estimation head on top.
"""
)
class VitPoseForPoseEstimation(VitPosePreTrainedModel):
def __init__(self, config: VitPoseConfig):
super().__init__(config)
self.backbone = load_backbone(config)
# add backbone attributes
if not hasattr(self.backbone.config, "hidden_size"):
raise ValueError("The backbone should have a hidden_size attribute")
if not hasattr(self.backbone.config, "image_size"):
raise ValueError("The backbone should have an image_size attribute")
if not hasattr(self.backbone.config, "patch_size"):
raise ValueError("The backbone should have a patch_size attribute")
self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: torch.Tensor,
dataset_index: torch.Tensor | None = None,
flip_pairs: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> VitPoseEstimatorOutput:
r"""
dataset_index (`torch.Tensor` of shape `(batch_size,)`):
Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose).
flip_pairs (`torch.tensor`, *optional*):
Whether to mirror pairs of keypoints (for example, left ear -- right ear).
Examples:
```python
>>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation
>>> import torch
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
>>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
>>> inputs = processor(image, boxes=boxes, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> heatmaps = outputs.heatmaps
```"""
loss = None
if labels is not None:
raise NotImplementedError("Training is not yet supported")
outputs: BackboneOutput = self.backbone.forward_with_filtered_kwargs(
pixel_values,
dataset_index=dataset_index,
**kwargs,
)
# Turn output hidden states in tensor of shape (batch_size, num_channels, height, width)
sequence_output = outputs.feature_maps[-1]
batch_size = sequence_output.shape[0]
patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0]
patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1]
sequence_output = sequence_output.permute(0, 2, 1)
sequence_output = sequence_output.reshape(batch_size, -1, patch_height, patch_width).contiguous()
heatmaps = self.head(sequence_output, flip_pairs=flip_pairs)
return VitPoseEstimatorOutput(
loss=loss,
heatmaps=heatmaps,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"]