You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
953 lines
37 KiB
953 lines
37 KiB
# Copyright 2024 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections.abc import Iterable
|
|
from copy import deepcopy
|
|
from functools import lru_cache, partial
|
|
from typing import Any, Optional, Union
|
|
|
|
import numpy as np
|
|
from huggingface_hub.dataclasses import validate_typed_dict
|
|
|
|
from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
|
from .image_transforms import (
|
|
convert_to_rgb,
|
|
get_resize_output_image_size,
|
|
get_size_with_aspect_ratio,
|
|
group_images_by_shape,
|
|
reorder_images,
|
|
)
|
|
from .image_utils import (
|
|
ChannelDimension,
|
|
ImageInput,
|
|
ImageType,
|
|
SizeDict,
|
|
get_image_size,
|
|
get_image_size_for_max_height_width,
|
|
get_image_type,
|
|
infer_channel_dimension_format,
|
|
make_flat_list_of_images,
|
|
validate_kwargs,
|
|
validate_preprocess_arguments,
|
|
)
|
|
from .processing_utils import ImagesKwargs, Unpack
|
|
from .utils import (
|
|
TensorType,
|
|
auto_docstring,
|
|
is_torch_available,
|
|
is_torchvision_available,
|
|
is_vision_available,
|
|
logging,
|
|
)
|
|
from .utils.import_utils import is_rocm_platform, is_torchdynamo_compiling
|
|
|
|
|
|
if is_vision_available():
|
|
from .image_utils import PILImageResampling
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
if is_torchvision_available():
|
|
import torchvision.transforms.v2.functional as tvF
|
|
|
|
from .image_utils import pil_torch_interpolation_mapping
|
|
else:
|
|
pil_torch_interpolation_mapping = None
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@lru_cache(maxsize=10)
|
|
def validate_fast_preprocess_arguments(
|
|
do_rescale: bool | None = None,
|
|
rescale_factor: float | None = None,
|
|
do_normalize: bool | None = None,
|
|
image_mean: float | list[float] | None = None,
|
|
image_std: float | list[float] | None = None,
|
|
do_center_crop: bool | None = None,
|
|
crop_size: SizeDict | None = None,
|
|
do_resize: bool | None = None,
|
|
size: SizeDict | None = None,
|
|
interpolation: Optional["tvF.InterpolationMode"] = None,
|
|
return_tensors: str | TensorType | None = None,
|
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
|
):
|
|
"""
|
|
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
|
|
Raises `ValueError` if arguments incompatibility is caught.
|
|
"""
|
|
validate_preprocess_arguments(
|
|
do_rescale=do_rescale,
|
|
rescale_factor=rescale_factor,
|
|
do_normalize=do_normalize,
|
|
image_mean=image_mean,
|
|
image_std=image_std,
|
|
do_center_crop=do_center_crop,
|
|
crop_size=crop_size,
|
|
do_resize=do_resize,
|
|
size=size,
|
|
interpolation=interpolation,
|
|
)
|
|
# Extra checks for ImageProcessorFast
|
|
if return_tensors is not None and return_tensors != "pt":
|
|
raise ValueError("Only returning PyTorch tensors is currently supported.")
|
|
|
|
if data_format != ChannelDimension.FIRST:
|
|
raise ValueError("Only channel first data format is currently supported.")
|
|
|
|
|
|
def safe_squeeze(tensor: "torch.Tensor", axis: int | None = None) -> "torch.Tensor":
|
|
"""
|
|
Squeezes a tensor, but only if the axis specified has dim 1.
|
|
"""
|
|
if axis is None:
|
|
return tensor.squeeze()
|
|
|
|
try:
|
|
return tensor.squeeze(axis=axis)
|
|
except ValueError:
|
|
return tensor
|
|
|
|
|
|
def max_across_indices(values: Iterable[Any]) -> list[Any]:
|
|
"""
|
|
Return the maximum value across all indices of an iterable of values.
|
|
"""
|
|
return [max(values_i) for values_i in zip(*values)]
|
|
|
|
|
|
def get_max_height_width(images: list["torch.Tensor"]) -> tuple[int, ...]:
|
|
"""
|
|
Get the maximum height and width across all images in a batch.
|
|
"""
|
|
|
|
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
|
|
|
return (max_height, max_width)
|
|
|
|
|
|
def divide_to_patches(
|
|
image: Union[np.ndarray, "torch.Tensor"], patch_size: int
|
|
) -> list[Union[np.ndarray, "torch.Tensor"]]:
|
|
"""
|
|
Divides an image into patches of a specified size.
|
|
|
|
Args:
|
|
image (`Union[np.array, "torch.Tensor"]`):
|
|
The input image.
|
|
patch_size (`int`):
|
|
The size of each patch.
|
|
Returns:
|
|
list: A list of Union[np.array, "torch.Tensor"] representing the patches.
|
|
"""
|
|
patches = []
|
|
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
|
for i in range(0, height, patch_size):
|
|
for j in range(0, width, patch_size):
|
|
patch = image[:, i : i + patch_size, j : j + patch_size]
|
|
patches.append(patch)
|
|
|
|
return patches
|
|
|
|
|
|
@auto_docstring
|
|
class BaseImageProcessorFast(BaseImageProcessor):
|
|
r"""
|
|
Base class for fast image processors using PyTorch and TorchVision for image transformations.
|
|
|
|
This class provides a complete implementation for standard image preprocessing operations (resize, crop, rescale,
|
|
normalize) with GPU support and batch processing optimizations. Most image processors can be implemented by simply
|
|
setting class attributes; only processors requiring custom logic need to override methods.
|
|
|
|
Basic Implementation
|
|
--------------------
|
|
|
|
For processors that only need standard operations (resize, center crop, rescale, normalize), define class
|
|
attributes:
|
|
|
|
class MyImageProcessorFast(BaseImageProcessorFast):
|
|
resample = PILImageResampling.BILINEAR
|
|
image_mean = IMAGENET_DEFAULT_MEAN
|
|
image_std = IMAGENET_DEFAULT_STD
|
|
size = {"height": 224, "width": 224}
|
|
do_resize = True
|
|
do_rescale = True
|
|
do_normalize = True
|
|
|
|
Custom Processing
|
|
-----------------
|
|
|
|
Override `_preprocess` (most common):
|
|
For custom image processing logic, override `_preprocess`. This method receives a list of torch tensors with
|
|
channel dimension first and should return a BatchFeature. Use `group_images_by_shape` and `reorder_images` for
|
|
efficient batch processing:
|
|
|
|
def _preprocess(
|
|
self,
|
|
images: list[torch.Tensor],
|
|
do_resize: bool,
|
|
size: SizeDict,
|
|
# ... other parameters
|
|
**kwargs,
|
|
) -> BatchFeature:
|
|
# Group images by shape for batched operations
|
|
grouped_images, indices = group_images_by_shape(images)
|
|
processed_groups = {}
|
|
|
|
for shape, stacked_images in grouped_images.items():
|
|
if do_resize:
|
|
stacked_images = self.resize(stacked_images, size)
|
|
# Custom processing here
|
|
processed_groups[shape] = stacked_images
|
|
|
|
processed_images = reorder_images(processed_groups, indices)
|
|
return BatchFeature(data={"pixel_values": torch.stack(processed_images)})
|
|
|
|
Override `_preprocess_image_like_inputs` (for additional inputs):
|
|
For processors handling multiple input types (e.g., images + segmentation maps), override this method:
|
|
|
|
def _preprocess_image_like_inputs(
|
|
self,
|
|
images: ImageInput,
|
|
segmentation_maps: Optional[ImageInput] = None,
|
|
do_convert_rgb: bool,
|
|
input_data_format: ChannelDimension,
|
|
device: Optional[torch.device] = None,
|
|
**kwargs,
|
|
) -> BatchFeature:
|
|
images = self._prepare_image_like_inputs(images, do_convert_rgb, input_data_format, device)
|
|
batch_feature = self._preprocess(images, **kwargs)
|
|
|
|
if segmentation_maps is not None:
|
|
# Process segmentation maps separately
|
|
maps = self._prepare_image_like_inputs(segmentation_maps, ...)
|
|
batch_feature["labels"] = self._preprocess(maps, ...)
|
|
|
|
return batch_feature
|
|
|
|
Override `_further_process_kwargs` (for custom kwargs formatting):
|
|
To format custom kwargs before validation:
|
|
|
|
def _further_process_kwargs(self, custom_param=None, **kwargs):
|
|
kwargs = super()._further_process_kwargs(**kwargs)
|
|
if custom_param is not None:
|
|
kwargs["custom_param"] = self._format_custom_param(custom_param)
|
|
return kwargs
|
|
|
|
Override `_validate_preprocess_kwargs` (for custom validation):
|
|
To add custom validation logic:
|
|
|
|
def _validate_preprocess_kwargs(self, custom_param=None, **kwargs):
|
|
super()._validate_preprocess_kwargs(**kwargs)
|
|
if custom_param is not None and custom_param < 0:
|
|
raise ValueError("custom_param must be non-negative")
|
|
|
|
Override `_prepare_images_structure` (for nested inputs):
|
|
By default, nested image lists are flattened. Override to preserve structure:
|
|
|
|
def _prepare_images_structure(self, images, expected_ndims=3):
|
|
# Custom logic to handle nested structure
|
|
return images # Return as-is or with custom processing
|
|
|
|
Custom Parameters
|
|
-----------------
|
|
|
|
To add parameters beyond `ImagesKwargs`, create a custom kwargs class and set it as `valid_kwargs`:
|
|
|
|
class MyImageProcessorKwargs(ImagesKwargs):
|
|
custom_param: Optional[int] = None
|
|
another_param: Optional[bool] = None
|
|
|
|
class MyImageProcessorFast(BaseImageProcessorFast):
|
|
valid_kwargs = MyImageProcessorKwargs
|
|
custom_param = 10 # default value
|
|
|
|
def _preprocess(self, images, custom_param, **kwargs):
|
|
# Use custom_param in processing
|
|
...
|
|
|
|
Key Notes
|
|
---------
|
|
|
|
- Images in `_preprocess` are always torch tensors with channel dimension first, regardless of input format
|
|
- Arguments not provided by users default to class attribute values
|
|
- Use batch processing utilities (`group_images_by_shape`, `reorder_images`) for GPU efficiency
|
|
- Image loading, format conversion, and argument handling are automatic - focus only on processing logic
|
|
"""
|
|
|
|
resample = None
|
|
image_mean = None
|
|
image_std = None
|
|
size = None
|
|
default_to_square = True
|
|
crop_size = None
|
|
do_resize = None
|
|
do_center_crop = None
|
|
do_pad = None
|
|
pad_size = None
|
|
do_rescale = None
|
|
rescale_factor = 1 / 255
|
|
do_normalize = None
|
|
do_convert_rgb = None
|
|
return_tensors = None
|
|
data_format = ChannelDimension.FIRST
|
|
input_data_format = None
|
|
device = None
|
|
model_input_names = ["pixel_values"]
|
|
image_seq_length = None
|
|
valid_kwargs = ImagesKwargs
|
|
unused_kwargs = None
|
|
|
|
def __init__(self, **kwargs: Unpack[ImagesKwargs]):
|
|
super().__init__(**kwargs)
|
|
kwargs = self.filter_out_unused_kwargs(kwargs)
|
|
size = kwargs.pop("size", self.size)
|
|
self.size = (
|
|
get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
|
|
if size is not None
|
|
else None
|
|
)
|
|
crop_size = kwargs.pop("crop_size", self.crop_size)
|
|
self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
|
|
pad_size = kwargs.pop("pad_size", self.pad_size)
|
|
self.pad_size = get_size_dict(size=pad_size, param_name="pad_size") if pad_size is not None else None
|
|
|
|
for key in self.valid_kwargs.__annotations__:
|
|
kwarg = kwargs.pop(key, None)
|
|
if kwarg is not None:
|
|
setattr(self, key, kwarg)
|
|
else:
|
|
setattr(self, key, deepcopy(getattr(self, key, None)))
|
|
|
|
# get valid kwargs names
|
|
self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
|
|
|
|
@property
|
|
def is_fast(self) -> bool:
|
|
"""
|
|
`bool`: Whether or not this image processor is a fast processor (backed by PyTorch and TorchVision).
|
|
"""
|
|
return True
|
|
|
|
def pad(
|
|
self,
|
|
images: list["torch.Tensor"],
|
|
pad_size: SizeDict = None,
|
|
fill_value: int | None = 0,
|
|
padding_mode: str | None = "constant",
|
|
return_mask: bool = False,
|
|
disable_grouping: bool | None = False,
|
|
is_nested: bool | None = False,
|
|
**kwargs,
|
|
) -> Union[tuple["torch.Tensor", "torch.Tensor"], "torch.Tensor"]:
|
|
"""
|
|
Pads images to `(pad_size["height"], pad_size["width"])` or to the largest size in the batch.
|
|
|
|
Args:
|
|
images (`list[torch.Tensor]`):
|
|
Images to pad.
|
|
pad_size (`SizeDict`, *optional*):
|
|
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
|
fill_value (`int`, *optional*, defaults to `0`):
|
|
The constant value used to fill the padded area.
|
|
padding_mode (`str`, *optional*, defaults to "constant"):
|
|
The padding mode to use. Can be any of the modes supported by
|
|
`torch.nn.functional.pad` (e.g. constant, reflection, replication).
|
|
return_mask (`bool`, *optional*, defaults to `False`):
|
|
Whether to return a pixel mask to denote padded regions.
|
|
disable_grouping (`bool`, *optional*, defaults to `False`):
|
|
Whether to disable grouping of images by size.
|
|
|
|
Returns:
|
|
`Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]`: The padded images and pixel masks if `return_mask` is `True`.
|
|
"""
|
|
if pad_size is not None:
|
|
if not (pad_size.height and pad_size.width):
|
|
raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.")
|
|
pad_size = (pad_size.height, pad_size.width)
|
|
else:
|
|
pad_size = get_max_height_width(images)
|
|
|
|
grouped_images, grouped_images_index = group_images_by_shape(
|
|
images, disable_grouping=disable_grouping, is_nested=is_nested
|
|
)
|
|
processed_images_grouped = {}
|
|
processed_masks_grouped = {}
|
|
for shape, stacked_images in grouped_images.items():
|
|
image_size = stacked_images.shape[-2:]
|
|
padding_height = pad_size[0] - image_size[0]
|
|
padding_width = pad_size[1] - image_size[1]
|
|
if padding_height < 0 or padding_width < 0:
|
|
raise ValueError(
|
|
f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the "
|
|
f"image size. Got pad_size={pad_size}, image_size={image_size}."
|
|
)
|
|
if image_size != pad_size:
|
|
padding = (0, 0, padding_width, padding_height)
|
|
stacked_images = tvF.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode)
|
|
processed_images_grouped[shape] = stacked_images
|
|
|
|
if return_mask:
|
|
# keep only one from the channel dimension in pixel mask
|
|
stacked_masks = torch.zeros_like(stacked_images, dtype=torch.int64)[..., 0, :, :]
|
|
stacked_masks[..., : image_size[0], : image_size[1]] = 1
|
|
processed_masks_grouped[shape] = stacked_masks
|
|
|
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=is_nested)
|
|
if return_mask:
|
|
processed_masks = reorder_images(processed_masks_grouped, grouped_images_index, is_nested=is_nested)
|
|
return processed_images, processed_masks
|
|
|
|
return processed_images
|
|
|
|
def resize(
|
|
self,
|
|
image: "torch.Tensor",
|
|
size: SizeDict,
|
|
interpolation: Optional["tvF.InterpolationMode"] = None,
|
|
antialias: bool = True,
|
|
**kwargs,
|
|
) -> "torch.Tensor":
|
|
"""
|
|
Resize an image to `(size["height"], size["width"])`.
|
|
|
|
Args:
|
|
image (`torch.Tensor`):
|
|
Image to resize.
|
|
size (`SizeDict`):
|
|
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
|
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
|
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
|
antialias (`bool`, *optional*, defaults to `True`):
|
|
Whether to use antialiasing.
|
|
|
|
Returns:
|
|
`torch.Tensor`: The resized image.
|
|
"""
|
|
interpolation = interpolation if interpolation is not None else tvF.InterpolationMode.BILINEAR
|
|
if size.shortest_edge and size.longest_edge:
|
|
# Resize the image so that the shortest edge or the longest edge is of the given size
|
|
# while maintaining the aspect ratio of the original image.
|
|
new_size = get_size_with_aspect_ratio(
|
|
image.size()[-2:],
|
|
size.shortest_edge,
|
|
size.longest_edge,
|
|
)
|
|
elif size.shortest_edge:
|
|
new_size = get_resize_output_image_size(
|
|
image,
|
|
size=size.shortest_edge,
|
|
default_to_square=False,
|
|
input_data_format=ChannelDimension.FIRST,
|
|
)
|
|
elif size.max_height and size.max_width:
|
|
new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width)
|
|
elif size.height and size.width:
|
|
new_size = (size.height, size.width)
|
|
else:
|
|
raise ValueError(
|
|
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
|
|
f" {size}."
|
|
)
|
|
# This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs
|
|
# Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209
|
|
# TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd)
|
|
if is_torchdynamo_compiling() and is_rocm_platform():
|
|
return self.compile_friendly_resize(image, new_size, interpolation, antialias)
|
|
return tvF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
|
|
|
@staticmethod
|
|
def compile_friendly_resize(
|
|
image: "torch.Tensor",
|
|
new_size: tuple[int, int],
|
|
interpolation: Optional["tvF.InterpolationMode"] = None,
|
|
antialias: bool = True,
|
|
) -> "torch.Tensor":
|
|
"""
|
|
A wrapper around `tvF.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
|
|
"""
|
|
if image.dtype == torch.uint8:
|
|
# 256 is used on purpose instead of 255 to avoid numerical differences
|
|
# see https://github.com/huggingface/transformers/pull/38540#discussion_r2127165652
|
|
image = image.float() / 256
|
|
image = tvF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
|
image = image * 256
|
|
# torch.where is used on purpose instead of torch.clamp to avoid bug in torch.compile
|
|
# see https://github.com/huggingface/transformers/pull/38540#discussion_r2126888471
|
|
image = torch.where(image > 255, 255, image)
|
|
image = torch.where(image < 0, 0, image)
|
|
image = image.round().to(torch.uint8)
|
|
else:
|
|
image = tvF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
|
return image
|
|
|
|
def rescale(
|
|
self,
|
|
image: "torch.Tensor",
|
|
scale: float,
|
|
**kwargs,
|
|
) -> "torch.Tensor":
|
|
"""
|
|
Rescale an image by a scale factor. image = image * scale.
|
|
|
|
Args:
|
|
image (`torch.Tensor`):
|
|
Image to rescale.
|
|
scale (`float`):
|
|
The scaling factor to rescale pixel values by.
|
|
|
|
Returns:
|
|
`torch.Tensor`: The rescaled image.
|
|
"""
|
|
return image * scale
|
|
|
|
def normalize(
|
|
self,
|
|
image: "torch.Tensor",
|
|
mean: float | Iterable[float],
|
|
std: float | Iterable[float],
|
|
**kwargs,
|
|
) -> "torch.Tensor":
|
|
"""
|
|
Normalize an image. image = (image - image_mean) / image_std.
|
|
|
|
Args:
|
|
image (`torch.Tensor`):
|
|
Image to normalize.
|
|
mean (`torch.Tensor`, `float` or `Iterable[float]`):
|
|
Image mean to use for normalization.
|
|
std (`torch.Tensor`, `float` or `Iterable[float]`):
|
|
Image standard deviation to use for normalization.
|
|
|
|
Returns:
|
|
`torch.Tensor`: The normalized image.
|
|
"""
|
|
return tvF.normalize(image, mean, std)
|
|
|
|
@lru_cache(maxsize=10)
|
|
def _fuse_mean_std_and_rescale_factor(
|
|
self,
|
|
do_normalize: bool | None = None,
|
|
image_mean: float | list[float] | None = None,
|
|
image_std: float | list[float] | None = None,
|
|
do_rescale: bool | None = None,
|
|
rescale_factor: float | None = None,
|
|
device: Optional["torch.device"] = None,
|
|
) -> tuple:
|
|
if do_rescale and do_normalize:
|
|
# Fused rescale and normalize
|
|
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
|
|
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
|
|
do_rescale = False
|
|
return image_mean, image_std, do_rescale
|
|
|
|
def rescale_and_normalize(
|
|
self,
|
|
images: "torch.Tensor",
|
|
do_rescale: bool,
|
|
rescale_factor: float,
|
|
do_normalize: bool,
|
|
image_mean: float | list[float],
|
|
image_std: float | list[float],
|
|
) -> "torch.Tensor":
|
|
"""
|
|
Rescale and normalize images.
|
|
"""
|
|
image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
|
|
do_normalize=do_normalize,
|
|
image_mean=image_mean,
|
|
image_std=image_std,
|
|
do_rescale=do_rescale,
|
|
rescale_factor=rescale_factor,
|
|
device=images.device,
|
|
)
|
|
# if/elif as we use fused rescale and normalize if both are set to True
|
|
if do_normalize:
|
|
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
|
|
elif do_rescale:
|
|
images = self.rescale(images, rescale_factor)
|
|
|
|
return images
|
|
|
|
def center_crop(
|
|
self,
|
|
image: "torch.Tensor",
|
|
size: SizeDict,
|
|
**kwargs,
|
|
) -> "torch.Tensor":
|
|
"""
|
|
Note: override torchvision's center_crop to have the same behavior as the slow processor.
|
|
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
|
|
any edge, the image is padded with 0's and then center cropped.
|
|
|
|
Args:
|
|
image (`"torch.Tensor"`):
|
|
Image to center crop.
|
|
size (`dict[str, int]`):
|
|
Size of the output image.
|
|
|
|
Returns:
|
|
`torch.Tensor`: The center cropped image.
|
|
"""
|
|
if size.height is None or size.width is None:
|
|
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
|
|
image_height, image_width = image.shape[-2:]
|
|
crop_height, crop_width = size.height, size.width
|
|
|
|
if crop_width > image_width or crop_height > image_height:
|
|
padding_ltrb = [
|
|
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
|
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
|
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
|
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
|
]
|
|
image = tvF.pad(image, padding_ltrb, fill=0) # PIL uses fill value 0
|
|
image_height, image_width = image.shape[-2:]
|
|
if crop_width == image_width and crop_height == image_height:
|
|
return image
|
|
|
|
crop_top = int((image_height - crop_height) / 2.0)
|
|
crop_left = int((image_width - crop_width) / 2.0)
|
|
return tvF.crop(image, crop_top, crop_left, crop_height, crop_width)
|
|
|
|
def convert_to_rgb(
|
|
self,
|
|
image: ImageInput,
|
|
) -> ImageInput:
|
|
"""
|
|
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
|
|
as is.
|
|
Args:
|
|
image (ImageInput):
|
|
The image to convert.
|
|
|
|
Returns:
|
|
ImageInput: The converted image.
|
|
"""
|
|
return convert_to_rgb(image)
|
|
|
|
def filter_out_unused_kwargs(self, kwargs: dict):
|
|
"""
|
|
Filter out the unused kwargs from the kwargs dictionary.
|
|
"""
|
|
if self.unused_kwargs is None:
|
|
return kwargs
|
|
|
|
for kwarg_name in self.unused_kwargs:
|
|
if kwarg_name in kwargs:
|
|
logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.")
|
|
kwargs.pop(kwarg_name)
|
|
return kwargs
|
|
|
|
def _prepare_images_structure(
|
|
self,
|
|
images: ImageInput,
|
|
expected_ndims: int = 3,
|
|
) -> ImageInput:
|
|
"""
|
|
Prepare the images structure for processing.
|
|
|
|
Args:
|
|
images (`ImageInput`):
|
|
The input images to process.
|
|
|
|
Returns:
|
|
`ImageInput`: The images with a valid nesting.
|
|
"""
|
|
# Checks for `str` in case of URL/local path and optionally loads images
|
|
images = self.fetch_images(images)
|
|
return make_flat_list_of_images(images, expected_ndims=expected_ndims)
|
|
|
|
def _process_image(
|
|
self,
|
|
image: ImageInput,
|
|
do_convert_rgb: bool | None = None,
|
|
input_data_format: str | ChannelDimension | None = None,
|
|
device: Optional["torch.device"] = None,
|
|
) -> "torch.Tensor":
|
|
image_type = get_image_type(image)
|
|
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
|
|
raise ValueError(f"Unsupported input image type {image_type}")
|
|
|
|
if do_convert_rgb:
|
|
image = self.convert_to_rgb(image)
|
|
|
|
if image_type == ImageType.PIL:
|
|
image = tvF.pil_to_tensor(image)
|
|
elif image_type == ImageType.NUMPY:
|
|
# not using tvF.to_tensor as it doesn't handle (C, H, W) numpy arrays
|
|
image = torch.from_numpy(image).contiguous()
|
|
|
|
# If the image is 2D, we need to unsqueeze it to add a channel dimension for processing
|
|
if image.ndim == 2:
|
|
image = image.unsqueeze(0)
|
|
|
|
# Infer the channel dimension format if not provided
|
|
if input_data_format is None:
|
|
input_data_format = infer_channel_dimension_format(image)
|
|
|
|
if input_data_format == ChannelDimension.LAST:
|
|
# We force the channel dimension to be first for torch tensors as this is what torchvision expects.
|
|
image = image.permute(2, 0, 1).contiguous()
|
|
|
|
# Now that we have torch tensors, we can move them to the right device
|
|
if device is not None:
|
|
image = image.to(device)
|
|
|
|
return image
|
|
|
|
def _prepare_image_like_inputs(
|
|
self,
|
|
images: ImageInput,
|
|
do_convert_rgb: bool | None = None,
|
|
input_data_format: str | ChannelDimension | None = None,
|
|
device: Optional["torch.device"] = None,
|
|
expected_ndims: int = 3,
|
|
) -> list["torch.Tensor"]:
|
|
"""
|
|
Prepare image-like inputs for processing.
|
|
|
|
Args:
|
|
images (`ImageInput`):
|
|
The image-like inputs to process.
|
|
do_convert_rgb (`bool`, *optional*):
|
|
Whether to convert the images to RGB.
|
|
input_data_format (`str` or `ChannelDimension`, *optional*):
|
|
The input data format of the images.
|
|
device (`torch.device`, *optional*):
|
|
The device to put the processed images on.
|
|
expected_ndims (`int`, *optional*):
|
|
The expected number of dimensions for the images. (can be 2 for segmentation maps etc.)
|
|
|
|
Returns:
|
|
List[`torch.Tensor`]: The processed images.
|
|
"""
|
|
|
|
# Get structured images (potentially nested)
|
|
images = self._prepare_images_structure(images, expected_ndims=expected_ndims)
|
|
|
|
process_image_partial = partial(
|
|
self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
|
)
|
|
|
|
# Check if we have nested structure, assuming the nesting is consistent
|
|
has_nested_structure = len(images) > 0 and isinstance(images[0], (list, tuple))
|
|
|
|
if has_nested_structure:
|
|
processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
|
|
else:
|
|
processed_images = [process_image_partial(img) for img in images]
|
|
|
|
return processed_images
|
|
|
|
def _further_process_kwargs(
|
|
self,
|
|
size: SizeDict | None = None,
|
|
crop_size: SizeDict | None = None,
|
|
pad_size: SizeDict | None = None,
|
|
default_to_square: bool | None = None,
|
|
image_mean: float | list[float] | None = None,
|
|
image_std: float | list[float] | None = None,
|
|
data_format: ChannelDimension | None = None,
|
|
**kwargs,
|
|
) -> dict:
|
|
"""
|
|
Update kwargs that need further processing before being validated
|
|
Can be overridden by subclasses to customize the processing of kwargs.
|
|
"""
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if size is not None:
|
|
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
|
|
if crop_size is not None:
|
|
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
|
|
if pad_size is not None:
|
|
pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size"))
|
|
if isinstance(image_mean, list):
|
|
image_mean = tuple(image_mean)
|
|
if isinstance(image_std, list):
|
|
image_std = tuple(image_std)
|
|
if data_format is None:
|
|
data_format = ChannelDimension.FIRST
|
|
|
|
kwargs["size"] = size
|
|
kwargs["crop_size"] = crop_size
|
|
kwargs["pad_size"] = pad_size
|
|
kwargs["image_mean"] = image_mean
|
|
kwargs["image_std"] = image_std
|
|
kwargs["data_format"] = data_format
|
|
|
|
# torch resize uses interpolation instead of resample
|
|
# Check if resample is an int before checking if it's an instance of PILImageResampling
|
|
# because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
|
|
# Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
|
|
resample = kwargs.pop("resample")
|
|
kwargs["interpolation"] = (
|
|
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
|
)
|
|
|
|
return kwargs
|
|
|
|
def _validate_preprocess_kwargs(
|
|
self,
|
|
do_rescale: bool | None = None,
|
|
rescale_factor: float | None = None,
|
|
do_normalize: bool | None = None,
|
|
image_mean: float | tuple[float] | None = None,
|
|
image_std: float | tuple[float] | None = None,
|
|
do_resize: bool | None = None,
|
|
size: SizeDict | None = None,
|
|
do_center_crop: bool | None = None,
|
|
crop_size: SizeDict | None = None,
|
|
interpolation: Optional["tvF.InterpolationMode"] = None,
|
|
return_tensors: str | TensorType | None = None,
|
|
data_format: ChannelDimension | None = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
validate the kwargs for the preprocess method.
|
|
"""
|
|
validate_fast_preprocess_arguments(
|
|
do_rescale=do_rescale,
|
|
rescale_factor=rescale_factor,
|
|
do_normalize=do_normalize,
|
|
image_mean=image_mean,
|
|
image_std=image_std,
|
|
do_resize=do_resize,
|
|
size=size,
|
|
do_center_crop=do_center_crop,
|
|
crop_size=crop_size,
|
|
interpolation=interpolation,
|
|
return_tensors=return_tensors,
|
|
data_format=data_format,
|
|
)
|
|
|
|
@auto_docstring
|
|
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
|
|
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
|
|
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
|
|
|
|
# Perform type validation on received kwargs
|
|
validate_typed_dict(self.valid_kwargs, kwargs)
|
|
|
|
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
|
# by the user, it gets its default value from the instance, or is set to None.
|
|
for kwarg_name in self._valid_kwargs_names:
|
|
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
|
|
|
# Extract parameters that are only used for preparing the input images
|
|
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
|
input_data_format = kwargs.pop("input_data_format")
|
|
device = kwargs.pop("device")
|
|
|
|
# Update kwargs that need further processing before being validated
|
|
kwargs = self._further_process_kwargs(**kwargs)
|
|
|
|
# Validate kwargs
|
|
self._validate_preprocess_kwargs(**kwargs)
|
|
|
|
# Pop kwargs that are not needed in _preprocess
|
|
kwargs.pop("data_format")
|
|
|
|
return self._preprocess_image_like_inputs(
|
|
images, *args, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device, **kwargs
|
|
)
|
|
|
|
def _preprocess_image_like_inputs(
|
|
self,
|
|
images: ImageInput,
|
|
*args,
|
|
do_convert_rgb: bool,
|
|
input_data_format: ChannelDimension,
|
|
device: Union[str, "torch.device"] | None = None,
|
|
**kwargs: Unpack[ImagesKwargs],
|
|
) -> BatchFeature:
|
|
"""
|
|
Preprocess image-like inputs.
|
|
To be overridden by subclasses when image-like inputs other than images should be processed.
|
|
It can be used for segmentation maps, depth maps, etc.
|
|
"""
|
|
# Prepare input images
|
|
images = self._prepare_image_like_inputs(
|
|
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
|
)
|
|
return self._preprocess(images, *args, **kwargs)
|
|
|
|
def _preprocess(
|
|
self,
|
|
images: list["torch.Tensor"],
|
|
do_resize: bool,
|
|
size: SizeDict,
|
|
interpolation: Optional["tvF.InterpolationMode"],
|
|
do_center_crop: bool,
|
|
crop_size: SizeDict,
|
|
do_rescale: bool,
|
|
rescale_factor: float,
|
|
do_normalize: bool,
|
|
image_mean: float | list[float] | None,
|
|
image_std: float | list[float] | None,
|
|
do_pad: bool | None,
|
|
pad_size: SizeDict | None,
|
|
disable_grouping: bool | None,
|
|
return_tensors: str | TensorType | None,
|
|
**kwargs,
|
|
) -> BatchFeature:
|
|
# Group images by size for batched resizing
|
|
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
|
resized_images_grouped = {}
|
|
for shape, stacked_images in grouped_images.items():
|
|
if do_resize:
|
|
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
|
resized_images_grouped[shape] = stacked_images
|
|
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
|
|
|
# Group images by size for further processing
|
|
# Needed in case do_resize is False, or resize returns images with different sizes
|
|
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
|
processed_images_grouped = {}
|
|
for shape, stacked_images in grouped_images.items():
|
|
if do_center_crop:
|
|
stacked_images = self.center_crop(stacked_images, crop_size)
|
|
# Fused rescale and normalize
|
|
stacked_images = self.rescale_and_normalize(
|
|
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
|
)
|
|
processed_images_grouped[shape] = stacked_images
|
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
|
|
if do_pad:
|
|
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
|
|
|
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
|
|
def to_dict(self):
|
|
encoder_dict = super().to_dict()
|
|
|
|
# Filter out None values that are class defaults, but preserve explicitly set None values
|
|
filtered_dict = {}
|
|
for key, value in encoder_dict.items():
|
|
if value is None:
|
|
class_default = getattr(type(self), key, "NOT_FOUND")
|
|
# Keep None if user explicitly set it (class default is non-None)
|
|
if class_default != "NOT_FOUND" and class_default is not None:
|
|
filtered_dict[key] = value
|
|
else:
|
|
filtered_dict[key] = value
|
|
|
|
filtered_dict.pop("_valid_processor_keys", None)
|
|
filtered_dict.pop("_valid_kwargs_names", None)
|
|
return filtered_dict
|